mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-30 09:01:11 +00:00
Fix sharding error (#37)
* Use cosmo as arg for the ODE function * Update examples * format * notebook update * fix tests * add correct annotations for weights in painting and warning for cic_paint in distributed pm * update test_against_fpm * update distributed tests and add jacfwd jacrev and vmap tests * format * add Caveats to notebook readme * final touches * update Growth.py to allow using FastPM solver * fix 2D painting when input is (X , Y , 2) shape * update cic read halo size and notebooks examples * Allow env variable control of caching in growth * Format * update test jax version * update notebooks/03-MultiGPU_PM_Halo.ipynb * update numpy install in wf * update tolerance :) * reorganize install in test workflow * update tests * add mpi4py * update tests.yml * update tests * update wf * format * make normal_field signature consistent with jax.random.normal * update by default normal_field dtype to match JAX * format * debug test workflow * format * debug test workflow * updating tests * fix accuracy * fixed tolerance * adding caching * Update conftest.py * Update tolerance and precision settings in distributed PM tests * revererting back changes to growth.py --------- Co-authored-by: Francois Lanusse <fr.eiffel@gmail.com> Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>
This commit is contained in:
parent
cb2a7ab17f
commit
6693e5c725
17 changed files with 675 additions and 298 deletions
|
@ -166,11 +166,11 @@ def uniform_particles(mesh_shape, sharding=None):
|
|||
axis=-1)
|
||||
|
||||
|
||||
def normal_field(mesh_shape, seed, sharding=None):
|
||||
def normal_field(seed, shape, sharding=None, dtype=float):
|
||||
"""Generate a Gaussian random field with the given power spectrum."""
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
if gpu_mesh is not None and not (gpu_mesh.empty):
|
||||
local_mesh_shape = get_local_shape(mesh_shape, sharding)
|
||||
local_mesh_shape = get_local_shape(shape, sharding)
|
||||
|
||||
size = jax.device_count()
|
||||
# rank = jax.process_index()
|
||||
|
@ -190,9 +190,9 @@ def normal_field(mesh_shape, seed, sharding=None):
|
|||
return jax.random.normal(key=keys[idx], shape=shape, dtype=dtype)
|
||||
|
||||
return shard_map(
|
||||
partial(normal, shape=local_mesh_shape, dtype='float32'),
|
||||
partial(normal, shape=local_mesh_shape, dtype=dtype),
|
||||
mesh=gpu_mesh,
|
||||
in_specs=P(None),
|
||||
out_specs=spec)(keys) # yapf: disable
|
||||
else:
|
||||
return jax.random.normal(shape=mesh_shape, key=seed)
|
||||
return jax.random.normal(shape=shape, key=seed, dtype=dtype)
|
||||
|
|
|
@ -12,7 +12,7 @@ from jaxpm.kernels import cic_compensation, fftk
|
|||
from jaxpm.painting_utils import gather, scatter
|
||||
|
||||
|
||||
def _cic_paint_impl(grid_mesh, positions, weight=None):
|
||||
def _cic_paint_impl(grid_mesh, positions, weight=1.):
|
||||
""" Paints positions onto mesh
|
||||
mesh: [nx, ny, nz]
|
||||
displacement field: [nx, ny, nz, 3]
|
||||
|
@ -27,12 +27,10 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
|
|||
neighboor_coords = floor + connection
|
||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||
if weight is not None:
|
||||
if jnp.isscalar(weight):
|
||||
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
|
||||
else:
|
||||
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]),
|
||||
kernel)
|
||||
if jnp.isscalar(weight):
|
||||
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
|
||||
else:
|
||||
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]), kernel)
|
||||
|
||||
neighboor_coords = jnp.mod(
|
||||
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
|
||||
|
@ -48,7 +46,13 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
|
|||
|
||||
|
||||
@partial(jax.jit, static_argnums=(3, 4))
|
||||
def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
|
||||
def cic_paint(grid_mesh, positions, weight=1., halo_size=0, sharding=None):
|
||||
|
||||
if sharding is not None:
|
||||
print("""
|
||||
WARNING : absolute painting is not recommended in multi-device mode.
|
||||
Please use relative painting instead.
|
||||
""")
|
||||
|
||||
positions = positions.reshape((*grid_mesh.shape, 3))
|
||||
|
||||
|
@ -57,9 +61,11 @@ def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
|
|||
|
||||
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
||||
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
||||
weight_spec = P() if jnp.isscalar(weight) else spec
|
||||
|
||||
grid_mesh = autoshmap(_cic_paint_impl,
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=(spec, spec, P()),
|
||||
in_specs=(spec, spec, weight_spec),
|
||||
out_specs=spec)(grid_mesh, positions, weight)
|
||||
grid_mesh = halo_exchange(grid_mesh,
|
||||
halo_extents=halo_extents,
|
||||
|
@ -128,6 +134,7 @@ def cic_paint_2d(mesh, positions, weight):
|
|||
positions: [npart, 2]
|
||||
weight: [npart]
|
||||
"""
|
||||
positions = positions.reshape([-1, 2])
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
floor = jnp.floor(positions)
|
||||
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
|
||||
|
@ -136,7 +143,7 @@ def cic_paint_2d(mesh, positions, weight):
|
|||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1]
|
||||
if weight is not None:
|
||||
kernel = kernel * weight[..., jnp.newaxis]
|
||||
kernel = kernel * weight.reshape(*positions.shape[:-1])
|
||||
|
||||
neighboor_coords = jnp.mod(
|
||||
neighboor_coords.reshape([-1, 4, 2]).astype('int32'),
|
||||
|
@ -151,13 +158,16 @@ def cic_paint_2d(mesh, positions, weight):
|
|||
return mesh
|
||||
|
||||
|
||||
def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
|
||||
def _cic_paint_dx_impl(displacements,
|
||||
weight=1.,
|
||||
halo_size=0,
|
||||
chunk_size=2**24):
|
||||
|
||||
halo_x, _ = halo_size[0]
|
||||
halo_y, _ = halo_size[1]
|
||||
|
||||
original_shape = displacements.shape
|
||||
particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32')
|
||||
particle_mesh = jnp.zeros(original_shape[:-1], dtype=displacements.dtype)
|
||||
if not jnp.isscalar(weight):
|
||||
if weight.shape != original_shape[:-1]:
|
||||
raise ValueError("Weight shape must match particle shape")
|
||||
|
@ -175,7 +185,7 @@ def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
|
|||
return scatter(pmid.reshape([-1, 3]),
|
||||
displacements.reshape([-1, 3]),
|
||||
particle_mesh,
|
||||
chunk_size=2**24,
|
||||
chunk_size=chunk_size,
|
||||
val=weight)
|
||||
|
||||
|
||||
|
@ -190,13 +200,13 @@ def cic_paint_dx(displacements,
|
|||
|
||||
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
||||
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
||||
weight_spec = P() if jnp.isscalar(weight) else spec
|
||||
grid_mesh = autoshmap(partial(_cic_paint_dx_impl,
|
||||
halo_size=halo_size,
|
||||
weight=weight,
|
||||
chunk_size=chunk_size),
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=spec,
|
||||
out_specs=spec)(displacements)
|
||||
in_specs=(spec, weight_spec),
|
||||
out_specs=spec)(displacements, weight)
|
||||
|
||||
grid_mesh = halo_exchange(grid_mesh,
|
||||
halo_extents=halo_extents,
|
||||
|
@ -230,6 +240,12 @@ def _cic_read_dx_impl(grid_mesh, disp, halo_size):
|
|||
def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None):
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
|
||||
# Halo size is halved for the read operation
|
||||
# We only need to read the density field
|
||||
# while in the painting operation we need to exchange and reduce the halo
|
||||
# We chose to do that since it is much easier to write a custom jvp rule for exchange
|
||||
# while it is a bit harder if there is a reduction involved
|
||||
halo_size = jax.tree.map(lambda x: x // 2, halo_size)
|
||||
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
|
||||
grid_mesh = halo_exchange(grid_mesh,
|
||||
halo_extents=halo_extents,
|
||||
|
|
|
@ -30,17 +30,14 @@ def enmesh(base_indices, displacements, cell_size, base_shape, offset,
|
|||
"""Multilinear enmeshing."""
|
||||
base_indices = jnp.asarray(base_indices)
|
||||
displacements = jnp.asarray(displacements)
|
||||
with jax.experimental.enable_x64():
|
||||
cell_size = jnp.float64(
|
||||
cell_size) if new_cell_size is not None else jnp.array(
|
||||
cell_size, dtype=displacements.dtype)
|
||||
if base_shape is not None:
|
||||
base_shape = jnp.array(base_shape, dtype=base_indices.dtype)
|
||||
offset = jnp.float64(offset)
|
||||
if new_cell_size is not None:
|
||||
new_cell_size = jnp.float64(new_cell_size)
|
||||
if new_shape is not None:
|
||||
new_shape = jnp.array(new_shape, dtype=base_indices.dtype)
|
||||
cell_size = jnp.array(cell_size, dtype=displacements.dtype)
|
||||
if base_shape is not None:
|
||||
base_shape = jnp.array(base_shape, dtype=base_indices.dtype)
|
||||
offset = offset.astype(base_indices.dtype)
|
||||
if new_cell_size is not None:
|
||||
new_cell_size = jnp.array(new_cell_size, dtype=displacements.dtype)
|
||||
if new_shape is not None:
|
||||
new_shape = jnp.array(new_shape, dtype=base_indices.dtype)
|
||||
|
||||
spatial_dim = base_indices.shape[1]
|
||||
neighbor_offsets = (
|
||||
|
|
|
@ -131,7 +131,7 @@ def linear_field(mesh_shape, box_size, pk, seed, sharding=None):
|
|||
Generate initial conditions.
|
||||
"""
|
||||
# Initialize a random field with one slice on each gpu
|
||||
field = normal_field(mesh_shape, seed=seed, sharding=sharding)
|
||||
field = normal_field(seed=seed, shape=mesh_shape, sharding=sharding)
|
||||
field = fft3d(field)
|
||||
kvec = fftk(field)
|
||||
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2
|
||||
|
@ -172,8 +172,7 @@ def make_ode_fn(mesh_shape,
|
|||
return nbody_ode
|
||||
|
||||
|
||||
def make_diffrax_ode(cosmo,
|
||||
mesh_shape,
|
||||
def make_diffrax_ode(mesh_shape,
|
||||
paint_absolute_pos=True,
|
||||
halo_size=0,
|
||||
sharding=None):
|
||||
|
@ -183,6 +182,7 @@ def make_diffrax_ode(cosmo,
|
|||
state is a tuple (position, velocities)
|
||||
"""
|
||||
pos, vel = state
|
||||
cosmo = args
|
||||
|
||||
forces = pm_forces(pos,
|
||||
mesh_shape=mesh_shape,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue