mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-02-22 09:37:11 +00:00
* adding example of distributed solution * put back old functgion * update formatting * add halo exchange and slice pad * apply formatting * implement distributed optimized cic_paint * Use new cic_paint with halo * Fix seed for distributed normal * Wrap interpolation function to avoid all gather * Return normal order frequencies for single GPU * add example * format * add optimised bench script * times in ms * add lpt2 * update benchmark and add slurm * Visualize only final field * Update scripts/distributed_pm.py Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com> * Adjust pencil type for frequencies * fix painting issue with slabs * Shared operation in fourrier space now take inverted sharding axis for slabs * add assert to make pyright happy * adjust test for hpc-plotter * add PMWD test * bench * format * added github workflow * fix formatting from main * Update for jaxDecomp pure JAX * revert single halo extent change * update for latest jaxDecomp * remove fourrier_space in autoshmap * make normal_field work with single controller * format * make distributed pm work in single controller * merge bench_pm * update to leapfrog * add a strict dependency on jaxdecomp * global mesh no longer needed * kernels.py no longer uses global mesh * quick fix in distributed * pm.py no longer uses global mesh * painting.py no longer uses global mesh * update demo script * quick fix in kernels * quick fix in distributed * update demo * merge hugos LPT2 code * format * Small fix * format * remove duplicate get_ode_fn * update visualizer * update compensate CIC * By default check_rep is false for shard_map * remove experimental distributed code * update PGDCorrection and neural ode to use new fft3d * jaxDecomp pfft3d promotes to complex automatically * remove deprecated stuff * fix painting issue with read_cic * use jnp interp instead of jc interp * delete old slurms * add notebook examples * apply formatting * add distributed zeros * fix code in LPT2 * jit cic_paint * update notebooks * apply formating * get local shape and zeros can be used by users * add a user facing function to create uniform particle grid * use jax interp instead of jax_cosmo * use float64 for enmeshing * Allow applying weights with relative cic paint * Weights can be traced * remove script folder * update example notebooks * delete outdated design file * add readme for tutorials * update readme * fix small error * forgot particles in multi host * clarifying why cic_paint_dx is slower * clarifying the halo size dependence on the box size * ability to choose snapshots number with MultiHost script * Adding animation notebook * Put plotting in package * Add finite difference laplace kernel + powerspec functions from Hugo Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com> * Put plotting utils in package * By default use absoulute painting with * update code * update notebooks * add tests * Upgrade setup.py to pyproject * Format * format tests * update test dependencies * add test workflow * fix deprecated FftType in jaxpm.kernels * Add aboucaud comments * JAX version is 0.4.35 until Diffrax new release * add numpy explicitly as dependency for tests * fix install order for tests * add numpy to be installed * enforce no build isolation for fastpm * pip install jaxpm test without build isolation * bump jaxdecomp version * revert test workflow * remove outdated tests --------- Co-authored-by: EiffL <fr.eiffel@gmail.com> Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com> Co-authored-by: Wassim KABALAN <wassim@apc.in2p3.fr> Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com> Former-commit-id: 8c2e823d4669eac712089bf7f85ffb7912e8232d
179 lines
5.9 KiB
Python
179 lines
5.9 KiB
Python
import os
|
|
|
|
os.environ["EQX_ON_ERROR"] = "nan" # avoid an allgather caused by diffrax
|
|
import jax
|
|
|
|
jax.distributed.initialize()
|
|
rank = jax.process_index()
|
|
size = jax.process_count()
|
|
if rank == 0:
|
|
print(f"SIZE is {jax.device_count()}")
|
|
|
|
import argparse
|
|
from functools import partial
|
|
|
|
import jax.numpy as jnp
|
|
import jax_cosmo as jc
|
|
import numpy as np
|
|
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
|
|
PIDController, SaveAt, diffeqsolve)
|
|
from jax.experimental.mesh_utils import create_device_mesh
|
|
from jax.experimental.multihost_utils import process_allgather
|
|
from jax.sharding import Mesh, NamedSharding
|
|
from jax.sharding import PartitionSpec as P
|
|
|
|
from jaxpm.kernels import interpolate_power_spectrum
|
|
from jaxpm.painting import cic_paint_dx
|
|
from jaxpm.pm import linear_field, lpt, make_diffrax_ode
|
|
|
|
all_gather = partial(process_allgather, tiled=True)
|
|
|
|
|
|
def parse_arguments():
|
|
parser = argparse.ArgumentParser(
|
|
description="Run a cosmological simulation with JAX.")
|
|
parser.add_argument(
|
|
"-p",
|
|
"--pdims",
|
|
type=int,
|
|
nargs=2,
|
|
default=[1, jax.devices()],
|
|
help="Processor grid dimensions as two integers (e.g., 2 4).")
|
|
parser.add_argument(
|
|
"-m",
|
|
"--mesh_shape",
|
|
type=int,
|
|
nargs=3,
|
|
default=[512, 512, 512],
|
|
help="Shape of the simulation mesh as three values (e.g., 512 512 512)."
|
|
)
|
|
parser.add_argument(
|
|
"-b",
|
|
"--box_size",
|
|
type=float,
|
|
nargs=3,
|
|
default=[500.0, 500.0, 500.0],
|
|
help=
|
|
"Box size of the simulation as three values (e.g., 500.0 500.0 1000.0)."
|
|
)
|
|
parser.add_argument(
|
|
"-st",
|
|
"--snapshots",
|
|
type=int,
|
|
default=2,
|
|
help="Number of snapshots to save during the simulation.")
|
|
parser.add_argument("-H",
|
|
"--halo_size",
|
|
type=int,
|
|
default=64,
|
|
help="Halo size for the simulation.")
|
|
parser.add_argument("-s",
|
|
"--solver",
|
|
type=str,
|
|
choices=['leapfrog', 'dopri8'],
|
|
default='leapfrog',
|
|
help="ODE solver choice: 'leapfrog' or 'dopri8'.")
|
|
return parser.parse_args()
|
|
|
|
|
|
def create_mesh_and_sharding(pdims):
|
|
devices = create_device_mesh(pdims)
|
|
mesh = Mesh(devices, axis_names=('x', 'y'))
|
|
sharding = NamedSharding(mesh, P('x', 'y'))
|
|
return mesh, sharding
|
|
|
|
|
|
@partial(jax.jit, static_argnums=(2, 3, 4, 5, 6))
|
|
def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
|
|
solver_choice, nb_snapshots, sharding):
|
|
k = jnp.logspace(-4, 1, 128)
|
|
pk = jc.power.linear_matter_power(
|
|
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
|
|
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)
|
|
|
|
initial_conditions = linear_field(mesh_shape,
|
|
box_size,
|
|
pk_fn,
|
|
seed=jax.random.PRNGKey(0),
|
|
sharding=sharding)
|
|
|
|
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
|
|
|
|
dx, p, _ = lpt(cosmo,
|
|
initial_conditions,
|
|
a=0.1,
|
|
halo_size=halo_size,
|
|
sharding=sharding)
|
|
|
|
ode_fn = ODETerm(
|
|
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
|
|
|
|
# Choose solver
|
|
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()
|
|
stepsize_controller = ConstantStepSize(
|
|
) if solver_choice == "leapfrog" else PIDController(rtol=1e-5, atol=1e-5)
|
|
res = diffeqsolve(ode_fn,
|
|
solver,
|
|
t0=0.1,
|
|
t1=1.,
|
|
dt0=0.01,
|
|
y0=jnp.stack([dx, p], axis=0),
|
|
args=cosmo,
|
|
saveat=SaveAt(ts=jnp.linspace(0.2, 1., nb_snapshots)),
|
|
stepsize_controller=stepsize_controller)
|
|
|
|
ode_fields = [
|
|
cic_paint_dx(sol[0], halo_size=halo_size, sharding=sharding)
|
|
for sol in res.ys
|
|
]
|
|
lpt_field = cic_paint_dx(dx, halo_size=halo_size, sharding=sharding)
|
|
return initial_conditions, lpt_field, ode_fields, res.stats
|
|
|
|
|
|
def main():
|
|
args = parse_arguments()
|
|
mesh_shape = args.mesh_shape
|
|
box_size = args.box_size
|
|
halo_size = args.halo_size
|
|
solver_choice = args.solver
|
|
nb_snapshots = args.snapshots
|
|
|
|
sharding = create_mesh_and_sharding(args.pdims)
|
|
|
|
initial_conditions, lpt_displacements, ode_solutions, solver_stats = run_simulation(
|
|
0.25, 0.8, tuple(mesh_shape), tuple(box_size), halo_size,
|
|
solver_choice, nb_snapshots, sharding)
|
|
|
|
if rank == 0:
|
|
os.makedirs("fields", exist_ok=True)
|
|
print(f"[{rank}] Simulation done")
|
|
print(f"Solver stats: {solver_stats}")
|
|
|
|
# Save initial conditions
|
|
initial_conditions_g = all_gather(initial_conditions)
|
|
if rank == 0:
|
|
print(f"[{rank}] Saving initial_conditions")
|
|
np.save("fields/initial_conditions.npy", initial_conditions_g)
|
|
print(f"[{rank}] initial_conditions saved")
|
|
del initial_conditions_g, initial_conditions
|
|
|
|
# Save LPT displacements
|
|
lpt_displacements_g = all_gather(lpt_displacements)
|
|
if rank == 0:
|
|
print(f"[{rank}] Saving lpt_displacements")
|
|
np.save("fields/lpt_displacements.npy", lpt_displacements_g)
|
|
print(f"[{rank}] lpt_displacements saved")
|
|
del lpt_displacements_g, lpt_displacements
|
|
|
|
# Save each ODE solution separately
|
|
for i, sol in enumerate(ode_solutions):
|
|
sol_g = all_gather(sol)
|
|
if rank == 0:
|
|
print(f"[{rank}] Saving ode_solution_{i}")
|
|
np.save(f"fields/ode_solution_{i}.npy", sol_g)
|
|
print(f"[{rank}] ode_solution_{i} saved")
|
|
del sol_g
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|