This commit is contained in:
Wassim KABALAN 2024-08-03 00:23:40 +02:00
parent 831291c1f9
commit ece8c93540
12 changed files with 210 additions and 170 deletions

View file

@ -44,17 +44,20 @@ def autoshmap(f: Callable,
return f
else:
if in_fourrier_space and 1 in mesh.devices.shape:
in_specs , out_specs = switch_specs((in_specs , out_specs))
in_specs, out_specs = switch_specs((in_specs, out_specs))
return shard_map(f, mesh, in_specs, out_specs, check_rep, auto)
def switch_specs(specs):
if isinstance(specs, P):
new_axes = tuple('y' if ax == 'x' else 'x' if ax == 'y' else ax for ax in specs)
return P(*new_axes)
elif isinstance(specs, tuple):
return tuple(switch_specs(sub_spec) for sub_spec in specs)
else:
raise TypeError("Element must be either a PartitionSpec or a tuple")
if isinstance(specs, P):
new_axes = tuple('y' if ax == 'x' else 'x' if ax == 'y' else ax
for ax in specs)
return P(*new_axes)
elif isinstance(specs, tuple):
return tuple(switch_specs(sub_spec) for sub_spec in specs)
else:
raise TypeError("Element must be either a PartitionSpec or a tuple")
def fft3d(x):
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
@ -105,14 +108,15 @@ def slice_unpad_impl(x, pad_width):
# Apply corrections along y
x = x.at[:, halo_y:halo_y + halo_y // 2].add(x[:, :halo_y // 2])
x = x.at[:, -(halo_y + halo_y // 2):-halo_y].add(x[:, -halo_y // 2:])
unpad_slice = [slice(None)] * 3
if halo_x > 0:
unpad_slice[0] = slice(halo_x , -halo_x)
unpad_slice[0] = slice(halo_x, -halo_x)
if halo_y > 0:
unpad_slice[1] = slice(halo_y , -halo_y)
return x[tuple(unpad_slice)]
unpad_slice[1] = slice(halo_y, -halo_y)
return x[tuple(unpad_slice)]
def slice_pad(x, pad_width):
mesh = mesh_lib.thread_resources.env.physical_mesh

View file

@ -1,3 +1,4 @@
from enum import Enum
from functools import partial
import jax.numpy as jnp
@ -7,29 +8,31 @@ from jax._src import mesh as mesh_lib
from jax.sharding import PartitionSpec as P
from jaxpm.distributed import autoshmap
from enum import Enum
class PencilType(Enum):
NO_DECOMP = 0
SLAB_XY = 1
SLAB_YZ = 2
PENCILS = 3
NO_DECOMP = 0
SLAB_XY = 1
SLAB_YZ = 2
PENCILS = 3
def get_pencil_type():
mesh = mesh_lib.thread_resources.env.physical_mesh
if mesh.empty:
pdims = None
else:
pdims = mesh.devices.shape[::-1]
mesh = mesh_lib.thread_resources.env.physical_mesh
if mesh.empty:
pdims = None
else:
pdims = mesh.devices.shape[::-1]
if pdims == (1, 1) or pdims == None:
return PencilType.NO_DECOMP
elif pdims[0] == 1:
return PencilType.SLAB_XY
elif pdims[1] == 1:
return PencilType.SLAB_YZ
else:
return PencilType.PENCILS
if pdims == (1, 1) or pdims == None:
return PencilType.NO_DECOMP
elif pdims[0] == 1:
return PencilType.SLAB_XY
elif pdims[1] == 1:
return PencilType.SLAB_YZ
else:
return PencilType.PENCILS
def fftk(shape, dtype=np.float32):
"""
@ -46,22 +49,23 @@ def fftk(shape, dtype=np.float32):
@partial(autoshmap,
in_specs=(P('x'), P('y'), P(None)),
out_specs=(P('x'), P(None, 'y'), P(None)),in_fourrier_space=True)
out_specs=(P('x'), P(None, 'y'), P(None)),
in_fourrier_space=True)
def get_kvec(ky, kz, kx):
return (ky.reshape([-1, 1, 1]),
kz.reshape([1, -1, 1]),
kx.reshape([1, 1, -1])) # yapf: disable
pencil_type = get_pencil_type()
pencil_type = get_pencil_type()
# YZ returns Y pencil
# XY and pencils returns a Z pencil
# NO_DECOMP returns a X pencil
if pencil_type == PencilType.NO_DECOMP:
kx, ky, kz = get_kvec(kx, ky, kz) # Z Y X ==> X pencil
kx, ky, kz = get_kvec(kx, ky, kz) # Z Y X ==> X pencil
elif pencil_type == PencilType.SLAB_YZ:
kz, kx, ky = get_kvec(kz, kx, ky) # X Z Y ==> Y pencil
kz, kx, ky = get_kvec(kz, kx, ky) # X Z Y ==> Y pencil
elif pencil_type == PencilType.SLAB_XY or pencil_type == PencilType.PENCILS:
ky, kz, kx = get_kvec(ky, kz, kx) # Z X Y ==> Z pencil
ky, kz, kx = get_kvec(ky, kz, kx) # Z X Y ==> Z pencil
else:
raise ValueError("Unknown pencil type")
@ -73,7 +77,10 @@ def interpolate_power_spectrum(input, k, pk):
pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape(-1), k, pk
).reshape(x.shape)
return autoshmap(pk_fn, in_specs=P('x', 'y'), out_specs=P('x', 'y'),in_fourrier_space=True)(input)
return autoshmap(pk_fn,
in_specs=P('x', 'y'),
out_specs=P('x', 'y'),
in_fourrier_space=True)(input)
def gradient_kernel(kvec, direction, order=1):

View file

@ -150,7 +150,7 @@ def cic_paint_dx_impl(displacements, halo_size):
jnp.arange(particle_mesh.shape[1]),
jnp.arange(particle_mesh.shape[2]),
indexing='ij')
particle_mesh = jnp.pad(particle_mesh, halo_size)
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
pmid = pmid.reshape([-1, 3])
@ -159,13 +159,13 @@ def cic_paint_dx_impl(displacements, halo_size):
@partial(jax.jit, static_argnums=(1, ))
def cic_paint_dx(displacements, halo_size=0):
halo_size, halo_extents = get_halo_size(halo_size)
mesh = autoshmap(partial(cic_paint_dx_impl, halo_size=halo_size),
in_specs=(P('x', 'y')),
out_specs=P('x', 'y'))(displacements)
mesh = halo_exchange(mesh,
halo_extents=halo_extents,
halo_periods=(True, True, True))
@ -173,19 +173,21 @@ def cic_paint_dx(displacements, halo_size=0):
return mesh
def cic_read_dx_impl(mesh , halo_size):
def cic_read_dx_impl(mesh, halo_size):
halo_x, _ = halo_size[0]
halo_y, _ = halo_size[1]
original_shape = [dim - 2 * halo[0] for dim , halo in zip(mesh.shape, halo_size)]
original_shape = [
dim - 2 * halo[0] for dim, halo in zip(mesh.shape, halo_size)
]
a, b, c = jnp.meshgrid(jnp.arange(original_shape[0]),
jnp.arange(original_shape[1]),
jnp.arange(original_shape[2]),
indexing='ij')
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
pmid = pmid.reshape([-1, 3])
return gather(pmid, jnp.zeros_like(pmid), mesh).reshape(original_shape)
@ -199,7 +201,7 @@ def cic_read_dx(mesh, halo_size=0):
mesh = halo_exchange(mesh,
halo_extents=halo_extents,
halo_periods=(True, True, True))
displacements = autoshmap(partial(cic_read_dx_impl , halo_size=halo_size),
displacements = autoshmap(partial(cic_read_dx_impl, halo_size=halo_size),
in_specs=(P('x', 'y')),
out_specs=P('x', 'y'))(mesh)

View file

@ -19,10 +19,11 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0, halo_size=0):
Computes gravitational forces on particles using a PM scheme
"""
if mesh_shape is None:
assert(delta is not None) , "If mesh_shape is not provided, delta should be provided"
assert (delta is not None
), "If mesh_shape is not provided, delta should be provided"
mesh_shape = delta.shape
kvec = fftk(mesh_shape)
if delta is None:
delta_k = fft3d(cic_paint_dx(positions, halo_size=halo_size))
else:
@ -33,8 +34,8 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0, halo_size=0):
r_split=r_split)
# Computes gravitational forces
forces = jnp.stack([
cic_read_dx(ifft3d(gradient_kernel(kvec, i) * pot_k), halo_size=halo_size)
for i in range(3)
cic_read_dx(ifft3d(gradient_kernel(kvec, i) * pot_k),
halo_size=halo_size) for i in range(3)
],
axis=-1)