update for latest jaxDecomp

This commit is contained in:
Wassim KABALAN 2024-10-21 13:55:48 -04:00
parent afecb13cde
commit 01b952701e
2 changed files with 16 additions and 59 deletions

View file

@ -1,40 +1,18 @@
from enum import Enum
from functools import partial
import jax.numpy as jnp
import jax_cosmo as jc
import numpy as np
from jax._src import mesh as mesh_lib
from jax.lib.xla_client import FftType
from jax.sharding import PartitionSpec as P
from jaxdecomp import fftfreq3d, get_output_specs
from jaxpm.distributed import autoshmap
class PencilType(Enum):
NO_DECOMP = 0
SLAB_XY = 1
SLAB_YZ = 2
PENCILS = 3
def get_pencil_type():
mesh = mesh_lib.thread_resources.env.physical_mesh
if mesh.empty:
pdims = None
else:
pdims = mesh.devices.shape[::-1]
if pdims == (1, 1) or pdims == None:
return PencilType.NO_DECOMP
elif pdims[0] == 1:
return PencilType.SLAB_XY
elif pdims[1] == 1:
return PencilType.SLAB_YZ
else:
return PencilType.PENCILS
def fftk(shape, dtype=np.float32):
def fftk(k_array):
"""
Generate Fourier transform wave numbers for a given mesh.
@ -44,31 +22,8 @@ def fftk(shape, dtype=np.float32):
Returns:
list: List of wave number arrays for each dimension in
the order [kx, ky, kz].
"""
kx, ky, kz = [jnp.fft.fftfreq(s, dtype=dtype) * 2 * np.pi for s in shape]
@partial(autoshmap,
in_specs=(P('x'), P('y'), P(None)),
out_specs=(P('x'), P(None, 'y'), P(None)),
in_fourrier_space=True)
def get_kvec(ky, kz, kx):
return (ky.reshape([-1, 1, 1]),
kz.reshape([1, -1, 1]),
kx.reshape([1, 1, -1])) # yapf: disable
pencil_type = get_pencil_type()
# YZ returns Y pencil
# XY and pencils returns a Z pencil
# NO_DECOMP returns a X pencil
if pencil_type == PencilType.NO_DECOMP:
kx, ky, kz = get_kvec(kx, ky, kz) # Z Y X ==> X pencil
elif pencil_type == PencilType.SLAB_YZ:
kz, kx, ky = get_kvec(kz, kx, ky) # X Z Y ==> Y pencil
elif pencil_type == PencilType.SLAB_XY or pencil_type == PencilType.PENCILS:
ky, kz, kx = get_kvec(ky, kz, kx) # Z X Y ==> Z pencil
else:
raise ValueError("Unknown pencil type")
"""
kx, ky, kz = fftfreq3d(k_array)
# to the order of dimensions in the transposed FFT
return kx, ky, kz
@ -77,10 +32,11 @@ def interpolate_power_spectrum(input, k, pk):
pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape(-1), k, pk
).reshape(x.shape)
return autoshmap(pk_fn,
in_specs=P('x', 'y'),
out_specs=P('x', 'y'),
in_fourrier_space=True)(input)
specs = P('x', 'y')
mesh = mesh_lib.thread_resources.env.physical_mesh
out_specs = P(*get_output_specs(FftType.FFT, specs, mesh))
return autoshmap(pk_fn, in_specs=out_specs, out_specs=out_specs)(input)
def gradient_kernel(kvec, direction, order=1):

View file

@ -22,13 +22,13 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0, halo_size=0):
assert (delta is not None
), "If mesh_shape is not provided, delta should be provided"
mesh_shape = delta.shape
kvec = fftk(mesh_shape)
if delta is None:
delta_k = fft3d(cic_paint_dx(positions, halo_size=halo_size))
else:
delta_k = fft3d(delta)
kvec = fftk(delta_k)
# Computes gravitational potential
pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec,
r_split=r_split)
@ -137,15 +137,16 @@ def linear_field(mesh_shape, box_size, pk, seed):
"""
Generate initial conditions.
"""
kvec = fftk(mesh_shape)
# Initialize a random field with one slice on each gpu
field = normal_field(mesh_shape, seed=seed)
field = fft3d(field)
kvec = fftk(field)
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2
for i, kk in enumerate(kvec))**0.5
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (
box_size[0] * box_size[1] * box_size[2])
# Initialize a random field with one slice on each gpu
field = normal_field(mesh_shape, seed=seed)
field = fft3d(field) * pkmesh**0.5
field = field * (pkmesh)**0.5
field = ifft3d(field)
return field