mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-18 17:10:54 +00:00
update for latest jaxDecomp
This commit is contained in:
parent
afecb13cde
commit
01b952701e
2 changed files with 16 additions and 59 deletions
|
@ -1,40 +1,18 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jax_cosmo as jc
|
import jax_cosmo as jc
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from jax._src import mesh as mesh_lib
|
from jax._src import mesh as mesh_lib
|
||||||
|
from jax.lib.xla_client import FftType
|
||||||
from jax.sharding import PartitionSpec as P
|
from jax.sharding import PartitionSpec as P
|
||||||
|
from jaxdecomp import fftfreq3d, get_output_specs
|
||||||
|
|
||||||
from jaxpm.distributed import autoshmap
|
from jaxpm.distributed import autoshmap
|
||||||
|
|
||||||
|
|
||||||
class PencilType(Enum):
|
|
||||||
NO_DECOMP = 0
|
|
||||||
SLAB_XY = 1
|
|
||||||
SLAB_YZ = 2
|
|
||||||
PENCILS = 3
|
|
||||||
|
|
||||||
|
def fftk(k_array):
|
||||||
def get_pencil_type():
|
|
||||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
|
||||||
if mesh.empty:
|
|
||||||
pdims = None
|
|
||||||
else:
|
|
||||||
pdims = mesh.devices.shape[::-1]
|
|
||||||
|
|
||||||
if pdims == (1, 1) or pdims == None:
|
|
||||||
return PencilType.NO_DECOMP
|
|
||||||
elif pdims[0] == 1:
|
|
||||||
return PencilType.SLAB_XY
|
|
||||||
elif pdims[1] == 1:
|
|
||||||
return PencilType.SLAB_YZ
|
|
||||||
else:
|
|
||||||
return PencilType.PENCILS
|
|
||||||
|
|
||||||
|
|
||||||
def fftk(shape, dtype=np.float32):
|
|
||||||
"""
|
"""
|
||||||
Generate Fourier transform wave numbers for a given mesh.
|
Generate Fourier transform wave numbers for a given mesh.
|
||||||
|
|
||||||
|
@ -44,31 +22,8 @@ def fftk(shape, dtype=np.float32):
|
||||||
Returns:
|
Returns:
|
||||||
list: List of wave number arrays for each dimension in
|
list: List of wave number arrays for each dimension in
|
||||||
the order [kx, ky, kz].
|
the order [kx, ky, kz].
|
||||||
"""
|
"""
|
||||||
kx, ky, kz = [jnp.fft.fftfreq(s, dtype=dtype) * 2 * np.pi for s in shape]
|
kx, ky, kz = fftfreq3d(k_array)
|
||||||
|
|
||||||
@partial(autoshmap,
|
|
||||||
in_specs=(P('x'), P('y'), P(None)),
|
|
||||||
out_specs=(P('x'), P(None, 'y'), P(None)),
|
|
||||||
in_fourrier_space=True)
|
|
||||||
def get_kvec(ky, kz, kx):
|
|
||||||
return (ky.reshape([-1, 1, 1]),
|
|
||||||
kz.reshape([1, -1, 1]),
|
|
||||||
kx.reshape([1, 1, -1])) # yapf: disable
|
|
||||||
|
|
||||||
pencil_type = get_pencil_type()
|
|
||||||
# YZ returns Y pencil
|
|
||||||
# XY and pencils returns a Z pencil
|
|
||||||
# NO_DECOMP returns a X pencil
|
|
||||||
if pencil_type == PencilType.NO_DECOMP:
|
|
||||||
kx, ky, kz = get_kvec(kx, ky, kz) # Z Y X ==> X pencil
|
|
||||||
elif pencil_type == PencilType.SLAB_YZ:
|
|
||||||
kz, kx, ky = get_kvec(kz, kx, ky) # X Z Y ==> Y pencil
|
|
||||||
elif pencil_type == PencilType.SLAB_XY or pencil_type == PencilType.PENCILS:
|
|
||||||
ky, kz, kx = get_kvec(ky, kz, kx) # Z X Y ==> Z pencil
|
|
||||||
else:
|
|
||||||
raise ValueError("Unknown pencil type")
|
|
||||||
|
|
||||||
# to the order of dimensions in the transposed FFT
|
# to the order of dimensions in the transposed FFT
|
||||||
return kx, ky, kz
|
return kx, ky, kz
|
||||||
|
|
||||||
|
@ -77,10 +32,11 @@ def interpolate_power_spectrum(input, k, pk):
|
||||||
|
|
||||||
pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape(-1), k, pk
|
pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape(-1), k, pk
|
||||||
).reshape(x.shape)
|
).reshape(x.shape)
|
||||||
return autoshmap(pk_fn,
|
specs = P('x', 'y')
|
||||||
in_specs=P('x', 'y'),
|
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||||
out_specs=P('x', 'y'),
|
out_specs = P(*get_output_specs(FftType.FFT, specs, mesh))
|
||||||
in_fourrier_space=True)(input)
|
|
||||||
|
return autoshmap(pk_fn, in_specs=out_specs, out_specs=out_specs)(input)
|
||||||
|
|
||||||
|
|
||||||
def gradient_kernel(kvec, direction, order=1):
|
def gradient_kernel(kvec, direction, order=1):
|
||||||
|
|
11
jaxpm/pm.py
11
jaxpm/pm.py
|
@ -22,13 +22,13 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0, halo_size=0):
|
||||||
assert (delta is not None
|
assert (delta is not None
|
||||||
), "If mesh_shape is not provided, delta should be provided"
|
), "If mesh_shape is not provided, delta should be provided"
|
||||||
mesh_shape = delta.shape
|
mesh_shape = delta.shape
|
||||||
kvec = fftk(mesh_shape)
|
|
||||||
|
|
||||||
if delta is None:
|
if delta is None:
|
||||||
delta_k = fft3d(cic_paint_dx(positions, halo_size=halo_size))
|
delta_k = fft3d(cic_paint_dx(positions, halo_size=halo_size))
|
||||||
else:
|
else:
|
||||||
delta_k = fft3d(delta)
|
delta_k = fft3d(delta)
|
||||||
|
|
||||||
|
kvec = fftk(delta_k)
|
||||||
# Computes gravitational potential
|
# Computes gravitational potential
|
||||||
pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec,
|
pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec,
|
||||||
r_split=r_split)
|
r_split=r_split)
|
||||||
|
@ -137,15 +137,16 @@ def linear_field(mesh_shape, box_size, pk, seed):
|
||||||
"""
|
"""
|
||||||
Generate initial conditions.
|
Generate initial conditions.
|
||||||
"""
|
"""
|
||||||
kvec = fftk(mesh_shape)
|
# Initialize a random field with one slice on each gpu
|
||||||
|
field = normal_field(mesh_shape, seed=seed)
|
||||||
|
field = fft3d(field)
|
||||||
|
kvec = fftk(field)
|
||||||
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2
|
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2
|
||||||
for i, kk in enumerate(kvec))**0.5
|
for i, kk in enumerate(kvec))**0.5
|
||||||
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (
|
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (
|
||||||
box_size[0] * box_size[1] * box_size[2])
|
box_size[0] * box_size[1] * box_size[2])
|
||||||
|
|
||||||
# Initialize a random field with one slice on each gpu
|
field = field * (pkmesh)**0.5
|
||||||
field = normal_field(mesh_shape, seed=seed)
|
|
||||||
field = fft3d(field) * pkmesh**0.5
|
|
||||||
field = ifft3d(field)
|
field = ifft3d(field)
|
||||||
return field
|
return field
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue