mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-05-14 03:51:11 +00:00
jaxdecomp proto (#21)
* adding example of distributed solution * put back old functgion * update formatting * add halo exchange and slice pad * apply formatting * implement distributed optimized cic_paint * Use new cic_paint with halo * Fix seed for distributed normal * Wrap interpolation function to avoid all gather * Return normal order frequencies for single GPU * add example * format * add optimised bench script * times in ms * add lpt2 * update benchmark and add slurm * Visualize only final field * Update scripts/distributed_pm.py Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com> * Adjust pencil type for frequencies * fix painting issue with slabs * Shared operation in fourrier space now take inverted sharding axis for slabs * add assert to make pyright happy * adjust test for hpc-plotter * add PMWD test * bench * format * added github workflow * fix formatting from main * Update for jaxDecomp pure JAX * revert single halo extent change * update for latest jaxDecomp * remove fourrier_space in autoshmap * make normal_field work with single controller * format * make distributed pm work in single controller * merge bench_pm * update to leapfrog * add a strict dependency on jaxdecomp * global mesh no longer needed * kernels.py no longer uses global mesh * quick fix in distributed * pm.py no longer uses global mesh * painting.py no longer uses global mesh * update demo script * quick fix in kernels * quick fix in distributed * update demo * merge hugos LPT2 code * format * Small fix * format * remove duplicate get_ode_fn * update visualizer * update compensate CIC * By default check_rep is false for shard_map * remove experimental distributed code * update PGDCorrection and neural ode to use new fft3d * jaxDecomp pfft3d promotes to complex automatically * remove deprecated stuff * fix painting issue with read_cic * use jnp interp instead of jc interp * delete old slurms * add notebook examples * apply formatting * add distributed zeros * fix code in LPT2 * jit cic_paint * update notebooks * apply formating * get local shape and zeros can be used by users * add a user facing function to create uniform particle grid * use jax interp instead of jax_cosmo * use float64 for enmeshing * Allow applying weights with relative cic paint * Weights can be traced * remove script folder * update example notebooks * delete outdated design file * add readme for tutorials * update readme * fix small error * forgot particles in multi host * clarifying why cic_paint_dx is slower * clarifying the halo size dependence on the box size * ability to choose snapshots number with MultiHost script * Adding animation notebook * Put plotting in package * Add finite difference laplace kernel + powerspec functions from Hugo Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com> * Put plotting utils in package * By default use absoulute painting with * update code * update notebooks * add tests * Upgrade setup.py to pyproject * Format * format tests * update test dependencies * add test workflow * fix deprecated FftType in jaxpm.kernels * Add aboucaud comments * JAX version is 0.4.35 until Diffrax new release * add numpy explicitly as dependency for tests * fix install order for tests * add numpy to be installed * enforce no build isolation for fastpm * pip install jaxpm test without build isolation * bump jaxdecomp version * revert test workflow * remove outdated tests --------- Co-authored-by: EiffL <fr.eiffel@gmail.com> Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com> Co-authored-by: Wassim KABALAN <wassim@apc.in2p3.fr> Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com> Former-commit-id: 8c2e823d4669eac712089bf7f85ffb7912e8232d
This commit is contained in:
parent
a0a79277e5
commit
df8602b318
26 changed files with 1871 additions and 434 deletions
190
jaxpm/painting_utils.py
Normal file
190
jaxpm/painting_utils.py
Normal file
|
@ -0,0 +1,190 @@
|
|||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.lax import scan
|
||||
|
||||
|
||||
def _chunk_split(ptcl_num, chunk_size, *arrays):
|
||||
"""Split and reshape particle arrays into chunks and remainders, with the remainders
|
||||
preceding the chunks. 0D ones are duplicated as full arrays in the chunks."""
|
||||
chunk_size = ptcl_num if chunk_size is None else min(chunk_size, ptcl_num)
|
||||
remainder_size = ptcl_num % chunk_size
|
||||
chunk_num = ptcl_num // chunk_size
|
||||
|
||||
remainder = None
|
||||
chunks = arrays
|
||||
if remainder_size:
|
||||
remainder = [x[:remainder_size] if x.ndim != 0 else x for x in arrays]
|
||||
chunks = [x[remainder_size:] if x.ndim != 0 else x for x in arrays]
|
||||
|
||||
# `scan` triggers errors in scatter and gather without the `full`
|
||||
chunks = [
|
||||
x.reshape(chunk_num, chunk_size, *x.shape[1:])
|
||||
if x.ndim != 0 else jnp.full(chunk_num, x) for x in chunks
|
||||
]
|
||||
|
||||
return remainder, chunks
|
||||
|
||||
|
||||
def enmesh(base_indices, displacements, cell_size, base_shape, offset,
|
||||
new_cell_size, new_shape):
|
||||
"""Multilinear enmeshing."""
|
||||
base_indices = jnp.asarray(base_indices)
|
||||
displacements = jnp.asarray(displacements)
|
||||
with jax.experimental.enable_x64():
|
||||
cell_size = jnp.float64(
|
||||
cell_size) if new_cell_size is not None else jnp.array(
|
||||
cell_size, dtype=displacements.dtype)
|
||||
if base_shape is not None:
|
||||
base_shape = jnp.array(base_shape, dtype=base_indices.dtype)
|
||||
offset = jnp.float64(offset)
|
||||
if new_cell_size is not None:
|
||||
new_cell_size = jnp.float64(new_cell_size)
|
||||
if new_shape is not None:
|
||||
new_shape = jnp.array(new_shape, dtype=base_indices.dtype)
|
||||
|
||||
spatial_dim = base_indices.shape[1]
|
||||
neighbor_offsets = (
|
||||
jnp.arange(2**spatial_dim, dtype=base_indices.dtype)[:, jnp.newaxis] >>
|
||||
jnp.arange(spatial_dim, dtype=base_indices.dtype)) & 1
|
||||
|
||||
if new_cell_size is not None:
|
||||
particle_positions = base_indices * cell_size + displacements - offset
|
||||
particle_positions = particle_positions[:, jnp.
|
||||
newaxis] # insert neighbor axis
|
||||
new_indices = particle_positions + neighbor_offsets * new_cell_size # multilinear
|
||||
|
||||
if base_shape is not None:
|
||||
grid_length = base_shape * cell_size
|
||||
new_indices %= grid_length
|
||||
|
||||
new_indices //= new_cell_size
|
||||
new_displacements = particle_positions - new_indices * new_cell_size
|
||||
|
||||
if base_shape is not None:
|
||||
new_displacements -= jnp.rint(
|
||||
new_displacements / grid_length
|
||||
) * grid_length # also abs(new_displacements) < new_cell_size is expected
|
||||
|
||||
new_indices = new_indices.astype(base_indices.dtype)
|
||||
new_displacements = new_displacements.astype(displacements.dtype)
|
||||
new_cell_size = new_cell_size.astype(displacements.dtype)
|
||||
|
||||
new_displacements /= new_cell_size
|
||||
else:
|
||||
offset_indices, offset_displacements = jnp.divmod(offset, cell_size)
|
||||
base_indices -= offset_indices.astype(base_indices.dtype)
|
||||
displacements -= offset_displacements.astype(displacements.dtype)
|
||||
|
||||
# insert neighbor axis
|
||||
base_indices = base_indices[:, jnp.newaxis]
|
||||
displacements = displacements[:, jnp.newaxis]
|
||||
|
||||
# multilinear
|
||||
displacements /= cell_size
|
||||
new_indices = jnp.floor(displacements).astype(base_indices.dtype)
|
||||
new_indices += neighbor_offsets
|
||||
new_displacements = displacements - new_indices
|
||||
new_indices += base_indices
|
||||
|
||||
if base_shape is not None:
|
||||
new_indices %= base_shape
|
||||
|
||||
weights = 1 - jnp.abs(new_displacements)
|
||||
|
||||
if base_shape is None and new_shape is not None: # all new_indices >= 0 if base_shape is not None
|
||||
new_indices = jnp.where(new_indices < 0, new_shape, new_indices)
|
||||
|
||||
weights = weights.prod(axis=-1)
|
||||
|
||||
return new_indices, weights
|
||||
|
||||
|
||||
def _scatter_chunk(carry, chunk):
|
||||
mesh, offset, cell_size, mesh_shape = carry
|
||||
pmid, disp, val = chunk
|
||||
spatial_ndim = pmid.shape[1]
|
||||
spatial_shape = mesh.shape
|
||||
|
||||
# multilinear mesh indices and fractions
|
||||
ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size,
|
||||
spatial_shape)
|
||||
# scatter
|
||||
ind = tuple(ind[..., i] for i in range(spatial_ndim))
|
||||
mesh = mesh.at[ind].add(jnp.multiply(jnp.expand_dims(val, axis=-1), frac))
|
||||
carry = mesh, offset, cell_size, mesh_shape
|
||||
return carry, None
|
||||
|
||||
|
||||
def scatter(pmid,
|
||||
disp,
|
||||
mesh,
|
||||
chunk_size=2**24,
|
||||
val=1.,
|
||||
offset=0,
|
||||
cell_size=1.):
|
||||
ptcl_num, spatial_ndim = pmid.shape
|
||||
val = jnp.asarray(val)
|
||||
mesh = jnp.asarray(mesh)
|
||||
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
|
||||
carry = mesh, offset, cell_size, mesh.shape
|
||||
if remainder is not None:
|
||||
carry = _scatter_chunk(carry, remainder)[0]
|
||||
carry = scan(_scatter_chunk, carry, chunks)[0]
|
||||
mesh = carry[0]
|
||||
return mesh
|
||||
|
||||
|
||||
def _chunk_cat(remainder_array, chunked_array):
|
||||
"""Reshape and concatenate one remainder and one chunked particle arrays."""
|
||||
array = chunked_array.reshape(-1, *chunked_array.shape[2:])
|
||||
|
||||
if remainder_array is not None:
|
||||
array = jnp.concatenate((remainder_array, array), axis=0)
|
||||
|
||||
return array
|
||||
|
||||
|
||||
def gather(pmid, disp, mesh, chunk_size=2**24, val=0, offset=0, cell_size=1.):
|
||||
ptcl_num, spatial_ndim = pmid.shape
|
||||
|
||||
mesh = jnp.asarray(mesh)
|
||||
|
||||
val = jnp.asarray(val)
|
||||
|
||||
if mesh.shape[spatial_ndim:] != val.shape[1:]:
|
||||
raise ValueError('channel shape mismatch: '
|
||||
f'{mesh.shape[spatial_ndim:]} != {val.shape[1:]}')
|
||||
|
||||
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
|
||||
|
||||
carry = mesh, offset, cell_size, mesh.shape
|
||||
val_0 = None
|
||||
if remainder is not None:
|
||||
val_0 = _gather_chunk(carry, remainder)[1]
|
||||
val = scan(_gather_chunk, carry, chunks)[1]
|
||||
|
||||
val = _chunk_cat(val_0, val)
|
||||
|
||||
return val
|
||||
|
||||
|
||||
def _gather_chunk(carry, chunk):
|
||||
mesh, offset, cell_size, mesh_shape = carry
|
||||
pmid, disp, val = chunk
|
||||
|
||||
spatial_ndim = pmid.shape[1]
|
||||
|
||||
spatial_shape = mesh.shape[:spatial_ndim]
|
||||
chan_ndim = mesh.ndim - spatial_ndim
|
||||
chan_axis = tuple(range(-chan_ndim, 0))
|
||||
|
||||
# multilinear mesh indices and fractions
|
||||
ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size,
|
||||
spatial_shape)
|
||||
|
||||
# gather
|
||||
ind = tuple(ind[..., i] for i in range(spatial_ndim))
|
||||
frac = jnp.expand_dims(frac, chan_axis)
|
||||
val += (mesh.at[ind].get(mode='drop', fill_value=0) * frac).sum(axis=1)
|
||||
|
||||
return carry, val
|
Loading…
Add table
Add a link
Reference in a new issue