JaxPM/notebooks/02-Advanced_usage.ipynb
Wassim KABALAN 695e359f56 jaxdecomp proto (#21)
* adding example of distributed solution

* put back old functgion

* update formatting

* add halo exchange and slice pad

* apply formatting

* implement distributed optimized cic_paint

* Use new cic_paint with halo

* Fix seed for distributed normal

* Wrap interpolation function to avoid all gather

* Return normal order frequencies for single GPU

* add example

* format

* add optimised bench script

* times in ms

* add lpt2

* update benchmark and add slurm

* Visualize only final field

* Update scripts/distributed_pm.py

Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>

* Adjust pencil type for frequencies

* fix painting issue with slabs

* Shared operation in fourrier space now take inverted sharding axis for
slabs

* add assert to make pyright happy

* adjust test for hpc-plotter

* add PMWD test

* bench

* format

* added github workflow

* fix formatting from main

* Update for jaxDecomp pure JAX

* revert single halo extent change

* update for latest jaxDecomp

* remove fourrier_space in autoshmap

* make normal_field work with single controller

* format

* make distributed pm work in single controller

* merge bench_pm

* update to leapfrog

* add a strict dependency on jaxdecomp

* global mesh no longer needed

* kernels.py no longer uses global mesh

* quick fix in distributed

* pm.py no longer uses global mesh

* painting.py no longer uses global mesh

* update demo script

* quick fix in kernels

* quick fix in distributed

* update demo

* merge hugos LPT2 code

* format

* Small fix

* format

* remove duplicate get_ode_fn

* update visualizer

* update compensate CIC

* By default check_rep is false for shard_map

* remove experimental distributed code

* update PGDCorrection and neural ode to use new fft3d

* jaxDecomp pfft3d promotes to complex automatically

* remove deprecated stuff

* fix painting issue with read_cic

* use jnp interp instead of jc interp

* delete old slurms

* add notebook examples

* apply formatting

* add distributed zeros

* fix code in LPT2

* jit cic_paint

* update notebooks

* apply formating

* get local shape and zeros can be used by users

* add a user facing function to create uniform particle grid

* use jax interp instead of jax_cosmo

* use float64 for enmeshing

* Allow applying weights with relative cic paint

* Weights can be traced

* remove script folder

* update example notebooks

* delete outdated design file

* add readme for tutorials

* update readme

* fix small error

* forgot particles in multi host

* clarifying why cic_paint_dx is slower

* clarifying the halo size dependence on the box size

* ability to choose snapshots number with MultiHost script

* Adding animation notebook

* Put plotting in package

* Add finite difference laplace kernel + powerspec functions from Hugo

Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com>

* Put plotting utils in package

* By default use absoulute painting with

* update code

* update notebooks

* add tests

* Upgrade setup.py to pyproject

* Format

* format tests

* update test dependencies

* add test workflow

* fix deprecated FftType in jaxpm.kernels

* Add aboucaud comments

* JAX version is 0.4.35 until Diffrax new release

* add numpy explicitly as dependency for tests

* fix install order for tests

* add numpy to be installed

* enforce no build isolation for fastpm

* pip install jaxpm test without build isolation

* bump jaxdecomp version

* revert test workflow

* remove outdated tests

---------

Co-authored-by: EiffL <fr.eiffel@gmail.com>
Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>
Co-authored-by: Wassim KABALAN <wassim@apc.in2p3.fr>
Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com>
2024-12-20 05:44:02 -05:00

5.1 MiB
Raw Blame History

Advanced Particle Mesh Simulation on a Single GPU

Open In Colab

In [ ]:
!pip install --quiet git+https://github.com/DifferentiableUniverseInitiative/JaxPM.git
!pip install diffrax
In [ ]:
import jax
import jax.numpy as jnp
import jax_cosmo as jc

from jaxpm.painting import cic_paint , cic_paint_dx
from jaxpm.pm import linear_field, lpt, make_diffrax_ode
from jaxpm.distributed import uniform_particles
from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve

Particle Mesh Simulation with Diffrax Leapfrog Solver

In this setup, we use the LeapfrogMidpoint solver from the diffrax library to evolve particle displacements over time in our Particle Mesh simulation. The novelty here is the use of a Leapfrog solver from diffrax for efficient, memory-saving time integration.

  • Leapfrog Integration: This symplectic integrator is well-suited for simulations of gravitational dynamics, preserving energy over long timescales and allowing larger time steps without sacrificing accuracy.
  • Efficient Displacement Tracking: We initialize only displacements (dx) rather than absolute positions, which, combined with Leapfrogs stability, enhances memory efficiency and speeds up computation.
In [ ]:
mesh_shape = [128, 128, 128]
box_size = [128., 128., 128.]
snapshots = jnp.array([0.5, 1.0])

@jax.jit
def run_simulation(omega_c, sigma8):
    # Create a small function to generate the matter power spectrum
    k = jnp.logspace(-4, 1, 128)
    pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
    pk_fn = lambda x: jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape)

    # Create initial conditions
    initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=jax.random.PRNGKey(0))

    # Create particles
    cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
    
    # Initial displacement
    dx, p, f = lpt(cosmo, initial_conditions, a=0.1,order=1)
    
    # Evolve the simulation forward
    ode_fn = ODETerm(
        make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
    solver = LeapfrogMidpoint()

    stepsize_controller = ConstantStepSize()
    res = diffeqsolve(ode_fn,
                      solver,
                      t0=0.1,
                      t1=1.,
                      dt0=0.01,
                      y0=jnp.stack([dx, p], axis=0),
                      args=cosmo,
                      saveat=SaveAt(ts=snapshots),
                      stepsize_controller=stepsize_controller)

    ode_solutions = [sol[0] for sol in res.ys]
    return initial_conditions ,  dx , ode_solutions , res.stats

initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)
ode_solutions[-1].block_until_ready()
%timeit initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8);ode_solutions[-1].block_until_ready()
print(f"Solver Stats : {solver_stats}")
4.05 s ± 1.54 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Solver Stats : {'max_steps': Array(4096, dtype=int32, weak_type=True), 'num_accepted_steps': Array(90, dtype=int32, weak_type=True), 'num_rejected_steps': Array(0, dtype=int32, weak_type=True), 'num_steps': Array(90, dtype=int32, weak_type=True)}
In [ ]:
from jaxpm.plotting import plot_fields_single_projection

