mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-11 21:50:55 +00:00
update for latest jaxDecomp
This commit is contained in:
parent
afecb13cde
commit
01b952701e
2 changed files with 16 additions and 59 deletions
jaxpm
|
@ -1,40 +1,18 @@
|
|||
from enum import Enum
|
||||
from functools import partial
|
||||
|
||||
import jax.numpy as jnp
|
||||
import jax_cosmo as jc
|
||||
import numpy as np
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax.lib.xla_client import FftType
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jaxdecomp import fftfreq3d, get_output_specs
|
||||
|
||||
from jaxpm.distributed import autoshmap
|
||||
|
||||
|
||||
class PencilType(Enum):
|
||||
NO_DECOMP = 0
|
||||
SLAB_XY = 1
|
||||
SLAB_YZ = 2
|
||||
PENCILS = 3
|
||||
|
||||
|
||||
def get_pencil_type():
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
if mesh.empty:
|
||||
pdims = None
|
||||
else:
|
||||
pdims = mesh.devices.shape[::-1]
|
||||
|
||||
if pdims == (1, 1) or pdims == None:
|
||||
return PencilType.NO_DECOMP
|
||||
elif pdims[0] == 1:
|
||||
return PencilType.SLAB_XY
|
||||
elif pdims[1] == 1:
|
||||
return PencilType.SLAB_YZ
|
||||
else:
|
||||
return PencilType.PENCILS
|
||||
|
||||
|
||||
def fftk(shape, dtype=np.float32):
|
||||
def fftk(k_array):
|
||||
"""
|
||||
Generate Fourier transform wave numbers for a given mesh.
|
||||
|
||||
|
@ -44,31 +22,8 @@ def fftk(shape, dtype=np.float32):
|
|||
Returns:
|
||||
list: List of wave number arrays for each dimension in
|
||||
the order [kx, ky, kz].
|
||||
"""
|
||||
kx, ky, kz = [jnp.fft.fftfreq(s, dtype=dtype) * 2 * np.pi for s in shape]
|
||||
|
||||
@partial(autoshmap,
|
||||
in_specs=(P('x'), P('y'), P(None)),
|
||||
out_specs=(P('x'), P(None, 'y'), P(None)),
|
||||
in_fourrier_space=True)
|
||||
def get_kvec(ky, kz, kx):
|
||||
return (ky.reshape([-1, 1, 1]),
|
||||
kz.reshape([1, -1, 1]),
|
||||
kx.reshape([1, 1, -1])) # yapf: disable
|
||||
|
||||
pencil_type = get_pencil_type()
|
||||
# YZ returns Y pencil
|
||||
# XY and pencils returns a Z pencil
|
||||
# NO_DECOMP returns a X pencil
|
||||
if pencil_type == PencilType.NO_DECOMP:
|
||||
kx, ky, kz = get_kvec(kx, ky, kz) # Z Y X ==> X pencil
|
||||
elif pencil_type == PencilType.SLAB_YZ:
|
||||
kz, kx, ky = get_kvec(kz, kx, ky) # X Z Y ==> Y pencil
|
||||
elif pencil_type == PencilType.SLAB_XY or pencil_type == PencilType.PENCILS:
|
||||
ky, kz, kx = get_kvec(ky, kz, kx) # Z X Y ==> Z pencil
|
||||
else:
|
||||
raise ValueError("Unknown pencil type")
|
||||
|
||||
"""
|
||||
kx, ky, kz = fftfreq3d(k_array)
|
||||
# to the order of dimensions in the transposed FFT
|
||||
return kx, ky, kz
|
||||
|
||||
|
@ -77,10 +32,11 @@ def interpolate_power_spectrum(input, k, pk):
|
|||
|
||||
pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape(-1), k, pk
|
||||
).reshape(x.shape)
|
||||
return autoshmap(pk_fn,
|
||||
in_specs=P('x', 'y'),
|
||||
out_specs=P('x', 'y'),
|
||||
in_fourrier_space=True)(input)
|
||||
specs = P('x', 'y')
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
out_specs = P(*get_output_specs(FftType.FFT, specs, mesh))
|
||||
|
||||
return autoshmap(pk_fn, in_specs=out_specs, out_specs=out_specs)(input)
|
||||
|
||||
|
||||
def gradient_kernel(kvec, direction, order=1):
|
||||
|
|
11
jaxpm/pm.py
11
jaxpm/pm.py
|
@ -22,13 +22,13 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0, halo_size=0):
|
|||
assert (delta is not None
|
||||
), "If mesh_shape is not provided, delta should be provided"
|
||||
mesh_shape = delta.shape
|
||||
kvec = fftk(mesh_shape)
|
||||
|
||||
if delta is None:
|
||||
delta_k = fft3d(cic_paint_dx(positions, halo_size=halo_size))
|
||||
else:
|
||||
delta_k = fft3d(delta)
|
||||
|
||||
kvec = fftk(delta_k)
|
||||
# Computes gravitational potential
|
||||
pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec,
|
||||
r_split=r_split)
|
||||
|
@ -137,15 +137,16 @@ def linear_field(mesh_shape, box_size, pk, seed):
|
|||
"""
|
||||
Generate initial conditions.
|
||||
"""
|
||||
kvec = fftk(mesh_shape)
|
||||
# Initialize a random field with one slice on each gpu
|
||||
field = normal_field(mesh_shape, seed=seed)
|
||||
field = fft3d(field)
|
||||
kvec = fftk(field)
|
||||
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2
|
||||
for i, kk in enumerate(kvec))**0.5
|
||||
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (
|
||||
box_size[0] * box_size[1] * box_size[2])
|
||||
|
||||
# Initialize a random field with one slice on each gpu
|
||||
field = normal_field(mesh_shape, seed=seed)
|
||||
field = fft3d(field) * pkmesh**0.5
|
||||
field = field * (pkmesh)**0.5
|
||||
field = ifft3d(field)
|
||||
return field
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue