JaxPM/notebooks/02-Advanced_usage.ipynb
Wassim KABALAN 6693e5c725
Some checks failed
Code Formatting / formatting (push) Failing after 4m30s
Tests / run_tests (3.10) (push) Failing after 1m41s
Tests / run_tests (3.11) (push) Failing after 1m42s
Tests / run_tests (3.12) (push) Failing after 1m15s
Fix sharding error (#37)
* Use cosmo as arg for the ODE function

* Update examples

* format

* notebook update

* fix tests

* add correct annotations for weights in painting and warning for cic_paint in distributed pm

* update test_against_fpm

* update distributed tests and add jacfwd jacrev and vmap tests

* format

* add Caveats to notebook readme

* final touches

* update Growth.py to allow using FastPM solver

* fix 2D painting when input is (X , Y , 2) shape

* update cic read halo size and notebooks examples

* Allow env variable control of caching in growth

* Format

* update test jax version

* update notebooks/03-MultiGPU_PM_Halo.ipynb

* update numpy install in wf

* update tolerance :)

* reorganize install in test workflow

* update tests

* add mpi4py

* update tests.yml

* update tests

* update wf

* format

* make normal_field signature consistent with jax.random.normal

* update by default normal_field dtype to match JAX

* format

* debug test workflow

* format

* debug test workflow

* updating tests

* fix accuracy

* fixed tolerance

* adding caching

* Update conftest.py

* Update tolerance and precision settings in distributed PM tests

* revererting back changes to growth.py

---------

Co-authored-by: Francois Lanusse <fr.eiffel@gmail.com>
Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>
2025-06-28 23:07:31 +02:00

1.2 MiB

Advanced Particle Mesh Simulation on a Single GPU

Open In Colab

In [ ]:
!pip install --quiet git+https://github.com/DifferentiableUniverseInitiative/JaxPM.git
!pip install diffrax
In [1]:
import jax
import jax.numpy as jnp
import jax_cosmo as jc

from jaxpm.painting import cic_paint , cic_paint_dx
from jaxpm.pm import linear_field, lpt, make_diffrax_ode
from jaxpm.distributed import uniform_particles
from diffrax import PIDController, Tsit5, ODETerm, SaveAt, diffeqsolve

Particle Mesh Simulation with Diffrax Leapfrog Solver

In this setup, we use the Tsit5 solver from the diffrax library to evolve particle displacements over time in our Particle Mesh simulation.

  • Efficient Displacement Tracking: We initialize only displacements (dx) rather than absolute positions, which uses a the pmwd cic_painting algorithm which is more memory efficient at the cost of being slightly slower
In [2]:
mesh_shape = [128, 128, 128]
box_size = [128., 128., 128.]
snapshots = jnp.array([0.5, 1.0])

@jax.jit
def run_simulation(omega_c, sigma8):
    # Create a small function to generate the matter power spectrum
    k = jnp.logspace(-4, 1, 128)
    pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
    pk_fn = lambda x: jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape)

    # Create initial conditions
    initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=jax.random.PRNGKey(0))

    # Create particles
    cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
    
    # Initial displacement
    dx, p, f = lpt(cosmo, initial_conditions, particles=None,a=0.1,order=1)
    
    # Evolve the simulation forward
    ode_fn = ODETerm(
        make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
    solver = Tsit5()

    stepsize_controller = PIDController(rtol=1e-6 , atol=1e-6)
    res = diffeqsolve(ode_fn,
                      solver,
                      t0=0.1,
                      t1=1.,
                      dt0=0.01,
                      y0=jnp.stack([dx, p], axis=0),
                      args=cosmo,
                      saveat=SaveAt(ts=snapshots),
                      stepsize_controller=stepsize_controller)

    ode_solutions = [sol[0] for sol in res.ys]
    return initial_conditions ,  dx , ode_solutions , res.stats

initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)
ode_solutions[-1].block_until_ready()
%timeit initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8);ode_solutions[-1].block_until_ready()
print(f"Solver Stats : {solver_stats}")
/home/wassim/micromamba/envs/jax/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:122: UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
  return lax_numpy.astype(self, dtype, copy=copy, device=device)
35 s ± 4.73 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
Solver Stats : {'max_steps': Array(4096, dtype=int32, weak_type=True), 'num_accepted_steps': Array(156, dtype=int32, weak_type=True), 'num_rejected_steps': Array(139, dtype=int32, weak_type=True), 'num_steps': Array(295, dtype=int32, weak_type=True)}
In [3]:
from jaxpm.plotting import plot_fields_single_projection

fields = {"Initial Conditions" : initial_conditions , "LPT Field" : jnp.log10(cic_paint_dx(lpt_displacements) + 1)}
for i , field in enumerate(ode_solutions):
    fields[f"field_{i}"] = jnp.log10(cic_paint_dx(field) + 1)
plot_fields_single_projection(fields)
No description has been provided for this image

First and Second Order Lagrangian Perturbation Theory (LPT) Displacements

This section introduces first-order and second-order LPT simulations, controlled by the order argument. First-order LPT captures linear displacements, while second-order LPT includes nonlinear corrections, allowing more accurate modeling of structure formation.

In [4]:
from functools import partial 

mesh_shape = [128, 128, 128]
box_size = [128., 128., 128.]
snapshots = jnp.array([0.5,1.])

@partial(jax.jit , static_argnums=(2,))
def lpt_simulation(omega_c, sigma8, order=1):
    # Create a small function to generate the matter power spectrum
    k = jnp.logspace(-4, 1, 128)
    pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
    pk_fn = lambda x: jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape)

    # Create initial conditions
    initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=jax.random.PRNGKey(0))

    # Create particles
    cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
    
    # Initial displacement
    dx, p, f = lpt(cosmo, initial_conditions, a=0.8,order=order)

    return initial_conditions ,  dx

initial_conditions_1 , lpt_displacements_1 = lpt_simulation(0.25, 0.8 , order=1)
lpt_displacements_1.block_until_ready()
initial_conditions_2 , lpt_displacements_2 = lpt_simulation(0.25, 0.8 , order=2)
lpt_displacements_2.block_until_ready()
%timeit initial_conditions_1 , lpt_displacements_1 = lpt_simulation(0.25, 0.8 , order=1);lpt_displacements_1.block_until_ready()
%timeit initial_conditions_2 , lpt_displacements_2 = lpt_simulation(0.25, 0.8, order=2);lpt_displacements_2.block_until_ready()
/home/wassim/micromamba/envs/jax/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:122: UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
  return lax_numpy.astype(self, dtype, copy=copy, device=device)
19 ms ± 69.6 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
28.6 ms ± 82 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
In [5]:
lpt_fields = {"First Order" : jnp.log10(cic_paint_dx(lpt_displacements_1) + 1) , "Second Order" : jnp.log10(cic_paint_dx(lpt_displacements_2) + 1)}
plot_fields_single_projection(lpt_fields)
No description has been provided for this image

Custom ODE Solver with Absolute Positions

Just like in the introduction notebook, this example uses absolute particle positions initialized on a uniform grid. We evolve these absolute positions forward using a Cloud-in-Cell (CIC) scheme, which enables clear tracking of particle movement across the simulation volume.

Here, we integrate over multiple snapshots with diffeqsolve and a Leapfrog solver.

In [6]:
mesh_shape = [128, 128, 128]
box_size = [128., 128., 128.]
snapshots = jnp.array([0.1 ,0.5, 1.])