fields = {"Initial Conditions" : initial_conditions , "LPT Field" : cic_paint_dx(lpt_displacements)}
for i , field in enumerate(ode_solutions):
    fields[f"field_{i}"] = cic_paint_dx(field)
plot_fields_single_projection(fields)
No description has been provided for this image

First and Second Order Lagrangian Perturbation Theory (LPT) Displacements

This section introduces first-order and second-order LPT simulations, controlled by the order argument. First-order LPT captures linear displacements, while second-order LPT includes nonlinear corrections, allowing more accurate modeling of structure formation.

In [ ]:
from functools import partial 

mesh_shape = [128, 128, 128]
box_size = [128., 128., 128.]
snapshots = jnp.array([0.5,1.])

@partial(jax.jit , static_argnums=(2,))
def lpt_simulation(omega_c, sigma8, order=1):
    # Create a small function to generate the matter power spectrum
    k = jnp.logspace(-4, 1, 128)
    pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
    pk_fn = lambda x: jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape)

    # Create initial conditions
    initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=jax.random.PRNGKey(0))

    # Create particles
    cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
    
    # Initial displacement
    dx, p, f = lpt(cosmo, initial_conditions, a=0.8,order=order)

    return initial_conditions ,  dx

initial_conditions_1 , lpt_displacements_1 = lpt_simulation(0.25, 0.8 , order=1)
lpt_displacements_1.block_until_ready()
initial_conditions_2 , lpt_displacements_2 = lpt_simulation(0.25, 0.8 , order=2)
lpt_displacements_2.block_until_ready()
%timeit initial_conditions_1 , lpt_displacements_1 = lpt_simulation(0.25, 0.8 , order=1);lpt_displacements_1.block_until_ready()
%timeit initial_conditions_2 , lpt_displacements_2 = lpt_simulation(0.25, 0.8, order=2);lpt_displacements_2.block_until_ready()
32.3 ms ± 9.42 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
42.8 ms ± 9.74 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
In [5]:
lpt_fields = {"First Order" : cic_paint_dx(lpt_displacements_1) , "Second Order" : cic_paint_dx(lpt_displacements_2)}
plot_fields_single_projection(lpt_fields)
No description has been provided for this image

Custom ODE Solver with Absolute Positions

Just like in the introduction notebook, this example uses absolute particle positions initialized on a uniform grid. We evolve these absolute positions forward using a Cloud-in-Cell (CIC) scheme, which enables clear tracking of particle movement across the simulation volume.

Here, we integrate over multiple snapshots with diffeqsolve and a Leapfrog solver.

In [ ]:
mesh_shape = [128, 128, 128]
box_size = [128., 128., 128.]
snapshots = jnp.array([0.1 ,0.5, 1.])

@jax.jit
def run_simulation(omega_c, sigma8):
    # Create a small function to generate the matter power spectrum
    k = jnp.logspace(-4, 1, 128)
    pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
    pk_fn = lambda x: jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape)

    # Create initial conditions
    initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=jax.random.PRNGKey(0))

    # Create particles
    cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
    
    particles = uniform_particles(mesh_shape)
    # Initial displacement
    dx, p, f = lpt(cosmo, initial_conditions,particles=particles,a=0.1,order=2)
    
    # Evolve the simulation forward
    ode_fn = ODETerm(
        make_diffrax_ode(cosmo, mesh_shape))
    solver = LeapfrogMidpoint()

    stepsize_controller = ConstantStepSize()
    res = diffeqsolve(ode_fn,
                      solver,
                      t0=0.1,
                      t1=1.,
                      dt0=0.2,
                      y0=jnp.stack([particles + dx, p], axis=0),
                      args=cosmo,
                      saveat=SaveAt(ts=snapshots),
                      stepsize_controller=stepsize_controller)

    ode_particles = [sol[0] for sol in res.ys]
    return initial_conditions ,  particles + dx , ode_particles , res.stats

initial_conditions , lpt_particles , ode_particles , solver_stats = run_simulation(0.25, 0.8)
print(f"Solver Stats : {solver_stats}")
Solver Stats : {'max_steps': Array(4096, dtype=int32, weak_type=True), 'num_accepted_steps': Array(5, dtype=int32, weak_type=True), 'num_rejected_steps': Array(0, dtype=int32, weak_type=True), 'num_steps': Array(5, dtype=int32, weak_type=True)}
In [ ]:
from jaxpm.plotting import plot_fields_single_projection

fields = {"Initial Conditions" : initial_conditions , "LPT Field" : cic_paint(jnp.zeros(mesh_shape) ,lpt_particles)}
for i , field in enumerate(ode_particles[1:]):
    fields[f"field_{i}"] = cic_paint(jnp.zeros(mesh_shape) , field)
plot_fields_single_projection(fields)
No description has been provided for this image

Weighted Field Projection for Central Region

In this cell, we apply custom weights to enhance density specifically in the central 3D region of the grid. By updating weights in this area, we multiply density by a factor of 3, emphasizing the structure in the center of the simulation volume.

We compare:

  • Weighted: Density increased in the central region.
  • Unweighted: Standard CIC painting without additional weighting.
In [ ]:
from jaxpm.plotting import plot_fields_single_projection

center = slice(mesh_shape[0] // 4, 3 * mesh_shape[0] // 4 )
center3d = (slice(None) , center,center) 
weights = jnp.ones_like(initial_conditions)
weights = weights.at[center3d].multiply(3)

weighted = cic_paint_dx(ode_solutions[0], weight=weights)
unweighted = cic_paint_dx(ode_solutions[0] , weight=1.0)

plot_fields_single_projection({"Weighted" : weighted , "Unweighted" : unweighted} , project_axis=0)
No description has been provided for this image

Weighted Field Projection with Absolute Positions

For simulations with absolute positions, we apply a weight factor of 1.3 to the central 3D region. Unlike previous cases using displacements, here the weight affects the absolute particle positions directly, impacting the overall density field differently.

Note: Since the weights apply to absolute positions (not displacements), the result differs, affecting the particle density distribution directly.

In [ ]:
from jaxpm.plotting import plot_fields_single_projection

center = slice(mesh_shape[0] // 4, 3 * mesh_shape[0] // 4 )
center3d = (slice(None) , center,center)  
weights = jnp.ones_like(initial_conditions)
weights = weights.at[center3d].multiply(1.3)

weighted = cic_paint(jnp.zeros(mesh_shape),ode_particles[0], weight=weights)
unweighted = cic_paint(jnp.zeros(mesh_shape),ode_particles[0] , weight=2.0)

plot_fields_single_projection({"Weighted" : weighted , "Unweighted" : unweighted} , project_axis=0)
No description has been provided for this image