forked from Aquila-Consortium/JaxPM_highres
* adding example of distributed solution * put back old functgion * update formatting * add halo exchange and slice pad * apply formatting * implement distributed optimized cic_paint * Use new cic_paint with halo * Fix seed for distributed normal * Wrap interpolation function to avoid all gather * Return normal order frequencies for single GPU * add example * format * add optimised bench script * times in ms * add lpt2 * update benchmark and add slurm * Visualize only final field * Update scripts/distributed_pm.py Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com> * Adjust pencil type for frequencies * fix painting issue with slabs * Shared operation in fourrier space now take inverted sharding axis for slabs * add assert to make pyright happy * adjust test for hpc-plotter * add PMWD test * bench * format * added github workflow * fix formatting from main * Update for jaxDecomp pure JAX * revert single halo extent change * update for latest jaxDecomp * remove fourrier_space in autoshmap * make normal_field work with single controller * format * make distributed pm work in single controller * merge bench_pm * update to leapfrog * add a strict dependency on jaxdecomp * global mesh no longer needed * kernels.py no longer uses global mesh * quick fix in distributed * pm.py no longer uses global mesh * painting.py no longer uses global mesh * update demo script * quick fix in kernels * quick fix in distributed * update demo * merge hugos LPT2 code * format * Small fix * format * remove duplicate get_ode_fn * update visualizer * update compensate CIC * By default check_rep is false for shard_map * remove experimental distributed code * update PGDCorrection and neural ode to use new fft3d * jaxDecomp pfft3d promotes to complex automatically * remove deprecated stuff * fix painting issue with read_cic * use jnp interp instead of jc interp * delete old slurms * add notebook examples * apply formatting * add distributed zeros * fix code in LPT2 * jit cic_paint * update notebooks * apply formating * get local shape and zeros can be used by users * add a user facing function to create uniform particle grid * use jax interp instead of jax_cosmo * use float64 for enmeshing * Allow applying weights with relative cic paint * Weights can be traced * remove script folder * update example notebooks * delete outdated design file * add readme for tutorials * update readme * fix small error * forgot particles in multi host * clarifying why cic_paint_dx is slower * clarifying the halo size dependence on the box size * ability to choose snapshots number with MultiHost script * Adding animation notebook * Put plotting in package * Add finite difference laplace kernel + powerspec functions from Hugo Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com> * Put plotting utils in package * By default use absoulute painting with * update code * update notebooks * add tests * Upgrade setup.py to pyproject * Format * format tests * update test dependencies * add test workflow * fix deprecated FftType in jaxpm.kernels * Add aboucaud comments * JAX version is 0.4.35 until Diffrax new release * add numpy explicitly as dependency for tests * fix install order for tests * add numpy to be installed * enforce no build isolation for fastpm * pip install jaxpm test without build isolation * bump jaxdecomp version * revert test workflow * remove outdated tests --------- Co-authored-by: EiffL <fr.eiffel@gmail.com> Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com> Co-authored-by: Wassim KABALAN <wassim@apc.in2p3.fr> Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com> Former-commit-id: 8c2e823d4669eac712089bf7f85ffb7912e8232d
152 lines
5.4 KiB
Python
152 lines
5.4 KiB
Python
from conftest import initialize_distributed
|
|
|
|
initialize_distributed() # ignore : E402
|
|
|
|
import jax # noqa : E402
|
|
import jax.numpy as jnp # noqa : E402
|
|
import pytest # noqa : E402
|
|
from diffrax import SaveAt # noqa : E402
|
|
from diffrax import Dopri5, ODETerm, PIDController, diffeqsolve
|
|
from helpers import MSE # noqa : E402
|
|
from jax import lax # noqa : E402
|
|
from jax.experimental.multihost_utils import process_allgather # noqa : E402
|
|
from jax.sharding import NamedSharding
|
|
from jax.sharding import PartitionSpec as P # noqa : E402
|
|
|
|
from jaxpm.distributed import uniform_particles # noqa : E402
|
|
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
|
|
from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402
|
|
|
|
_TOLERANCE = 3.0 # 🙃🙃
|
|
|
|
|
|
@pytest.mark.distributed
|
|
@pytest.mark.parametrize("order", [1, 2])
|
|
@pytest.mark.parametrize("absolute_painting", [True, False])
|
|
def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
|
absolute_painting):
|
|
|
|
mesh_shape, box_shape = simulation_config
|
|
# SINGLE DEVICE RUN
|
|
cosmo._workspace = {}
|
|
if absolute_painting:
|
|
particles = uniform_particles(mesh_shape)
|
|
# Initial displacement
|
|
dx, p, _ = lpt(cosmo,
|
|
initial_conditions,
|
|
particles,
|
|
a=0.1,
|
|
order=order)
|
|
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
|
|
y0 = jnp.stack([particles + dx, p])
|
|
else:
|
|
dx, p, _ = lpt(cosmo, initial_conditions, a=0.1, order=order)
|
|
ode_fn = ODETerm(
|
|
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
|
|
y0 = jnp.stack([dx, p])
|
|
|
|
solver = Dopri5()
|
|
controller = PIDController(rtol=1e-8,
|
|
atol=1e-8,
|
|
pcoeff=0.4,
|
|
icoeff=1,
|
|
dcoeff=0)
|
|
|
|
saveat = SaveAt(t1=True)
|
|
|
|
solutions = diffeqsolve(ode_fn,
|
|
solver,
|
|
t0=0.1,
|
|
t1=1.0,
|
|
dt0=None,
|
|
y0=y0,
|
|
stepsize_controller=controller,
|
|
saveat=saveat)
|
|
|
|
if absolute_painting:
|
|
single_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape),
|
|
solutions.ys[-1, 0])
|
|
else:
|
|
single_device_final_field = cic_paint_dx(solutions.ys[-1, 0])
|
|
|
|
print("Done with single device run")
|
|
# MULTI DEVICE RUN
|
|
|
|
mesh = jax.make_mesh((1, 8), ('x', 'y'))
|
|
sharding = NamedSharding(mesh, P('x', 'y'))
|
|
halo_size = mesh_shape[0] // 2
|
|
|
|
initial_conditions = lax.with_sharding_constraint(initial_conditions,
|
|
sharding)
|
|
|
|
print(f"sharded initial conditions {initial_conditions.sharding}")
|
|
|
|
cosmo._workspace = {}
|
|
if absolute_painting:
|
|
particles = uniform_particles(mesh_shape, sharding=sharding)
|
|
# Initial displacement
|
|
dx, p, _ = lpt(cosmo,
|
|
initial_conditions,
|
|
particles,
|
|
a=0.1,
|
|
order=order,
|
|
halo_size=halo_size,
|
|
sharding=sharding)
|
|
|
|
ode_fn = ODETerm(
|
|
make_diffrax_ode(cosmo,
|
|
mesh_shape,
|
|
halo_size=halo_size,
|
|
sharding=sharding))
|
|
|
|
y0 = jnp.stack([particles + dx, p])
|
|
else:
|
|
dx, p, _ = lpt(cosmo,
|
|
initial_conditions,
|
|
a=0.1,
|
|
order=order,
|
|
halo_size=halo_size,
|
|
sharding=sharding)
|
|
ode_fn = ODETerm(
|
|
make_diffrax_ode(cosmo,
|
|
mesh_shape,
|
|
paint_absolute_pos=False,
|
|
halo_size=halo_size,
|
|
sharding=sharding))
|
|
y0 = jnp.stack([dx, p])
|
|
|
|
solver = Dopri5()
|
|
controller = PIDController(rtol=1e-8,
|
|
atol=1e-8,
|
|
pcoeff=0.4,
|
|
icoeff=1,
|
|
dcoeff=0)
|
|
|
|
saveat = SaveAt(t1=True)
|
|
|
|
solutions = diffeqsolve(ode_fn,
|
|
solver,
|
|
t0=0.1,
|
|
t1=1.0,
|
|
dt0=None,
|
|
y0=y0,
|
|
stepsize_controller=controller,
|
|
saveat=saveat)
|
|
|
|
if absolute_painting:
|
|
multi_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape),
|
|
solutions.ys[-1, 0],
|
|
halo_size=halo_size,
|
|
sharding=sharding)
|
|
else:
|
|
multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0],
|
|
halo_size=halo_size,
|
|
sharding=sharding)
|
|
|
|
multi_device_final_field = process_allgather(multi_device_final_field,
|
|
tiled=True)
|
|
|
|
mse = MSE(single_device_final_field, multi_device_final_field)
|
|
print(f"MSE is {mse}")
|
|
|
|
assert mse < _TOLERANCE
|