mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-29 16:41:11 +00:00
Adding an example of jaxdecomp implementation
This commit is contained in:
parent
6644b35d71
commit
6ca4c9191e
5 changed files with 166 additions and 192 deletions
|
@ -3,19 +3,23 @@ from jax.experimental.maps import xmap
|
|||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
from functools import partial
|
||||
import jaxdecomp
|
||||
|
||||
|
||||
def fftk(shape, symmetric=False, dtype=np.float32, comms=None):
|
||||
def fftk(shape, symmetric=False, dtype=np.float32, sharding_info=None):
|
||||
""" Return k_vector given a shape (nc, nc, nc)
|
||||
"""
|
||||
k = []
|
||||
|
||||
if comms is not None:
|
||||
nx = comms[0].Get_size()
|
||||
ix = comms[0].Get_rank()
|
||||
ny = comms[1].Get_size()
|
||||
iy = comms[1].Get_rank()
|
||||
shape = [shape[0]*nx, shape[1]*ny] + list(shape[2:])
|
||||
if sharding_info is not None:
|
||||
nx = sharding_info.pdims[1]
|
||||
ny = sharding_info.pdims[0]
|
||||
# nx = sharding_info[0].Get_size()
|
||||
# ix = sharding_info[0].Get_rank()
|
||||
# ny = sharding_info[1].Get_size()
|
||||
# iy = sharding_info[1].Get_rank()
|
||||
ix = sharding_info.rank
|
||||
iy = 0
|
||||
shape = sharding_info.global_shape
|
||||
|
||||
for d in range(len(shape)):
|
||||
kd = np.fft.fftfreq(shape[d])
|
||||
|
@ -24,10 +28,10 @@ def fftk(shape, symmetric=False, dtype=np.float32, comms=None):
|
|||
if symmetric and d == len(shape) - 1:
|
||||
kd = kd[:shape[d] // 2 + 1]
|
||||
|
||||
if (comms is not None) and d == 0:
|
||||
if (sharding_info is not None) and d == 0:
|
||||
kd = kd.reshape([nx, -1])[ix]
|
||||
|
||||
if (comms is not None) and d == 1:
|
||||
if (sharding_info is not None) and d == 1:
|
||||
kd = kd.reshape([ny, -1])[iy]
|
||||
|
||||
k.append(kd.astype(dtype))
|
||||
|
@ -42,10 +46,9 @@ def apply_gradient_laplace(kfield, kvec):
|
|||
kx, ky, kz = kvec
|
||||
kk = (kx**2 + ky**2 + kz**2)
|
||||
kernel = jnp.where(kk == 0, 1., 1./kk)
|
||||
return jnp.stack([kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky)),
|
||||
kfield * kernel * 1j * 1 / 6.0 *
|
||||
(8 * jnp.sin(kz) - jnp.sin(2 * kz)),
|
||||
kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx))], axis=-1)
|
||||
return jnp.stack([kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz)),
|
||||
kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx)),
|
||||
kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky))], axis=-1)
|
||||
|
||||
|
||||
def cic_compensation(kvec):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue