Adding an example of jaxdecomp implementation

This commit is contained in:
EiffL 2022-11-26 17:27:14 +01:00
parent 6644b35d71
commit 6ca4c9191e
5 changed files with 166 additions and 192 deletions

View file

@ -3,19 +3,23 @@ from jax.experimental.maps import xmap
import numpy as np
import jax.numpy as jnp
from functools import partial
import jaxdecomp
def fftk(shape, symmetric=False, dtype=np.float32, comms=None):
def fftk(shape, symmetric=False, dtype=np.float32, sharding_info=None):
""" Return k_vector given a shape (nc, nc, nc)
"""
k = []
if comms is not None:
nx = comms[0].Get_size()
ix = comms[0].Get_rank()
ny = comms[1].Get_size()
iy = comms[1].Get_rank()
shape = [shape[0]*nx, shape[1]*ny] + list(shape[2:])
if sharding_info is not None:
nx = sharding_info.pdims[1]
ny = sharding_info.pdims[0]
# nx = sharding_info[0].Get_size()
# ix = sharding_info[0].Get_rank()
# ny = sharding_info[1].Get_size()
# iy = sharding_info[1].Get_rank()
ix = sharding_info.rank
iy = 0
shape = sharding_info.global_shape
for d in range(len(shape)):
kd = np.fft.fftfreq(shape[d])
@ -24,10 +28,10 @@ def fftk(shape, symmetric=False, dtype=np.float32, comms=None):
if symmetric and d == len(shape) - 1:
kd = kd[:shape[d] // 2 + 1]
if (comms is not None) and d == 0:
if (sharding_info is not None) and d == 0:
kd = kd.reshape([nx, -1])[ix]
if (comms is not None) and d == 1:
if (sharding_info is not None) and d == 1:
kd = kd.reshape([ny, -1])[iy]
k.append(kd.astype(dtype))
@ -42,10 +46,9 @@ def apply_gradient_laplace(kfield, kvec):
kx, ky, kz = kvec
kk = (kx**2 + ky**2 + kz**2)
kernel = jnp.where(kk == 0, 1., 1./kk)
return jnp.stack([kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky)),
kfield * kernel * 1j * 1 / 6.0 *
(8 * jnp.sin(kz) - jnp.sin(2 * kz)),
kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx))], axis=-1)
return jnp.stack([kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz)),
kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx)),
kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky))], axis=-1)
def cic_compensation(kvec):