mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-29 16:41:11 +00:00
temp commit
This commit is contained in:
parent
6ca4c9191e
commit
055ceedb7e
5 changed files with 220 additions and 110 deletions
|
@ -17,13 +17,13 @@ def fftk(shape, symmetric=False, dtype=np.float32, sharding_info=None):
|
|||
# ix = sharding_info[0].Get_rank()
|
||||
# ny = sharding_info[1].Get_size()
|
||||
# iy = sharding_info[1].Get_rank()
|
||||
ix = sharding_info.rank
|
||||
ix = sharding_info.rank % nx
|
||||
iy = 0
|
||||
shape = sharding_info.global_shape
|
||||
|
||||
for d in range(len(shape)):
|
||||
kd = np.fft.fftfreq(shape[d])
|
||||
kd *= 2 * np.pi
|
||||
kd = jnp.fft.fftfreq(shape[d])
|
||||
kd *= 2 * jnp.pi
|
||||
|
||||
if symmetric and d == len(shape) - 1:
|
||||
kd = kd[:shape[d] // 2 + 1]
|
||||
|
@ -38,12 +38,8 @@ def fftk(shape, symmetric=False, dtype=np.float32, sharding_info=None):
|
|||
return k
|
||||
|
||||
|
||||
@partial(xmap,
|
||||
in_axes=[['x', 'y', ...],
|
||||
[['x'], ['y'], [...]]],
|
||||
out_axes=['x', 'y', ...])
|
||||
def apply_gradient_laplace(kfield, kvec):
|
||||
kx, ky, kz = kvec
|
||||
@jax.jit
|
||||
def apply_gradient_laplace(kfield, kx, ky, kz):
|
||||
kk = (kx**2 + ky**2 + kz**2)
|
||||
kernel = jnp.where(kk == 0, 1., 1./kk)
|
||||
return jnp.stack([kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz)),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue