Fix sharding error (#37)
Some checks failed
Code Formatting / formatting (push) Failing after 4m30s
Tests / run_tests (3.10) (push) Failing after 1m41s
Tests / run_tests (3.11) (push) Failing after 1m42s
Tests / run_tests (3.12) (push) Failing after 1m15s

* Use cosmo as arg for the ODE function

* Update examples

* format

* notebook update

* fix tests

* add correct annotations for weights in painting and warning for cic_paint in distributed pm

* update test_against_fpm

* update distributed tests and add jacfwd jacrev and vmap tests

* format

* add Caveats to notebook readme

* final touches

* update Growth.py to allow using FastPM solver

* fix 2D painting when input is (X , Y , 2) shape

* update cic read halo size and notebooks examples

* Allow env variable control of caching in growth

* Format

* update test jax version

* update notebooks/03-MultiGPU_PM_Halo.ipynb

* update numpy install in wf

* update tolerance :)

* reorganize install in test workflow

* update tests

* add mpi4py

* update tests.yml

* update tests

* update wf

* format

* make normal_field signature consistent with jax.random.normal

* update by default normal_field dtype to match JAX

* format

* debug test workflow

* format

* debug test workflow

* updating tests

* fix accuracy

* fixed tolerance

* adding caching

* Update conftest.py

* Update tolerance and precision settings in distributed PM tests

* revererting back changes to growth.py

---------

Co-authored-by: Francois Lanusse <fr.eiffel@gmail.com>
Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>
This commit is contained in:
Wassim KABALAN 2025-06-28 23:07:31 +02:00 committed by GitHub
parent cb2a7ab17f
commit 6693e5c725
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 675 additions and 298 deletions

View file

@ -166,11 +166,11 @@ def uniform_particles(mesh_shape, sharding=None):
axis=-1)
def normal_field(mesh_shape, seed, sharding=None):
def normal_field(seed, shape, sharding=None, dtype=float):
"""Generate a Gaussian random field with the given power spectrum."""
gpu_mesh = sharding.mesh if sharding is not None else None
if gpu_mesh is not None and not (gpu_mesh.empty):
local_mesh_shape = get_local_shape(mesh_shape, sharding)
local_mesh_shape = get_local_shape(shape, sharding)
size = jax.device_count()
# rank = jax.process_index()
@ -190,9 +190,9 @@ def normal_field(mesh_shape, seed, sharding=None):
return jax.random.normal(key=keys[idx], shape=shape, dtype=dtype)
return shard_map(
partial(normal, shape=local_mesh_shape, dtype='float32'),
partial(normal, shape=local_mesh_shape, dtype=dtype),
mesh=gpu_mesh,
in_specs=P(None),
out_specs=spec)(keys) # yapf: disable
else:
return jax.random.normal(shape=mesh_shape, key=seed)
return jax.random.normal(shape=shape, key=seed, dtype=dtype)

View file