@jax.jit
def run_simulation(omega_c, sigma8):
    # Create a small function to generate the matter power spectrum
    k = jnp.logspace(-4, 1, 128)
    pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
    pk_fn = lambda x: jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape)

    # Create initial conditions
    initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=jax.random.PRNGKey(0))

    # Create particles
    cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
    
    particles = uniform_particles(mesh_shape)
    # Initial displacement
    dx, p, f = lpt(cosmo, initial_conditions,particles=particles,a=0.1,order=2)
    
    # Evolve the simulation forward
    ode_fn = ODETerm(
        make_diffrax_ode(mesh_shape))
    solver = Tsit5()

    stepsize_controller = PIDController(rtol=1e-6 , atol=1e-6)
    res = diffeqsolve(ode_fn,
                      solver,
                      t0=0.1,
                      t1=1.,
                      dt0=0.2,
                      y0=jnp.stack([particles + dx, p], axis=0),
                      args=cosmo,
                      saveat=SaveAt(ts=snapshots),
                      stepsize_controller=stepsize_controller)

    ode_particles = [sol[0] for sol in res.ys]
    return initial_conditions ,  particles + dx , ode_particles , res.stats

initial_conditions , lpt_particles , ode_particles , solver_stats = run_simulation(0.25, 0.8)
print(f"Solver Stats : {solver_stats}")
Solver Stats : {'max_steps': Array(4096, dtype=int32, weak_type=True), 'num_accepted_steps': Array(67, dtype=int32, weak_type=True), 'num_rejected_steps': Array(8, dtype=int32, weak_type=True), 'num_steps': Array(75, dtype=int32, weak_type=True)}
In [7]:
from jaxpm.plotting import plot_fields_single_projection

fields = {"Initial Conditions" : initial_conditions , "LPT Field" : jnp.log10(cic_paint(jnp.zeros(mesh_shape) ,lpt_particles) + 1)}
for i , field in enumerate(ode_particles[1:]):
    fields[f"field_{i}"] = jnp.log10(cic_paint(jnp.zeros(mesh_shape) , field)+1)
plot_fields_single_projection(fields)
No description has been provided for this image

Weighted Field Projection for Central Region

In this cell, we apply custom weights to enhance density specifically in the central 3D region of the grid. By updating weights in this area, we multiply density by a factor of 3, emphasizing the structure in the center of the simulation volume.

We compare:

  • Weighted: Density increased in the central region.
  • Unweighted: Standard CIC painting without additional weighting.
In [8]:
from jaxpm.plotting import plot_fields_single_projection

center = slice(mesh_shape[0] // 4, 3 * mesh_shape[0] // 4 )
center3d = (slice(None) , center,center) 
weights = jnp.ones_like(initial_conditions)
weights = weights.at[center3d].multiply(10)

weighted = jnp.log10(cic_paint_dx(ode_solutions[0], weight=weights) + 1)
unweighted = jnp.log10(cic_paint_dx(ode_solutions[0] , weight=1.0) + 1)

plot_fields_single_projection({"Weighted" : weighted , "Unweighted" : unweighted} , project_axis=0)
No description has been provided for this image

Weighted Field Projection with Absolute Positions

For simulations with absolute positions, we apply a weight factor of 1.3 to the central 3D region. Unlike previous cases using displacements, here the weight affects the absolute particle positions directly, impacting the overall density field differently.

Note: Since the weights apply to absolute positions (not displacements), the result differs, affecting the particle density distribution directly.

In [9]:
from jaxpm.plotting import plot_fields_single_projection

center = slice(mesh_shape[0] // 4, 3 * mesh_shape[0] // 4 )
center3d = (slice(None) , center,center)  
weights = jnp.ones_like(initial_conditions)
weights = weights.at[center3d].multiply(5)

weighted = jnp.log10(cic_paint(jnp.zeros(mesh_shape),ode_particles[0], weight=weights) + 1)
unweighted = jnp.log10(cic_paint(jnp.zeros(mesh_shape),ode_particles[0] , weight=1.0) + 1)

plot_fields_single_projection({"Weighted" : weighted , "Unweighted" : unweighted} , project_axis=0)
No description has been provided for this image