temp commit

This commit is contained in:
Wassim KABALAN 2024-04-19 01:11:25 +02:00
parent 6ca4c9191e
commit 055ceedb7e
5 changed files with 220 additions and 110 deletions

View file

@ -17,13 +17,13 @@ def fftk(shape, symmetric=False, dtype=np.float32, sharding_info=None):
# ix = sharding_info[0].Get_rank()
# ny = sharding_info[1].Get_size()
# iy = sharding_info[1].Get_rank()
ix = sharding_info.rank
ix = sharding_info.rank % nx
iy = 0
shape = sharding_info.global_shape
for d in range(len(shape)):
kd = np.fft.fftfreq(shape[d])
kd *= 2 * np.pi
kd = jnp.fft.fftfreq(shape[d])
kd *= 2 * jnp.pi
if symmetric and d == len(shape) - 1:
kd = kd[:shape[d] // 2 + 1]
@ -38,12 +38,8 @@ def fftk(shape, symmetric=False, dtype=np.float32, sharding_info=None):
return k
@partial(xmap,
in_axes=[['x', 'y', ...],
[['x'], ['y'], [...]]],
out_axes=['x', 'y', ...])
def apply_gradient_laplace(kfield, kvec):
kx, ky, kz = kvec
@jax.jit
def apply_gradient_laplace(kfield, kx, ky, kz):
kk = (kx**2 + ky**2 + kz**2)
kernel = jnp.where(kk == 0, 1., 1./kk)
return jnp.stack([kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz)),