JaxPM/notebooks/05-MultiHost_PM.py
Wassim KABALAN df8602b318 jaxdecomp proto (#21)
* adding example of distributed solution

* put back old functgion

* update formatting

* add halo exchange and slice pad

* apply formatting

* implement distributed optimized cic_paint

* Use new cic_paint with halo

* Fix seed for distributed normal

* Wrap interpolation function to avoid all gather

* Return normal order frequencies for single GPU

* add example

* format

* add optimised bench script

* times in ms

* add lpt2

* update benchmark and add slurm

* Visualize only final field

* Update scripts/distributed_pm.py

Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>

* Adjust pencil type for frequencies

* fix painting issue with slabs

* Shared operation in fourrier space now take inverted sharding axis for
slabs

* add assert to make pyright happy

* adjust test for hpc-plotter

* add PMWD test

* bench

* format

* added github workflow

* fix formatting from main

* Update for jaxDecomp pure JAX

* revert single halo extent change

* update for latest jaxDecomp

* remove fourrier_space in autoshmap

* make normal_field work with single controller

* format

* make distributed pm work in single controller

* merge bench_pm

* update to leapfrog

* add a strict dependency on jaxdecomp

* global mesh no longer needed

* kernels.py no longer uses global mesh

* quick fix in distributed

* pm.py no longer uses global mesh

* painting.py no longer uses global mesh

* update demo script

* quick fix in kernels

* quick fix in distributed

* update demo

* merge hugos LPT2 code

* format

* Small fix

* format

* remove duplicate get_ode_fn

* update visualizer

* update compensate CIC

* By default check_rep is false for shard_map

* remove experimental distributed code

* update PGDCorrection and neural ode to use new fft3d

* jaxDecomp pfft3d promotes to complex automatically

* remove deprecated stuff

* fix painting issue with read_cic

* use jnp interp instead of jc interp

* delete old slurms

* add notebook examples

* apply formatting

* add distributed zeros

* fix code in LPT2

* jit cic_paint

* update notebooks

* apply formating

* get local shape and zeros can be used by users

* add a user facing function to create uniform particle grid

* use jax interp instead of jax_cosmo

* use float64 for enmeshing

* Allow applying weights with relative cic paint

* Weights can be traced

* remove script folder

* update example notebooks

* delete outdated design file

* add readme for tutorials

* update readme

* fix small error

* forgot particles in multi host

* clarifying why cic_paint_dx is slower

* clarifying the halo size dependence on the box size

* ability to choose snapshots number with MultiHost script

* Adding animation notebook

* Put plotting in package

* Add finite difference laplace kernel + powerspec functions from Hugo

Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com>

* Put plotting utils in package

* By default use absoulute painting with

* update code

* update notebooks

* add tests

* Upgrade setup.py to pyproject

* Format

* format tests

* update test dependencies

* add test workflow

* fix deprecated FftType in jaxpm.kernels

* Add aboucaud comments

* JAX version is 0.4.35 until Diffrax new release

* add numpy explicitly as dependency for tests

* fix install order for tests

* add numpy to be installed

* enforce no build isolation for fastpm

* pip install jaxpm test without build isolation

* bump jaxdecomp version

* revert test workflow

* remove outdated tests

---------

Co-authored-by: EiffL <fr.eiffel@gmail.com>
Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>
Co-authored-by: Wassim KABALAN <wassim@apc.in2p3.fr>
Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com>
Former-commit-id: 8c2e823d4669eac712089bf7f85ffb7912e8232d
2024-12-20 05:44:02 -05:00

179 lines
5.9 KiB
Python

import os
os.environ["EQX_ON_ERROR"] = "nan" # avoid an allgather caused by diffrax
import jax
jax.distributed.initialize()
rank = jax.process_index()
size = jax.process_count()
if rank == 0:
print(f"SIZE is {jax.device_count()}")
import argparse
from functools import partial
import jax.numpy as jnp
import jax_cosmo as jc
import numpy as np
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
PIDController, SaveAt, diffeqsolve)
from jax.experimental.mesh_utils import create_device_mesh
from jax.experimental.multihost_utils import process_allgather
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from jaxpm.kernels import interpolate_power_spectrum
from jaxpm.painting import cic_paint_dx
from jaxpm.pm import linear_field, lpt, make_diffrax_ode
all_gather = partial(process_allgather, tiled=True)
def parse_arguments():
parser = argparse.ArgumentParser(
description="Run a cosmological simulation with JAX.")
parser.add_argument(
"-p",
"--pdims",
type=int,
nargs=2,
default=[1, jax.devices()],
help="Processor grid dimensions as two integers (e.g., 2 4).")
parser.add_argument(
"-m",
"--mesh_shape",
type=int,
nargs=3,
default=[512, 512, 512],
help="Shape of the simulation mesh as three values (e.g., 512 512 512)."
)
parser.add_argument(
"-b",
"--box_size",
type=float,
nargs=3,
default=[500.0, 500.0, 500.0],
help=
"Box size of the simulation as three values (e.g., 500.0 500.0 1000.0)."
)
parser.add_argument(
"-st",
"--snapshots",
type=int,
default=2,
help="Number of snapshots to save during the simulation.")
parser.add_argument("-H",
"--halo_size",
type=int,
default=64,
help="Halo size for the simulation.")
parser.add_argument("-s",
"--solver",
type=str,
choices=['leapfrog', 'dopri8'],
default='leapfrog',
help="ODE solver choice: 'leapfrog' or 'dopri8'.")
return parser.parse_args()
def create_mesh_and_sharding(pdims):
devices = create_device_mesh(pdims)
mesh = Mesh(devices, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
return mesh, sharding
@partial(jax.jit, static_argnums=(2, 3, 4, 5, 6))
def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
solver_choice, nb_snapshots, sharding):
k = jnp.logspace(-4, 1, 128)
pk = jc.power.linear_matter_power(
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)
initial_conditions = linear_field(mesh_shape,
box_size,
pk_fn,
seed=jax.random.PRNGKey(0),
sharding=sharding)
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
dx, p, _ = lpt(cosmo,
initial_conditions,
a=0.1,
halo_size=halo_size,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
# Choose solver
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()
stepsize_controller = ConstantStepSize(
) if solver_choice == "leapfrog" else PIDController(rtol=1e-5, atol=1e-5)
res = diffeqsolve(ode_fn,
solver,
t0=0.1,
t1=1.,
dt0=0.01,
y0=jnp.stack([dx, p], axis=0),
args=cosmo,
saveat=SaveAt(ts=jnp.linspace(0.2, 1., nb_snapshots)),
stepsize_controller=stepsize_controller)
ode_fields = [
cic_paint_dx(sol[0], halo_size=halo_size, sharding=sharding)
for sol in res.ys
]
lpt_field = cic_paint_dx(dx, halo_size=halo_size, sharding=sharding)
return initial_conditions, lpt_field, ode_fields, res.stats
def main():
args = parse_arguments()
mesh_shape = args.mesh_shape
box_size = args.box_size
halo_size = args.halo_size
solver_choice = args.solver
nb_snapshots = args.snapshots
sharding = create_mesh_and_sharding(args.pdims)
initial_conditions, lpt_displacements, ode_solutions, solver_stats = run_simulation(
0.25, 0.8, tuple(mesh_shape), tuple(box_size), halo_size,
solver_choice, nb_snapshots, sharding)
if rank == 0:
os.makedirs("fields", exist_ok=True)
print(f"[{rank}] Simulation done")
print(f"Solver stats: {solver_stats}")
# Save initial conditions
initial_conditions_g = all_gather(initial_conditions)
if rank == 0:
print(f"[{rank}] Saving initial_conditions")
np.save("fields/initial_conditions.npy", initial_conditions_g)
print(f"[{rank}] initial_conditions saved")
del initial_conditions_g, initial_conditions
# Save LPT displacements
lpt_displacements_g = all_gather(lpt_displacements)
if rank == 0:
print(f"[{rank}] Saving lpt_displacements")
np.save("fields/lpt_displacements.npy", lpt_displacements_g)
print(f"[{rank}] lpt_displacements saved")
del lpt_displacements_g, lpt_displacements
# Save each ODE solution separately
for i, sol in enumerate(ode_solutions):
sol_g = all_gather(sol)
if rank == 0:
print(f"[{rank}] Saving ode_solution_{i}")
np.save(f"fields/ode_solution_{i}.npy", sol_g)
print(f"[{rank}] ode_solution_{i} saved")
del sol_g
if __name__ == "__main__":
main()