@ -12,7 +12,7 @@ from jaxpm.kernels import cic_compensation, fftk
from jaxpm.painting_utils import gather, scatter
def _cic_paint_impl(grid_mesh, positions, weight=None):
def _cic_paint_impl(grid_mesh, positions, weight=1.):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
displacement field: [nx, ny, nz, 3]
@ -27,12 +27,10 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
if weight is not None:
if jnp.isscalar(weight):
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
else:
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]),
kernel)
if jnp.isscalar(weight):
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
else:
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]), kernel)
neighboor_coords = jnp.mod(
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
@ -48,7 +46,13 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
@partial(jax.jit, static_argnums=(3, 4))
def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
def cic_paint(grid_mesh, positions, weight=1., halo_size=0, sharding=None):
if sharding is not None:
print("""
WARNING : absolute painting is not recommended in multi-device mode.
Please use relative painting instead.
""")
positions = positions.reshape((*grid_mesh.shape, 3))
@ -57,9 +61,11 @@ def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
weight_spec = P() if jnp.isscalar(weight) else spec
grid_mesh = autoshmap(_cic_paint_impl,
gpu_mesh=gpu_mesh,
in_specs=(spec, spec, P()),
in_specs=(spec, spec, weight_spec),
out_specs=spec)(grid_mesh, positions, weight)
grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,
@ -128,6 +134,7 @@ def cic_paint_2d(mesh, positions, weight):
positions: [npart, 2]
weight: [npart]
"""
positions = positions.reshape([-1, 2])
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
@ -136,7 +143,7 @@ def cic_paint_2d(mesh, positions, weight):
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1]
if weight is not None:
kernel = kernel * weight[..., jnp.newaxis]
kernel = kernel * weight.reshape(*positions.shape[:-1])
neighboor_coords = jnp.mod(
neighboor_coords.reshape([-1, 4, 2]).astype('int32'),
@ -151,13 +158,16 @@ def cic_paint_2d(mesh, positions, weight):
return mesh
def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
def _cic_paint_dx_impl(displacements,
weight=1.,
halo_size=0,
chunk_size=2**24):
halo_x, _ = halo_size[0]
halo_y, _ = halo_size[1]
original_shape = displacements.shape
particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32')
particle_mesh = jnp.zeros(original_shape[:-1], dtype=displacements.dtype)
if not jnp.isscalar(weight):
if weight.shape != original_shape[:-1]:
raise ValueError("Weight shape must match particle shape")
@ -175,7 +185,7 @@ def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
return scatter(pmid.reshape([-1, 3]),
displacements.reshape([-1, 3]),
particle_mesh,
chunk_size=2**24,
chunk_size=chunk_size,
val=weight)
@ -190,13 +200,13 @@ def cic_paint_dx(displacements,
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
weight_spec = P() if jnp.isscalar(weight) else spec
grid_mesh = autoshmap(partial(_cic_paint_dx_impl,
halo_size=halo_size,
weight=weight,
chunk_size=chunk_size),
gpu_mesh=gpu_mesh,
in_specs=spec,
out_specs=spec)(displacements)
in_specs=(spec, weight_spec),
out_specs=spec)(displacements, weight)
grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,
@ -230,6 +240,12 @@ def _cic_read_dx_impl(grid_mesh, disp, halo_size):
def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None):
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
# Halo size is halved for the read operation
# We only need to read the density field
# while in the painting operation we need to exchange and reduce the halo
# We chose to do that since it is much easier to write a custom jvp rule for exchange
# while it is a bit harder if there is a reduction involved
halo_size = jax.tree.map(lambda x: x // 2, halo_size)
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,

View file

@ -30,17 +30,14 @@ def enmesh(base_indices, displacements, cell_size, base_shape, offset,
"""Multilinear enmeshing."""
base_indices = jnp.asarray(base_indices)
displacements = jnp.asarray(displacements)
with jax.experimental.enable_x64():
cell_size = jnp.float64(
cell_size) if new_cell_size is not None else jnp.array(
cell_size, dtype=displacements.dtype)
if base_shape is not None:
base_shape = jnp.array(base_shape, dtype=base_indices.dtype)
offset = jnp.float64(offset)
if new_cell_size is not None:
new_cell_size = jnp.float64(new_cell_size)
if new_shape is not None:
new_shape = jnp.array(new_shape, dtype=base_indices.dtype)
cell_size = jnp.array(cell_size, dtype=displacements.dtype)
if base_shape is not None:
base_shape = jnp.array(base_shape, dtype=base_indices.dtype)
offset = offset.astype(base_indices.dtype)
if new_cell_size is not None:
new_cell_size = jnp.array(new_cell_size, dtype=displacements.dtype)
if new_shape is not None:
new_shape = jnp.array(new_shape, dtype=base_indices.dtype)
spatial_dim = base_indices.shape[1]
neighbor_offsets = (

View file

@ -131,7 +131,7 @@ def linear_field(mesh_shape, box_size, pk, seed, sharding=None):
Generate initial conditions.
"""
# Initialize a random field with one slice on each gpu
field = normal_field(mesh_shape, seed=seed, sharding=sharding)
field = normal_field(seed=seed, shape=mesh_shape, sharding=sharding)
field = fft3d(field)
kvec = fftk(field)
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2
@ -172,8 +172,7 @@ def make_ode_fn(mesh_shape,
return nbody_ode
def make_diffrax_ode(cosmo,
mesh_shape,
def make_diffrax_ode(mesh_shape,
paint_absolute_pos=True,
halo_size=0,
sharding=None):
@ -183,6 +182,7 @@ def make_diffrax_ode(cosmo,
state is a tuple (position, velocities)
"""
pos, vel = state
cosmo = args
forces = pm_forces(pos,
mesh_shape=mesh_shape,