mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-05-15 04:21:12 +00:00
format
This commit is contained in:
parent
831291c1f9
commit
ece8c93540
12 changed files with 210 additions and 170 deletions
|
@ -44,17 +44,20 @@ def autoshmap(f: Callable,
|
|||
return f
|
||||
else:
|
||||
if in_fourrier_space and 1 in mesh.devices.shape:
|
||||
in_specs , out_specs = switch_specs((in_specs , out_specs))
|
||||
in_specs, out_specs = switch_specs((in_specs, out_specs))
|
||||
return shard_map(f, mesh, in_specs, out_specs, check_rep, auto)
|
||||
|
||||
|
||||
def switch_specs(specs):
|
||||
if isinstance(specs, P):
|
||||
new_axes = tuple('y' if ax == 'x' else 'x' if ax == 'y' else ax for ax in specs)
|
||||
return P(*new_axes)
|
||||
elif isinstance(specs, tuple):
|
||||
return tuple(switch_specs(sub_spec) for sub_spec in specs)
|
||||
else:
|
||||
raise TypeError("Element must be either a PartitionSpec or a tuple")
|
||||
if isinstance(specs, P):
|
||||
new_axes = tuple('y' if ax == 'x' else 'x' if ax == 'y' else ax
|
||||
for ax in specs)
|
||||
return P(*new_axes)
|
||||
elif isinstance(specs, tuple):
|
||||
return tuple(switch_specs(sub_spec) for sub_spec in specs)
|
||||
else:
|
||||
raise TypeError("Element must be either a PartitionSpec or a tuple")
|
||||
|
||||
|
||||
def fft3d(x):
|
||||
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
|
||||
|
@ -105,14 +108,15 @@ def slice_unpad_impl(x, pad_width):
|
|||
# Apply corrections along y
|
||||
x = x.at[:, halo_y:halo_y + halo_y // 2].add(x[:, :halo_y // 2])
|
||||
x = x.at[:, -(halo_y + halo_y // 2):-halo_y].add(x[:, -halo_y // 2:])
|
||||
|
||||
|
||||
unpad_slice = [slice(None)] * 3
|
||||
if halo_x > 0:
|
||||
unpad_slice[0] = slice(halo_x , -halo_x)
|
||||
unpad_slice[0] = slice(halo_x, -halo_x)
|
||||
if halo_y > 0:
|
||||
unpad_slice[1] = slice(halo_y , -halo_y)
|
||||
|
||||
return x[tuple(unpad_slice)]
|
||||
unpad_slice[1] = slice(halo_y, -halo_y)
|
||||
|
||||
return x[tuple(unpad_slice)]
|
||||
|
||||
|
||||
def slice_pad(x, pad_width):
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from enum import Enum
|
||||
from functools import partial
|
||||
|
||||
import jax.numpy as jnp
|
||||
|
@ -7,29 +8,31 @@ from jax._src import mesh as mesh_lib
|
|||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
from jaxpm.distributed import autoshmap
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class PencilType(Enum):
|
||||
NO_DECOMP = 0
|
||||
SLAB_XY = 1
|
||||
SLAB_YZ = 2
|
||||
PENCILS = 3
|
||||
NO_DECOMP = 0
|
||||
SLAB_XY = 1
|
||||
SLAB_YZ = 2
|
||||
PENCILS = 3
|
||||
|
||||
|
||||
def get_pencil_type():
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
if mesh.empty:
|
||||
pdims = None
|
||||
else:
|
||||
pdims = mesh.devices.shape[::-1]
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
if mesh.empty:
|
||||
pdims = None
|
||||
else:
|
||||
pdims = mesh.devices.shape[::-1]
|
||||
|
||||
if pdims == (1, 1) or pdims == None:
|
||||
return PencilType.NO_DECOMP
|
||||
elif pdims[0] == 1:
|
||||
return PencilType.SLAB_XY
|
||||
elif pdims[1] == 1:
|
||||
return PencilType.SLAB_YZ
|
||||
else:
|
||||
return PencilType.PENCILS
|
||||
|
||||
if pdims == (1, 1) or pdims == None:
|
||||
return PencilType.NO_DECOMP
|
||||
elif pdims[0] == 1:
|
||||
return PencilType.SLAB_XY
|
||||
elif pdims[1] == 1:
|
||||
return PencilType.SLAB_YZ
|
||||
else:
|
||||
return PencilType.PENCILS
|
||||
|
||||
def fftk(shape, dtype=np.float32):
|
||||
"""
|
||||
|
@ -46,22 +49,23 @@ def fftk(shape, dtype=np.float32):
|
|||
|
||||
@partial(autoshmap,
|
||||
in_specs=(P('x'), P('y'), P(None)),
|
||||
out_specs=(P('x'), P(None, 'y'), P(None)),in_fourrier_space=True)
|
||||
out_specs=(P('x'), P(None, 'y'), P(None)),
|
||||
in_fourrier_space=True)
|
||||
def get_kvec(ky, kz, kx):
|
||||
return (ky.reshape([-1, 1, 1]),
|
||||
kz.reshape([1, -1, 1]),
|
||||
kx.reshape([1, 1, -1])) # yapf: disable
|
||||
|
||||
pencil_type = get_pencil_type()
|
||||
pencil_type = get_pencil_type()
|
||||
# YZ returns Y pencil
|
||||
# XY and pencils returns a Z pencil
|
||||
# NO_DECOMP returns a X pencil
|
||||
if pencil_type == PencilType.NO_DECOMP:
|
||||
kx, ky, kz = get_kvec(kx, ky, kz) # Z Y X ==> X pencil
|
||||
kx, ky, kz = get_kvec(kx, ky, kz) # Z Y X ==> X pencil
|
||||
elif pencil_type == PencilType.SLAB_YZ:
|
||||
kz, kx, ky = get_kvec(kz, kx, ky) # X Z Y ==> Y pencil
|
||||
kz, kx, ky = get_kvec(kz, kx, ky) # X Z Y ==> Y pencil
|
||||
elif pencil_type == PencilType.SLAB_XY or pencil_type == PencilType.PENCILS:
|
||||
ky, kz, kx = get_kvec(ky, kz, kx) # Z X Y ==> Z pencil
|
||||
ky, kz, kx = get_kvec(ky, kz, kx) # Z X Y ==> Z pencil
|
||||
else:
|
||||
raise ValueError("Unknown pencil type")
|
||||
|
||||
|
@ -73,7 +77,10 @@ def interpolate_power_spectrum(input, k, pk):
|
|||
|
||||
pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape(-1), k, pk
|
||||
).reshape(x.shape)
|
||||
return autoshmap(pk_fn, in_specs=P('x', 'y'), out_specs=P('x', 'y'),in_fourrier_space=True)(input)
|
||||
return autoshmap(pk_fn,
|
||||
in_specs=P('x', 'y'),
|
||||
out_specs=P('x', 'y'),
|
||||
in_fourrier_space=True)(input)
|
||||
|
||||
|
||||
def gradient_kernel(kvec, direction, order=1):
|
||||
|
|
|
@ -150,7 +150,7 @@ def cic_paint_dx_impl(displacements, halo_size):
|
|||
jnp.arange(particle_mesh.shape[1]),
|
||||
jnp.arange(particle_mesh.shape[2]),
|
||||
indexing='ij')
|
||||
|
||||
|
||||
particle_mesh = jnp.pad(particle_mesh, halo_size)
|
||||
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
|
||||
pmid = pmid.reshape([-1, 3])
|
||||
|
@ -159,13 +159,13 @@ def cic_paint_dx_impl(displacements, halo_size):
|
|||
|
||||
@partial(jax.jit, static_argnums=(1, ))
|
||||
def cic_paint_dx(displacements, halo_size=0):
|
||||
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size)
|
||||
|
||||
|
||||
mesh = autoshmap(partial(cic_paint_dx_impl, halo_size=halo_size),
|
||||
in_specs=(P('x', 'y')),
|
||||
out_specs=P('x', 'y'))(displacements)
|
||||
|
||||
|
||||
mesh = halo_exchange(mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True, True))
|
||||
|
@ -173,19 +173,21 @@ def cic_paint_dx(displacements, halo_size=0):
|
|||
return mesh
|
||||
|
||||
|
||||
def cic_read_dx_impl(mesh , halo_size):
|
||||
def cic_read_dx_impl(mesh, halo_size):
|
||||
|
||||
halo_x, _ = halo_size[0]
|
||||
halo_y, _ = halo_size[1]
|
||||
|
||||
original_shape = [dim - 2 * halo[0] for dim , halo in zip(mesh.shape, halo_size)]
|
||||
original_shape = [
|
||||
dim - 2 * halo[0] for dim, halo in zip(mesh.shape, halo_size)
|
||||
]
|
||||
a, b, c = jnp.meshgrid(jnp.arange(original_shape[0]),
|
||||
jnp.arange(original_shape[1]),
|
||||
jnp.arange(original_shape[2]),
|
||||
indexing='ij')
|
||||
|
||||
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
|
||||
|
||||
|
||||
pmid = pmid.reshape([-1, 3])
|
||||
|
||||
return gather(pmid, jnp.zeros_like(pmid), mesh).reshape(original_shape)
|
||||
|
@ -199,7 +201,7 @@ def cic_read_dx(mesh, halo_size=0):
|
|||
mesh = halo_exchange(mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True, True))
|
||||
displacements = autoshmap(partial(cic_read_dx_impl , halo_size=halo_size),
|
||||
displacements = autoshmap(partial(cic_read_dx_impl, halo_size=halo_size),
|
||||
in_specs=(P('x', 'y')),
|
||||
out_specs=P('x', 'y'))(mesh)
|
||||
|
||||
|
|
|
@ -19,10 +19,11 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0, halo_size=0):
|
|||
Computes gravitational forces on particles using a PM scheme
|
||||
"""
|
||||
if mesh_shape is None:
|
||||
assert(delta is not None) , "If mesh_shape is not provided, delta should be provided"
|
||||
assert (delta is not None
|
||||
), "If mesh_shape is not provided, delta should be provided"
|
||||
mesh_shape = delta.shape
|
||||
kvec = fftk(mesh_shape)
|
||||
|
||||
|
||||
if delta is None:
|
||||
delta_k = fft3d(cic_paint_dx(positions, halo_size=halo_size))
|
||||
else:
|
||||
|
@ -33,8 +34,8 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0, halo_size=0):
|
|||
r_split=r_split)
|
||||
# Computes gravitational forces
|
||||
forces = jnp.stack([
|
||||
cic_read_dx(ifft3d(gradient_kernel(kvec, i) * pot_k), halo_size=halo_size)
|
||||
for i in range(3)
|
||||
cic_read_dx(ifft3d(gradient_kernel(kvec, i) * pot_k),
|
||||
halo_size=halo_size) for i in range(3)
|
||||
],
|
||||
axis=-1)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue