forked from Aquila-Consortium/JaxPM_highres
* adding example of distributed solution * put back old functgion * update formatting * add halo exchange and slice pad * apply formatting * implement distributed optimized cic_paint * Use new cic_paint with halo * Fix seed for distributed normal * Wrap interpolation function to avoid all gather * Return normal order frequencies for single GPU * add example * format * add optimised bench script * times in ms * add lpt2 * update benchmark and add slurm * Visualize only final field * Update scripts/distributed_pm.py Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com> * Adjust pencil type for frequencies * fix painting issue with slabs * Shared operation in fourrier space now take inverted sharding axis for slabs * add assert to make pyright happy * adjust test for hpc-plotter * add PMWD test * bench * format * added github workflow * fix formatting from main * Update for jaxDecomp pure JAX * revert single halo extent change * update for latest jaxDecomp * remove fourrier_space in autoshmap * make normal_field work with single controller * format * make distributed pm work in single controller * merge bench_pm * update to leapfrog * add a strict dependency on jaxdecomp * global mesh no longer needed * kernels.py no longer uses global mesh * quick fix in distributed * pm.py no longer uses global mesh * painting.py no longer uses global mesh * update demo script * quick fix in kernels * quick fix in distributed * update demo * merge hugos LPT2 code * format * Small fix * format * remove duplicate get_ode_fn * update visualizer * update compensate CIC * By default check_rep is false for shard_map * remove experimental distributed code * update PGDCorrection and neural ode to use new fft3d * jaxDecomp pfft3d promotes to complex automatically * remove deprecated stuff * fix painting issue with read_cic * use jnp interp instead of jc interp * delete old slurms * add notebook examples * apply formatting * add distributed zeros * fix code in LPT2 * jit cic_paint * update notebooks * apply formating * get local shape and zeros can be used by users * add a user facing function to create uniform particle grid * use jax interp instead of jax_cosmo * use float64 for enmeshing * Allow applying weights with relative cic paint * Weights can be traced * remove script folder * update example notebooks * delete outdated design file * add readme for tutorials * update readme * fix small error * forgot particles in multi host * clarifying why cic_paint_dx is slower * clarifying the halo size dependence on the box size * ability to choose snapshots number with MultiHost script * Adding animation notebook * Put plotting in package * Add finite difference laplace kernel + powerspec functions from Hugo Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com> * Put plotting utils in package * By default use absoulute painting with * update code * update notebooks * add tests * Upgrade setup.py to pyproject * Format * format tests * update test dependencies * add test workflow * fix deprecated FftType in jaxpm.kernels * Add aboucaud comments * JAX version is 0.4.35 until Diffrax new release * add numpy explicitly as dependency for tests * fix install order for tests * add numpy to be installed * enforce no build isolation for fastpm * pip install jaxpm test without build isolation * bump jaxdecomp version * revert test workflow * remove outdated tests --------- Co-authored-by: EiffL <fr.eiffel@gmail.com> Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com> Co-authored-by: Wassim KABALAN <wassim@apc.in2p3.fr> Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com> Former-commit-id: 8c2e823d4669eac712089bf7f85ffb7912e8232d
259 lines
9.9 KiB
Python
259 lines
9.9 KiB
Python
from functools import partial
|
|
|
|
import jax
|
|
import jax.lax as lax
|
|
import jax.numpy as jnp
|
|
from jax.sharding import NamedSharding
|
|
from jax.sharding import PartitionSpec as P
|
|
|
|
from jaxpm.distributed import (autoshmap, fft3d, get_halo_size, halo_exchange,
|
|
ifft3d, slice_pad, slice_unpad)
|
|
from jaxpm.kernels import cic_compensation, fftk
|
|
from jaxpm.painting_utils import gather, scatter
|
|
|
|
|
|
def _cic_paint_impl(grid_mesh, positions, weight=None):
|
|
""" Paints positions onto mesh
|
|
mesh: [nx, ny, nz]
|
|
displacement field: [nx, ny, nz, 3]
|
|
"""
|
|
|
|
positions = positions.reshape([-1, 3])
|
|
positions = jnp.expand_dims(positions, 1)
|
|
floor = jnp.floor(positions)
|
|
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
|
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
|
|
|
|
neighboor_coords = floor + connection
|
|
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
|
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
|
if weight is not None:
|
|
if jnp.isscalar(weight):
|
|
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
|
|
else:
|
|
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]),
|
|
kernel)
|
|
|
|
neighboor_coords = jnp.mod(
|
|
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
|
|
jnp.array(grid_mesh.shape))
|
|
|
|
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
|
|
inserted_window_dims=(0, 1, 2),
|
|
scatter_dims_to_operand_dims=(0, 1,
|
|
2))
|
|
mesh = lax.scatter_add(grid_mesh, neighboor_coords,
|
|
kernel.reshape([-1, 8]), dnums)
|
|
return mesh
|
|
|
|
|
|
@partial(jax.jit, static_argnums=(3, 4))
|
|
def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
|
|
|
|
positions = positions.reshape((*grid_mesh.shape, 3))
|
|
|
|
halo_size, halo_extents = get_halo_size(halo_size, sharding)
|
|
grid_mesh = slice_pad(grid_mesh, halo_size, sharding)
|
|
|
|
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
|
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
|
grid_mesh = autoshmap(_cic_paint_impl,
|
|
gpu_mesh=gpu_mesh,
|
|
in_specs=(spec, spec, P()),
|
|
out_specs=spec)(grid_mesh, positions, weight)
|
|
grid_mesh = halo_exchange(grid_mesh,
|
|
halo_extents=halo_extents,
|
|
halo_periods=(True, True))
|
|
grid_mesh = slice_unpad(grid_mesh, halo_size, sharding)
|
|
|
|
return grid_mesh
|
|
|
|
|
|
def _cic_read_impl(grid_mesh, positions):
|
|
""" Paints positions onto mesh
|
|
mesh: [nx, ny, nz]
|
|
positions: [nx,ny,nz, 3]
|
|
"""
|
|
# Save original shape for reshaping output later
|
|
original_shape = positions.shape
|
|
# Reshape positions to a flat list of 3D coordinates
|
|
positions = positions.reshape([-1, 3])
|
|
# Expand dimensions to calculate neighbor coordinates
|
|
positions = jnp.expand_dims(positions, 1)
|
|
# Floor the positions to get the base grid cell for each particle
|
|
floor = jnp.floor(positions)
|
|
# Define connections to calculate all neighbor coordinates
|
|
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
|
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
|
|
# Calculate the 8 neighboring coordinates
|
|
neighboor_coords = floor + connection
|
|
# Calculate kernel weights based on distance from each neighboring coordinate
|
|
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
|
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
|
# Modulo operation to wrap around edges if necessary
|
|
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'),
|
|
jnp.array(grid_mesh.shape))
|
|
# Ensure grid_mesh shape is as expected
|
|
# Retrieve values from grid_mesh at each neighboring coordinate and multiply by kernel
|
|
return (grid_mesh[neighboor_coords[..., 0],
|
|
neighboor_coords[..., 1],
|
|
neighboor_coords[..., 2]] * kernel).sum(axis=-1).reshape(original_shape[:-1]) # yapf: disable
|
|
|
|
|
|
@partial(jax.jit, static_argnums=(2, 3))
|
|
def cic_read(grid_mesh, positions, halo_size=0, sharding=None):
|
|
|
|
original_shape = positions.shape
|
|
positions = positions.reshape((*grid_mesh.shape, 3))
|
|
|
|
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
|
|
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
|
|
grid_mesh = halo_exchange(grid_mesh,
|
|
halo_extents=halo_extents,
|
|
halo_periods=(True, True))
|
|
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
|
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
|
|
|
displacement = autoshmap(_cic_read_impl,
|
|
gpu_mesh=gpu_mesh,
|
|
in_specs=(spec, spec),
|
|
out_specs=spec)(grid_mesh, positions)
|
|
|
|
return displacement.reshape(original_shape[:-1])
|
|
|
|
|
|
def cic_paint_2d(mesh, positions, weight):
|
|
""" Paints positions onto a 2d mesh
|
|
mesh: [nx, ny]
|
|
positions: [npart, 2]
|
|
weight: [npart]
|
|
"""
|
|
positions = jnp.expand_dims(positions, 1)
|
|
floor = jnp.floor(positions)
|
|
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
|
|
|
|
neighboor_coords = floor + connection
|
|
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
|
kernel = kernel[..., 0] * kernel[..., 1]
|
|
if weight is not None:
|
|
kernel = kernel * weight[..., jnp.newaxis]
|
|
|
|
neighboor_coords = jnp.mod(
|
|
neighboor_coords.reshape([-1, 4, 2]).astype('int32'),
|
|
jnp.array(mesh.shape))
|
|
|
|
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
|
|
inserted_window_dims=(0, 1),
|
|
scatter_dims_to_operand_dims=(0,
|
|
1))
|
|
mesh = lax.scatter_add(mesh, neighboor_coords, kernel.reshape([-1, 4]),
|
|
dnums)
|
|
return mesh
|
|
|
|
|
|
def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
|
|
|
|
halo_x, _ = halo_size[0]
|
|
halo_y, _ = halo_size[1]
|
|
|
|
original_shape = displacements.shape
|
|
particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32')
|
|
if not jnp.isscalar(weight):
|
|
if weight.shape != original_shape[:-1]:
|
|
raise ValueError("Weight shape must match particle shape")
|
|
else:
|
|
weight = weight.flatten()
|
|
# Padding is forced to be zero in a single gpu run
|
|
|
|
a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]),
|
|
jnp.arange(particle_mesh.shape[1]),
|
|
jnp.arange(particle_mesh.shape[2]),
|
|
indexing='ij')
|
|
|
|
particle_mesh = jnp.pad(particle_mesh, halo_size)
|
|
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
|
|
return scatter(pmid.reshape([-1, 3]),
|
|
displacements.reshape([-1, 3]),
|
|
particle_mesh,
|
|
chunk_size=2**24,
|
|
val=weight)
|
|
|
|
|
|
@partial(jax.jit, static_argnums=(1, 2, 4))
|
|
def cic_paint_dx(displacements,
|
|
halo_size=0,
|
|
sharding=None,
|
|
weight=1.0,
|
|
chunk_size=2**24):
|
|
|
|
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
|
|
|
|
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
|
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
|
grid_mesh = autoshmap(partial(_cic_paint_dx_impl,
|
|
halo_size=halo_size,
|
|
weight=weight,
|
|
chunk_size=chunk_size),
|
|
gpu_mesh=gpu_mesh,
|
|
in_specs=spec,
|
|
out_specs=spec)(displacements)
|
|
|
|
grid_mesh = halo_exchange(grid_mesh,
|
|
halo_extents=halo_extents,
|
|
halo_periods=(True, True))
|
|
grid_mesh = slice_unpad(grid_mesh, halo_size, sharding)
|
|
return grid_mesh
|
|
|
|
|
|
def _cic_read_dx_impl(grid_mesh, disp, halo_size):
|
|
|
|
halo_x, _ = halo_size[0]
|
|
halo_y, _ = halo_size[1]
|
|
|
|
original_shape = [
|
|
dim - 2 * halo[0] for dim, halo in zip(grid_mesh.shape, halo_size)
|
|
]
|
|
a, b, c = jnp.meshgrid(jnp.arange(original_shape[0]),
|
|
jnp.arange(original_shape[1]),
|
|
jnp.arange(original_shape[2]),
|
|
indexing='ij')
|
|
|
|
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
|
|
|
|
pmid = pmid.reshape([-1, 3])
|
|
disp = disp.reshape([-1, 3])
|
|
|
|
return gather(pmid, disp, grid_mesh).reshape(original_shape)
|
|
|
|
|
|
@partial(jax.jit, static_argnums=(2, 3))
|
|
def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None):
|
|
|
|
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
|
|
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
|
|
grid_mesh = halo_exchange(grid_mesh,
|
|
halo_extents=halo_extents,
|
|
halo_periods=(True, True))
|
|
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
|
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
|
displacements = autoshmap(partial(_cic_read_dx_impl, halo_size=halo_size),
|
|
gpu_mesh=gpu_mesh,
|
|
in_specs=(spec),
|
|
out_specs=spec)(grid_mesh, disp)
|
|
|
|
return displacements
|
|
|
|
|
|
def compensate_cic(field):
|
|
"""
|
|
Compensate for CiC painting
|
|
Args:
|
|
field: input 3D cic-painted field
|
|
Returns:
|
|
compensated_field
|
|
"""
|
|
delta_k = fft3d(field)
|
|
|
|
kvec = fftk(delta_k)
|
|
delta_k = cic_compensation(kvec) * delta_k
|
|
return ifft3d(delta_k)
|