jaxdecomp proto (#21)

* adding example of distributed solution

* put back old functgion

* update formatting

* add halo exchange and slice pad

* apply formatting

* implement distributed optimized cic_paint

* Use new cic_paint with halo

* Fix seed for distributed normal

* Wrap interpolation function to avoid all gather

* Return normal order frequencies for single GPU

* add example

* format

* add optimised bench script

* times in ms

* add lpt2

* update benchmark and add slurm

* Visualize only final field

* Update scripts/distributed_pm.py

Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>

* Adjust pencil type for frequencies

* fix painting issue with slabs

* Shared operation in fourrier space now take inverted sharding axis for
slabs

* add assert to make pyright happy

* adjust test for hpc-plotter

* add PMWD test

* bench

* format

* added github workflow

* fix formatting from main

* Update for jaxDecomp pure JAX

* revert single halo extent change

* update for latest jaxDecomp

* remove fourrier_space in autoshmap

* make normal_field work with single controller

* format

* make distributed pm work in single controller

* merge bench_pm

* update to leapfrog

* add a strict dependency on jaxdecomp

* global mesh no longer needed

* kernels.py no longer uses global mesh

* quick fix in distributed

* pm.py no longer uses global mesh

* painting.py no longer uses global mesh

* update demo script

* quick fix in kernels

* quick fix in distributed

* update demo

* merge hugos LPT2 code

* format

* Small fix

* format

* remove duplicate get_ode_fn

* update visualizer

* update compensate CIC

* By default check_rep is false for shard_map

* remove experimental distributed code

* update PGDCorrection and neural ode to use new fft3d

* jaxDecomp pfft3d promotes to complex automatically

* remove deprecated stuff

* fix painting issue with read_cic

* use jnp interp instead of jc interp

* delete old slurms

* add notebook examples

* apply formatting

* add distributed zeros

* fix code in LPT2

* jit cic_paint

* update notebooks

* apply formating

* get local shape and zeros can be used by users

* add a user facing function to create uniform particle grid

* use jax interp instead of jax_cosmo

* use float64 for enmeshing

* Allow applying weights with relative cic paint

* Weights can be traced

* remove script folder

* update example notebooks

* delete outdated design file

* add readme for tutorials

* update readme

* fix small error

* forgot particles in multi host

* clarifying why cic_paint_dx is slower

* clarifying the halo size dependence on the box size

* ability to choose snapshots number with MultiHost script

* Adding animation notebook

* Put plotting in package

* Add finite difference laplace kernel + powerspec functions from Hugo

Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com>

* Put plotting utils in package

* By default use absoulute painting with

* update code

* update notebooks

* add tests

* Upgrade setup.py to pyproject

* Format

* format tests

* update test dependencies

* add test workflow

* fix deprecated FftType in jaxpm.kernels

* Add aboucaud comments

* JAX version is 0.4.35 until Diffrax new release

* add numpy explicitly as dependency for tests

* fix install order for tests

* add numpy to be installed

* enforce no build isolation for fastpm

* pip install jaxpm test without build isolation

* bump jaxdecomp version

* revert test workflow

* remove outdated tests

---------

Co-authored-by: EiffL <fr.eiffel@gmail.com>
Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>
Co-authored-by: Wassim KABALAN <wassim@apc.in2p3.fr>
Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com>
This commit is contained in:
Wassim KABALAN 2024-12-20 11:44:02 +01:00 committed by GitHub
parent cf88d680a3
commit bf44dfdea9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 4321 additions and 639 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1,179 @@
import os
os.environ["EQX_ON_ERROR"] = "nan" # avoid an allgather caused by diffrax
import jax
jax.distributed.initialize()
rank = jax.process_index()
size = jax.process_count()
if rank == 0:
print(f"SIZE is {jax.device_count()}")
import argparse
from functools import partial
import jax.numpy as jnp
import jax_cosmo as jc
import numpy as np
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
PIDController, SaveAt, diffeqsolve)
from jax.experimental.mesh_utils import create_device_mesh
from jax.experimental.multihost_utils import process_allgather
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from jaxpm.kernels import interpolate_power_spectrum
from jaxpm.painting import cic_paint_dx
from jaxpm.pm import linear_field, lpt, make_diffrax_ode
all_gather = partial(process_allgather, tiled=True)
def parse_arguments():
parser = argparse.ArgumentParser(
description="Run a cosmological simulation with JAX.")
parser.add_argument(
"-p",
"--pdims",
type=int,
nargs=2,
default=[1, jax.devices()],
help="Processor grid dimensions as two integers (e.g., 2 4).")
parser.add_argument(
"-m",
"--mesh_shape",
type=int,
nargs=3,
default=[512, 512, 512],
help="Shape of the simulation mesh as three values (e.g., 512 512 512)."
)
parser.add_argument(
"-b",
"--box_size",
type=float,
nargs=3,
default=[500.0, 500.0, 500.0],
help=
"Box size of the simulation as three values (e.g., 500.0 500.0 1000.0)."
)
parser.add_argument(
"-st",
"--snapshots",
type=int,
default=2,
help="Number of snapshots to save during the simulation.")
parser.add_argument("-H",
"--halo_size",
type=int,
default=64,
help="Halo size for the simulation.")
parser.add_argument("-s",
"--solver",
type=str,
choices=['leapfrog', 'dopri8'],
default='leapfrog',
help="ODE solver choice: 'leapfrog' or 'dopri8'.")
return parser.parse_args()
def create_mesh_and_sharding(pdims):
devices = create_device_mesh(pdims)
mesh = Mesh(devices, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
return mesh, sharding
@partial(jax.jit, static_argnums=(2, 3, 4, 5, 6))
def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
solver_choice, nb_snapshots, sharding):
k = jnp.logspace(-4, 1, 128)
pk = jc.power.linear_matter_power(
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)
initial_conditions = linear_field(mesh_shape,
box_size,
pk_fn,
seed=jax.random.PRNGKey(0),
sharding=sharding)
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
dx, p, _ = lpt(cosmo,
initial_conditions,
a=0.1,
halo_size=halo_size,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
# Choose solver
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()
stepsize_controller = ConstantStepSize(
) if solver_choice == "leapfrog" else PIDController(rtol=1e-5, atol=1e-5)
res = diffeqsolve(ode_fn,
solver,
t0=0.1,
t1=1.,
dt0=0.01,
y0=jnp.stack([dx, p], axis=0),
args=cosmo,
saveat=SaveAt(ts=jnp.linspace(0.2, 1., nb_snapshots)),
stepsize_controller=stepsize_controller)
ode_fields = [
cic_paint_dx(sol[0], halo_size=halo_size, sharding=sharding)
for sol in res.ys
]
lpt_field = cic_paint_dx(dx, halo_size=halo_size, sharding=sharding)
return initial_conditions, lpt_field, ode_fields, res.stats
def main():
args = parse_arguments()
mesh_shape = args.mesh_shape
box_size = args.box_size
halo_size = args.halo_size
solver_choice = args.solver
nb_snapshots = args.snapshots
sharding = create_mesh_and_sharding(args.pdims)
initial_conditions, lpt_displacements, ode_solutions, solver_stats = run_simulation(
0.25, 0.8, tuple(mesh_shape), tuple(box_size), halo_size,
solver_choice, nb_snapshots, sharding)
if rank == 0:
os.makedirs("fields", exist_ok=True)
print(f"[{rank}] Simulation done")
print(f"Solver stats: {solver_stats}")
# Save initial conditions
initial_conditions_g = all_gather(initial_conditions)
if rank == 0:
print(f"[{rank}] Saving initial_conditions")
np.save("fields/initial_conditions.npy", initial_conditions_g)
print(f"[{rank}] initial_conditions saved")
del initial_conditions_g, initial_conditions
# Save LPT displacements
lpt_displacements_g = all_gather(lpt_displacements)
if rank == 0:
print(f"[{rank}] Saving lpt_displacements")
np.save("fields/lpt_displacements.npy", lpt_displacements_g)
print(f"[{rank}] lpt_displacements saved")
del lpt_displacements_g, lpt_displacements
# Save each ODE solution separately
for i, sol in enumerate(ode_solutions):
sol_g = all_gather(sol)
if rank == 0:
print(f"[{rank}] Saving ode_solution_{i}")
np.save(f"fields/ode_solution_{i}.npy", sol_g)
print(f"[{rank}] ode_solution_{i} saved")
del sol_g
if __name__ == "__main__":
main()

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

39
notebooks/README.md Normal file
View file

@ -0,0 +1,39 @@
# Particle Mesh Simulation with JAXPM on Multi-GPU and Multi-Host Systems
This collection of notebooks demonstrates how to perform Particle Mesh (PM) simulations using **JAXPM**, leveraging JAX for efficient computation on multi-GPU and multi-host systems. Each notebook progressively covers different setups, from single-GPU simulations to advanced, distributed, multi-host simulations across multiple nodes.
## Table of Contents
1. **[Single-GPU Particle Mesh Simulation](01-Introduction.ipynb)**
- Introduction to basic PM simulations on a single GPU.
- Uses JAXPM to run simulations with absolute particle positions and Cloud-in-Cell (CIC) painting.
2. **[Advanced Particle Mesh Simulation on a Single GPU](02-Advanced_usage.ipynb)**
- Explore using diffrax solvers in the ODE step.
- Explores second order Lagrangian Perturbation Theory (LPT) simulations.
- Introduces weighted density field projections
3. **[Multi-GPU Particle Mesh Simulation with Halo Exchange](03-MultiGPU_PM_Halo.ipynb)**
- Extends PM simulation to multi-GPU setups with halo exchange.
- Uses sharding and device mesh configurations to manage distributed data across GPUs.
4. **[Multi-GPU Particle Mesh Simulation with Advanced Solvers](04-MultiGPU_PM_Solvers.ipynb)**
- Compares different ODE solvers (Leapfrog and Dopri5) in multi-GPU simulations.
- Highlights performance, memory considerations, and solver impact on simulation quality.
5. **[Multi-Host Particle Mesh Simulation](05-MultiHost_PM.ipynb)**
- Extends PM simulations to multi-host, multi-GPU setups for large-scale simulations.
- Guides through job submission, device initialization, and retrieving results across nodes.
## Getting Started
Each notebook includes installation instructions and guidelines for configuring JAXPM and required dependencies. Follow the setup instructions in each notebook to ensure an optimal environment.
## Requirements
- **JAXPM** (included in the installation commands within notebooks)
- **Diffrax** for ODE solvers
- **JAX** with CUDA support for multi-GPU or TPU setups
- **SLURM** for job scheduling on clusters (if running multi-host setups)
> **Note**: These notebooks are tested on the **Jean Zay** supercomputer and may require configuration changes for different HPC clusters.