mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-30 09:01:11 +00:00
Fix sharding error (#37)
* Use cosmo as arg for the ODE function * Update examples * format * notebook update * fix tests * add correct annotations for weights in painting and warning for cic_paint in distributed pm * update test_against_fpm * update distributed tests and add jacfwd jacrev and vmap tests * format * add Caveats to notebook readme * final touches * update Growth.py to allow using FastPM solver * fix 2D painting when input is (X , Y , 2) shape * update cic read halo size and notebooks examples * Allow env variable control of caching in growth * Format * update test jax version * update notebooks/03-MultiGPU_PM_Halo.ipynb * update numpy install in wf * update tolerance :) * reorganize install in test workflow * update tests * add mpi4py * update tests.yml * update tests * update wf * format * make normal_field signature consistent with jax.random.normal * update by default normal_field dtype to match JAX * format * debug test workflow * format * debug test workflow * updating tests * fix accuracy * fixed tolerance * adding caching * Update conftest.py * Update tolerance and precision settings in distributed PM tests * revererting back changes to growth.py --------- Co-authored-by: Francois Lanusse <fr.eiffel@gmail.com> Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>
This commit is contained in:
parent
cb2a7ab17f
commit
6693e5c725
17 changed files with 675 additions and 298 deletions
|
@ -12,7 +12,7 @@ from jaxpm.kernels import cic_compensation, fftk
|
|||
from jaxpm.painting_utils import gather, scatter
|
||||
|
||||
|
||||
def _cic_paint_impl(grid_mesh, positions, weight=None):
|
||||
def _cic_paint_impl(grid_mesh, positions, weight=1.):
|
||||
""" Paints positions onto mesh
|
||||
mesh: [nx, ny, nz]
|
||||
displacement field: [nx, ny, nz, 3]
|
||||
|
@ -27,12 +27,10 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
|
|||
neighboor_coords = floor + connection
|
||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||
if weight is not None:
|
||||
if jnp.isscalar(weight):
|
||||
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
|
||||
else:
|
||||
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]),
|
||||
kernel)
|
||||
if jnp.isscalar(weight):
|
||||
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
|
||||
else:
|
||||
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]), kernel)
|
||||
|
||||
neighboor_coords = jnp.mod(
|
||||
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
|
||||
|
@ -48,7 +46,13 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
|
|||
|
||||
|
||||
@partial(jax.jit, static_argnums=(3, 4))
|
||||
def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
|
||||
def cic_paint(grid_mesh, positions, weight=1., halo_size=0, sharding=None):
|
||||
|
||||
if sharding is not None:
|
||||
print("""
|
||||
WARNING : absolute painting is not recommended in multi-device mode.
|
||||
Please use relative painting instead.
|
||||
""")
|
||||
|
||||
positions = positions.reshape((*grid_mesh.shape, 3))
|
||||
|
||||
|
@ -57,9 +61,11 @@ def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
|
|||
|
||||
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
||||
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
||||
weight_spec = P() if jnp.isscalar(weight) else spec
|
||||
|
||||
grid_mesh = autoshmap(_cic_paint_impl,
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=(spec, spec, P()),
|
||||
in_specs=(spec, spec, weight_spec),
|
||||
out_specs=spec)(grid_mesh, positions, weight)
|
||||
grid_mesh = halo_exchange(grid_mesh,
|
||||
halo_extents=halo_extents,
|
||||
|
@ -128,6 +134,7 @@ def cic_paint_2d(mesh, positions, weight):
|
|||
positions: [npart, 2]
|
||||
weight: [npart]
|
||||
"""
|
||||
positions = positions.reshape([-1, 2])
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
floor = jnp.floor(positions)
|
||||
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
|
||||
|
@ -136,7 +143,7 @@ def cic_paint_2d(mesh, positions, weight):
|
|||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1]
|
||||
if weight is not None:
|
||||
kernel = kernel * weight[..., jnp.newaxis]
|
||||
kernel = kernel * weight.reshape(*positions.shape[:-1])
|
||||
|
||||
neighboor_coords = jnp.mod(
|
||||
neighboor_coords.reshape([-1, 4, 2]).astype('int32'),
|
||||
|
@ -151,13 +158,16 @@ def cic_paint_2d(mesh, positions, weight):
|
|||
return mesh
|
||||
|
||||
|
||||
def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
|
||||
def _cic_paint_dx_impl(displacements,
|
||||
weight=1.,
|
||||
halo_size=0,
|
||||
chunk_size=2**24):
|
||||
|
||||
halo_x, _ = halo_size[0]
|
||||
halo_y, _ = halo_size[1]
|
||||
|
||||
original_shape = displacements.shape
|
||||
particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32')
|
||||
particle_mesh = jnp.zeros(original_shape[:-1], dtype=displacements.dtype)
|
||||
if not jnp.isscalar(weight):
|
||||
if weight.shape != original_shape[:-1]:
|
||||
raise ValueError("Weight shape must match particle shape")
|
||||
|
@ -175,7 +185,7 @@ def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
|
|||
return scatter(pmid.reshape([-1, 3]),
|
||||
displacements.reshape([-1, 3]),
|
||||
particle_mesh,
|
||||
chunk_size=2**24,
|
||||
chunk_size=chunk_size,
|
||||
val=weight)
|
||||
|
||||
|
||||
|
@ -190,13 +200,13 @@ def cic_paint_dx(displacements,
|
|||
|
||||
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
||||
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
||||
weight_spec = P() if jnp.isscalar(weight) else spec
|
||||
grid_mesh = autoshmap(partial(_cic_paint_dx_impl,
|
||||
halo_size=halo_size,
|
||||
weight=weight,
|
||||
chunk_size=chunk_size),
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=spec,
|
||||
out_specs=spec)(displacements)
|
||||
in_specs=(spec, weight_spec),
|
||||
out_specs=spec)(displacements, weight)
|
||||
|
||||
grid_mesh = halo_exchange(grid_mesh,
|
||||
halo_extents=halo_extents,
|
||||
|
@ -230,6 +240,12 @@ def _cic_read_dx_impl(grid_mesh, disp, halo_size):
|
|||
def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None):
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
|
||||
# Halo size is halved for the read operation
|
||||
# We only need to read the density field
|
||||
# while in the painting operation we need to exchange and reduce the halo
|
||||
# We chose to do that since it is much easier to write a custom jvp rule for exchange
|
||||
# while it is a bit harder if there is a reduction involved
|
||||
halo_size = jax.tree.map(lambda x: x // 2, halo_size)
|
||||
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
|
||||
grid_mesh = halo_exchange(grid_mesh,
|
||||
halo_extents=halo_extents,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue