JaxPM/notebooks/03-MultiGPU_PM_Halo.ipynb
2025-06-08 11:25:51 +02:00

4.9 MiB
Raw Blame History

Multi-GPU Particle Mesh Simulation with Halo Exchange

Open In Colab

In [ ]:
!pip install --quiet git+https://github.com/DifferentiableUniverseInitiative/JaxPM.git
!pip install diffrax

Note: This notebook requires 8 devices (GPU or TPU).
If you're running on CPU or don't have access to 8 devices,
you can simulate multiple devices by adding the following code at the start BEFORE IMPORTING JAX:

import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

Recommended only for debugging. If used, you must probably lower the resolution of the mesh.

In [2]:
import os
os.environ["EQX_ON_ERROR"] = "nan"
import jax
import jax.numpy as jnp
import jax_cosmo as jc
from jax.debug import visualize_array_sharding

from jaxpm.kernels import interpolate_power_spectrum
from jaxpm.painting import cic_paint_dx
from jaxpm.pm import linear_field, lpt, make_diffrax_ode
from functools import partial
from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve, Tsit5, PIDController
In [3]:
assert jax.device_count() >= 8, "This notebook requires a TPU or GPU runtime with 8 devices"

Setting Up Device Mesh and Sharding for Multi-GPU Simulation

This cell configures a 2×4 device mesh across 8 devices and sets up named sharding to distribute data efficiently.

  • Device Mesh: pdims = (2, 4) arranges devices in a 2×4 grid. This logical mesh layout defines how devices are organized for parallelism. For more information on grid mesh layouts, see the jaxDecomp GitHub repository.
  • Sharding with Mesh: Mesh(devices, axis_names=('x', 'y')) assigns names to the mesh axes, enabling flexible and explicit data mapping across devices.
  • PartitionSpec and NamedSharding: PartitionSpec('x', 'y') specifies how to partition data along the mesh axes. NamedSharding(mesh, P('x', 'y')) applies this scheme, ensuring that arrays are correctly distributed across the 2D device mesh.

More details on sharding and distributed computation are available in the JAX documentation: Distributed arrays and automatic parallelization.

In [4]:
from jax.experimental.multihost_utils import process_allgather
from jax.sharding import PartitionSpec as P, NamedSharding

all_gather = partial(process_allgather, tiled=True)

pdims = (2, 4)
mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))

Multi-GPU Particle Mesh Simulation with Sharding

This function closely follows the single-GPU simulation logic but extends it to support distributed computation across multiple devices. The key change is the use of the sharding argument in linear_field, lpt, and make_diffrax_ode.

By passing sharding, each stage of the simulation—initial condition generation, LPT displacement calculation, and time integration via the ODE solver—is parallelized across the 2×4 device mesh. This enables efficient scaling and memory management in large-scale simulations.

In [5]:
mesh_shape = 128
box_size = 256.
halo_size = 64
snapshots = (0.5, 1.0)

@partial(jax.jit , static_argnums=(2,3,4,5))
def run_simulation(omega_c, sigma8,mesh_shape,box_size,halo_size , snapshots):
    mesh_shape = (mesh_shape,) * 3
    box_size = (box_size,) * 3
    # Create a small function to generate the matter power spectrum
    k = jnp.logspace(-4, 1, 128)
    pk = jc.power.linear_matter_power(
        jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
    pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)

    # Create initial conditions
    initial_conditions = linear_field(mesh_shape,
                                      box_size,
                                      pk_fn,
                                      seed=jax.random.PRNGKey(0),
                                      sharding=sharding)


    cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)

    # Initial displacement
    dx, p, f = lpt(cosmo,
                   initial_conditions,
                   a=0.1,
                   order=2,
                   halo_size=halo_size,
                   sharding=sharding)

    # Evolve the simulation forward
    ode_fn = ODETerm(
        make_diffrax_ode(mesh_shape, paint_absolute_pos=False , sharding=sharding , halo_size=halo_size),)
    solver = Tsit5()

    stepsize_controller = PIDController(rtol=1e-3 , atol=1e-3)
    res = diffeqsolve(ode_fn,
                      solver,
                      t0=0.1,
                      t1=1.,
                      dt0=0.01,
                      y0=jnp.stack([dx, p], axis=0),
                      args=cosmo,
                      saveat=SaveAt(ts=snapshots),
                      stepsize_controller=stepsize_controller)
    ode_solutions = [sol[0] for sol in res.ys]
    return initial_conditions, dx, ode_solutions, res.stats

initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8 , mesh_shape, box_size , halo_size , snapshots)
ode_solutions[-1].block_until_ready()
%time initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8, mesh_shape, box_size , halo_size , snapshots);ode_solutions[-1].block_until_ready()
print(f"Solver Stats : {solver_stats}")
/home/wassim/micromamba/envs/jax/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:122: UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
  return lax_numpy.astype(self, dtype, copy=copy, device=device)
CPU times: user 5min 42s, sys: 3.77 s, total: 5min 46s
Wall time: 23.6 s
Solver Stats : {'max_steps': Array(4096, dtype=int32, weak_type=True), 'num_accepted_steps': Array(8, dtype=int32, weak_type=True), 'num_rejected_steps': Array(1, dtype=int32, weak_type=True), 'num_steps': Array(9, dtype=int32, weak_type=True)}

All fields and particle grids remain distributed at all times (as seen below). jaxPM ensures they are never gathered on a single device. In a forward model scenario, its the user's responsibility to maintain distributed data to avoid memory bottlenecks.

In [6]:
visualize_array_sharding(ode_solutions[-1][:,:,0,0])
                                    
                                    
  CPU 0    CPU 1    CPU 2    CPU 3  
                                    
                                    
                                    
                                    
                                    
  CPU 4    CPU 5    CPU 6    CPU 7  
                                    
                                    
                                    

⚠️ Warning: One caveat is that particle arrays usually have a shape of (NPart, 3),
where NPart = Nx * Nx * Nx. However, this shape is not shardable in a distributed setup.
Instead, particle arrays will always have a shape of (Nx, Ny, Nz, 3) to ensure they remain distributed across devices.

In [7]:
initial_conditions_g = all_gather(initial_conditions)
lpt_displacements_g = all_gather(lpt_displacements)
ode_solutions_g = [all_gather(p) for p in ode_solutions]
In [8]:
from jaxpm.plotting import plot_fields_single_projection

fields = {"Initial Conditions" : initial_conditions , "LPT Field" : jnp.log(cic_paint_dx(lpt_displacements) + 1)}
for i , field in enumerate(ode_solutions):
    fields[f"field_{i}"] = jnp.log10(cic_paint_dx(field) + 1)
plot_fields_single_projection(fields,project_axis=1)
No description has been provided for this image

Halo Exchange

Let's start by running a simulation without halo exchange. Here, we set halo_size = 0, which means no overlapping regions between device boundaries. This configuration helps us observe the limitations of simulations without halo regions, especially for calculating forces near boundaries in multi-GPU setups.

In [ ]:
mesh_shape = 128
box_size = 256.
halo_size = 0
snapshots = (0.5, 1.0)

initial_conditions , lpt_displacements_bc , ode_solutions_bc , solver_stats = run_simulation(0.25, 0.8 , mesh_shape, box_size , halo_size , snapshots)

initial_conditions_g = all_gather(initial_conditions)
lpt_displacements_g_bc = all_gather(lpt_displacements_bc)
ode_solutions_g_bc = [all_gather(p) for p in ode_solutions_bc]

from jaxpm.plotting import plot_fields_single_projection

fields = {"Initial Conditions" : initial_conditions_g , "LPT Field" : jnp.log(cic_paint_dx(lpt_displacements_g_bc) + 1)}
for i , field in enumerate(ode_solutions_g_bc):
    fields[f"field_{i}"] = jnp.log10(cic_paint_dx(field) + 1)
plot_fields_single_projection(fields,project_axis=0)
/home/wassim/micromamba/envs/jax/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:122: UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
  return lax_numpy.astype(self, dtype, copy=copy, device=device)
No description has been provided for this image

We can clearly observe artifacts in the visualization—most notably, horizontal and vertical discontinuities appearing in the evolved density fields (field_0, field_1). These are a direct consequence of not using a halo exchange (halo_size = 0), which causes poor force computation across subdomain boundaries in a multi-device simulation.

🔍 These artifacts highlight where the simulation fails to maintain physical continuity between neighboring partitions, especially as structures evolve and particles interact across device boundaries.

Returning Density Fields for Forward Modeling

In typical forward models, the goal is to produce density fields, not raw particle displacements. This function (run_simulation_with_fields) returns pre-painted density fields directly from each device, keeping data distributed and ready for downstream use (e.g. likelihood evaluation, comparison with observations).

In [11]:
@partial(jax.jit , static_argnums=(2,3,4,5))
def run_simulation_with_fields(omega_c, sigma8,mesh_shape,box_size,halo_size , snapshots):
    mesh_shape = (mesh_shape,) * 3
    box_size = (box_size,) * 3
    # Create a small function to generate the matter power spectrum
    k = jnp.logspace(-4, 1, 128)
    pk = jc.power.linear_matter_power(
        jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
    pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)

    # Create initial conditions
    initial_conditions = linear_field(mesh_shape,
                                      box_size,
                                      pk_fn,
                                      seed=jax.random.PRNGKey(0),
                                      sharding=sharding)


    cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)

    # Initial displacement
    dx, p, f = lpt(cosmo,
                   initial_conditions,
                   a=0.1,
                   order=2,
                   halo_size=halo_size,
                   sharding=sharding)

    # Evolve the simulation forward
    ode_fn = ODETerm(
        make_diffrax_ode(mesh_shape, paint_absolute_pos=False , sharding=sharding , halo_size=halo_size))
    solver = Tsit5()

    stepsize_controller = PIDController(rtol=1e-3 , atol=1e-3)
    res = diffeqsolve(ode_fn,
                      solver,
                      t0=0.1,
                      t1=1.,
                      dt0=0.01,
                      y0=jnp.stack([dx, p], axis=0),
                      args=cosmo,
                      saveat=SaveAt(ts=snapshots),
                      stepsize_controller=stepsize_controller)
    ode_fields = [cic_paint_dx(sol[0], halo_size=halo_size, sharding=sharding) for sol in res.ys]
    lpt_field = cic_paint_dx(dx , halo_size=halo_size, sharding=sharding)
    return initial_conditions, lpt_field, ode_fields, res.stats

Choosing the Right Halo Size

In some cases, the halo size can be too small, leading to visible artifacts in the snapshots. Here, we see that boundaries are handled well in the first and second snapshots, but the lines become more pronounced with each successive step. This indicates that a larger halo size may be needed to fully capture interactions across device boundaries over time.

In [12]:
from jaxpm.plotting import plot_fields_single_projection

mesh_shape = 128
box_size = 128.
halo_size = 4
snapshots = (0.3 ,0.4, 0.5 , 0.6, 0.8, 1.0)

initial_conditions , lpt_field , ode_fields , solver_stats = run_simulation_with_fields(0.25, 0.8 , mesh_shape, box_size , halo_size , snapshots)

initial_conditions_g = all_gather(initial_conditions)
lpt_field_g = all_gather(lpt_field)
ode_fields_g = [all_gather(p) for p in ode_fields]

fields = {"Initial Conditions" : initial_conditions , "LPT Field" : jnp.log(lpt_field + 1)}
for i , field in enumerate(ode_fields):
    fields[f"field_{i}"] = jnp.log10(field + 1)
plot_fields_single_projection(fields,project_axis=0)
No description has been provided for this image

In other cases, if the box size is too large, particles must cover greater distances, resulting in smaller final displacements. This reduces the impact of insufficient halo size on boundary artifacts.

Explanation

  • Large Box Sizes: In larger simulation boxes, particles tend to have smaller relative displacements (or slower speeds). This reduces the frequency of interactions with particles in neighboring subdomains, making boundary artifacts less pronounced, even if the halo size is smaller.

  • Smaller Box Sizes: In smaller boxes, particles cover a greater relative distance, leading to more frequent interactions with boundary particles. Here, the halo size must be carefully chosen to capture these interactions accurately, reducing visible artifacts in the visualization.

In this scenario, we can see that the insufficient halo size does not lead to severe artifacts, as particles are less affected by neighboring boundaries.

In [13]:
mesh_shape = 256
box_size = 1000.
halo_size = 4
snapshots = (0.3 ,0.4, 0.5 , 0.6, 0.8, 1.0)

initial_conditions , lpt_field , ode_fields , solver_stats = run_simulation_with_fields(0.25, 0.8 , mesh_shape, box_size , halo_size , snapshots)

initial_conditions_g = all_gather(initial_conditions)
lpt_field_g = all_gather(lpt_field)
ode_fields_g = [all_gather(p) for p in ode_fields]

fields = {"Initial Conditions" : initial_conditions , "LPT Field" : jnp.log(lpt_field + 1)}
for i , field in enumerate(ode_fields):
    fields[f"field_{i}"] = jnp.log10(field + 1)
plot_fields_single_projection(fields,project_axis=0)
/home/wassim/micromamba/envs/jax/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:1132: UserWarning: A large amount of constants were captured during lowering (2.42GB total). If this is intentional, disable this warning by setting JAX_CAPTURED_CONSTANTS_WARN_BYTES=-1. To obtain a report of where these constants were encountered, set JAX_CAPTURED_CONSTANTS_REPORT_FRAMES=-1.
  warnings.warn(message)
No description has been provided for this image

General Guideline

Start with a halo size that is one-eighth of the box size. Gradually reduce it until you begin to notice lines in the visualization, indicating an insufficient halo size.

Applying Weights in a Distributed Setup

We can apply weights just like before. In general, we want to apply weights on a distributed particle grid.

Note: When using weights in a distributed setting, ensure that the weights have the same sharding as the particle grid. If the sharding is not identical, JAX may perform an all-gather or other collective operations that could significantly impact performance.

In [9]:
from jaxpm.plotting import plot_fields_single_projection

field = ode_solutions[0]

center = slice(field.shape[0] // 4, 3 * field.shape[0] // 4 )
center3d = (slice(None) , center,center)  # All of X, Y=0, Z=0
weights = jnp.ones_like(field[...,0])
# Update weights for the down-left pencil by multiplying by 100
weights = weights.at[center3d].multiply(3)
visualize_array_sharding(weights[:,:,0])

weighted = cic_paint_dx(field, weight=weights)
unweighted = cic_paint_dx(field, weight=1.0)

plot_fields_single_projection({"Weighted" : weighted , "Unweighted" : unweighted} , project_axis=0)
                                    
                                    
  CPU 0    CPU 1    CPU 2    CPU 3  
                                    
                                    
                                    
                                    
                                    
  CPU 4    CPU 5    CPU 6    CPU 7  
                                    
                                    
                                    
No description has been provided for this image