JaxPM/notebooks/03-MultiGPU_PM_Halo.ipynb
2025-06-07 19:26:37 +02:00

4.9 MiB
Raw Blame History

Multi-GPU Particle Mesh Simulation with Halo Exchange

Open In Colab

In [ ]:
!pip install --quiet git+https://github.com/DifferentiableUniverseInitiative/JaxPM.git
!pip install diffrax

Note: This notebook requires 8 devices (GPU or TPU).
If you're running on CPU or don't have access to 8 devices,
you can simulate multiple devices by adding the following code at the start BEFORE IMPORTING JAX:

import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

Recommended only for debugging. If used, you must probably lower the resolution of the mesh.

In [2]:
import os
os.environ["EQX_ON_ERROR"] = "nan"
import jax
import jax.numpy as jnp
import jax_cosmo as jc
from jax.debug import visualize_array_sharding

from jaxpm.kernels import interpolate_power_spectrum
from jaxpm.painting import cic_paint_dx
from jaxpm.pm import linear_field, lpt, make_diffrax_ode
from functools import partial
from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve, Tsit5, PIDController
In [3]:
assert jax.device_count() >= 8, "This notebook requires a TPU or GPU runtime with 8 devices"

Setting Up Device Mesh and Sharding for Multi-GPU Simulation

This cell configures a 2x4 device mesh across 8 devices and sets up named sharding to distribute data efficiently.

  • Device Mesh: pdims = (2, 4) arranges devices in a 2x4 grid.
  • Sharding with Mesh: Mesh(devices, axis_names=('x', 'y')) assigns the mesh grid axes, which allows flexible mapping of array data across devices.
  • PartitionSpec and NamedSharding: PartitionSpec defines data partitioning across mesh axes ('x', 'y'), and NamedSharding(mesh, P('x', 'y')) specifies this sharding scheme for arrays in the simulation.

More info on Sharding in general in Distributed arrays and automatic parallelization

In [4]:
from jax.experimental.multihost_utils import process_allgather
from jax.sharding import PartitionSpec as P, NamedSharding

all_gather = partial(process_allgather, tiled=True)

pdims = (2, 4)
mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))

Multi-GPU Particle Mesh Simulation with Sharding

This function is very similar to the single-GPU implementation, with the key difference being that linear_field, lpt, and make_ode_fn now take a sharding argument. This allows each stage of the simulation—initial conditions, displacements, and ODE evolution—to be distributed across the configured 2x4 device mesh, ensuring efficient parallel execution.

In [13]:
mesh_shape = 128
box_size = 256.
halo_size = 64
snapshots = (0.5, 1.0)

@partial(jax.jit , static_argnums=(2,3,4,5))
def run_simulation(omega_c, sigma8,mesh_shape,box_size,halo_size , snapshots):
    mesh_shape = (mesh_shape,) * 3
    box_size = (box_size,) * 3
    # Create a small function to generate the matter power spectrum
    k = jnp.logspace(-4, 1, 128)
    pk = jc.power.linear_matter_power(
        jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
    pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)

    # Create initial conditions
    initial_conditions = linear_field(mesh_shape,
                                      box_size,
                                      pk_fn,
                                      seed=jax.random.PRNGKey(0),
                                      sharding=sharding)


    cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)

    # Initial displacement
    dx, p, f = lpt(cosmo,
                   initial_conditions,
                   a=0.1,
                   order=2,
                   halo_size=halo_size,
                   sharding=sharding)

    # Evolve the simulation forward
    ode_fn = ODETerm(
        make_diffrax_ode(mesh_shape, paint_absolute_pos=False , sharding=sharding , halo_size=halo_size),)
    solver = Tsit5()

    stepsize_controller = PIDController(rtol=1e-3 , atol=1e-3)
    res = diffeqsolve(ode_fn,
                      solver,
                      t0=0.1,
                      t1=1.,
                      dt0=0.01,
                      y0=jnp.stack([dx, p], axis=0),
                      args=cosmo,
                      saveat=SaveAt(ts=snapshots),
                      stepsize_controller=stepsize_controller)
    ode_solutions = [sol[0] for sol in res.ys]
    return initial_conditions, dx, ode_solutions, res.stats

initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8 , mesh_shape, box_size , halo_size , snapshots)
ode_solutions[-1].block_until_ready()
%time initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8, mesh_shape, box_size , halo_size , snapshots);ode_solutions[-1].block_until_ready()
print(f"Solver Stats : {solver_stats}")
CPU times: user 6min 3s, sys: 3.69 s, total: 6min 7s
Wall time: 24.4 s
Solver Stats : {'max_steps': Array(4096, dtype=int32, weak_type=True), 'num_accepted_steps': Array(8, dtype=int32, weak_type=True), 'num_rejected_steps': Array(1, dtype=int32, weak_type=True), 'num_steps': Array(9, dtype=int32, weak_type=True)}

All fields and particle grids remain distributed at all times (as seen below). jaxPM ensures they are never gathered on a single device. In a forward model scenario, its the user's responsibility to maintain distributed data to avoid memory bottlenecks.

In [30]:
visualize_array_sharding(ode_solutions[-1][:,:,0,0])
                                    
                                    
  CPU 0    CPU 1    CPU 2    CPU 3  
                                    
                                    
                                    
                                    
                                    
  CPU 4    CPU 5    CPU 6    CPU 7  
                                    
                                    
                                    

⚠️ Warning: One caveat is that particle arrays usually have a shape of (NPart, 3),
where NPart = Nx * Nx * Nx. However, this shape is not shardable in a distributed setup.
Instead, particle arrays will always have a shape of (Nx, Ny, Nz, 3) to ensure they remain distributed across devices.

In [31]:
initial_conditions_g = all_gather(initial_conditions)
lpt_displacements_g = all_gather(lpt_displacements)
ode_solutions_g = [all_gather(p) for p in ode_solutions]
In [34]:
from jaxpm.plotting import plot_fields_single_projection

fields = {"Initial Conditions" : initial_conditions , "LPT Field" : jnp.log(cic_paint_dx(lpt_displacements) + 1)}
for i , field in enumerate(ode_solutions):
    fields[f"field_{i}"] = jnp.log10(cic_paint_dx(field) + 1)
plot_fields_single_projection(fields,project_axis=1)
No description has been provided for this image

Halo Exchange

Let's start by running a simulation without halo exchange. Here, we set halo_size = 0, which means no overlapping regions between device boundaries. This configuration helps us observe the limitations of simulations without halo regions, especially for calculating forces near boundaries in multi-GPU setups.

In [37]:
mesh_shape = 128
box_size = 256.
halo_size = 0
snapshots = (0.5, 1.0)

initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8 , mesh_shape, box_size , halo_size , snapshots)

initial_conditions_g = all_gather(initial_conditions)
lpt_displacements_g = all_gather(lpt_displacements)
ode_solutions_g = [all_gather(p) for p in ode_solutions]

from jaxpm.plotting import plot_fields_single_projection

fields = {"Initial Conditions" : initial_conditions , "LPT Field" : jnp.log(cic_paint_dx(lpt_displacements) + 1)}
for i , field in enumerate(ode_solutions):
    fields[f"field_{i}"] = jnp.log10(cic_paint_dx(field) + 1)
plot_fields_single_projection(fields,project_axis=0)
No description has been provided for this image

We can clearly observe artifacts in the visualization—most notably, horizontal and vertical discontinuities appearing in the evolved density fields (field_0, field_1). These are a direct consequence of not using a halo exchange (halo_size = 0), which causes poor force computation across subdomain boundaries in a multi-device simulation.

🔍 These artifacts highlight where the simulation fails to maintain physical continuity between neighboring partitions, especially as structures evolve and particles interact across device boundaries.

In [5]:
@partial(jax.jit , static_argnums=(2,3,4,5))
def run_simulation_with_fields(omega_c, sigma8,mesh_shape,box_size,halo_size , snapshots):
    mesh_shape = (mesh_shape,) * 3
    box_size = (box_size,) * 3
    # Create a small function to generate the matter power spectrum
    k = jnp.logspace(-4, 1, 128)
    pk = jc.power.linear_matter_power(
        jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
    pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)

    # Create initial conditions
    initial_conditions = linear_field(mesh_shape,
                                      box_size,
                                      pk_fn,
                                      seed=jax.random.PRNGKey(0),
                                      sharding=sharding)


    cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)

    # Initial displacement
    dx, p, f = lpt(cosmo,
                   initial_conditions,
                   a=0.1,
                   order=2,
                   halo_size=halo_size,
                   sharding=sharding)

    # Evolve the simulation forward
    ode_fn = ODETerm(
        make_diffrax_ode(mesh_shape, paint_absolute_pos=False , sharding=sharding , halo_size=halo_size))
    solver = Tsit5()

    stepsize_controller = PIDController(rtol=1e-3 , atol=1e-3)
    res = diffeqsolve(ode_fn,
                      solver,
                      t0=0.1,
                      t1=1.,
                      dt0=0.01,
                      y0=jnp.stack([dx, p], axis=0),
                      args=cosmo,
                      saveat=SaveAt(ts=snapshots),
                      stepsize_controller=stepsize_controller)
    ode_fields = [cic_paint_dx(sol[0], halo_size=halo_size, sharding=sharding) for sol in res.ys]
    lpt_field = cic_paint_dx(dx , halo_size=halo_size, sharding=sharding)
    return initial_conditions, lpt_field, ode_fields, res.stats

Now we can see that there are very apparent lines between the subdomains of the simulation. These lines highlight the artifacts that arise when running the simulation without a halo exchange, as boundary conditions are not accurately handled across device edges.

Choosing the Right Halo Size

In some cases, the halo size can be too small, leading to visible artifacts in the snapshots. Here, we see that boundaries are handled well in the first and second snapshots, but the lines become more pronounced with each successive step. This indicates that a larger halo size may be needed to fully capture interactions across device boundaries over time.

In [9]:
from jaxpm.plotting import plot_fields_single_projection

mesh_shape = 128
box_size = 128.
halo_size = 4
snapshots = (0.3 ,0.4, 0.5 , 0.6, 0.8, 1.0)

initial_conditions , lpt_field , ode_fields , solver_stats = run_simulation_with_fields(0.25, 0.8 , mesh_shape, box_size , halo_size , snapshots)

initial_conditions_g = all_gather(initial_conditions)
lpt_field_g = all_gather(lpt_field)
ode_fields_g = [all_gather(p) for p in ode_fields]

fields = {"Initial Conditions" : initial_conditions , "LPT Field" : jnp.log(lpt_field + 1)}
for i , field in enumerate(ode_fields):
    fields[f"field_{i}"] = jnp.log10(field + 1)
plot_fields_single_projection(fields,project_axis=0)
No description has been provided for this image

In other cases, if the box size is too large, particles must cover greater distances, resulting in smaller final displacements. This reduces the impact of insufficient halo size on boundary artifacts.

Explanation

  • Large Box Sizes: In larger simulation boxes, particles tend to have smaller relative displacements (or slower speeds). This reduces the frequency of interactions with particles in neighboring subdomains, making boundary artifacts less pronounced, even if the halo size is smaller.

  • Smaller Box Sizes: In smaller boxes, particles cover a greater relative distance, leading to more frequent interactions with boundary particles. Here, the halo size must be carefully chosen to capture these interactions accurately, reducing visible artifacts in the visualization.

In this scenario, we can see that the insufficient halo size does not lead to severe artifacts, as particles are less affected by neighboring boundaries.

In [10]:
mesh_shape = 256
box_size = 1000.
halo_size = 4
snapshots = (0.3 ,0.4, 0.5 , 0.6, 0.8, 1.0)

initial_conditions , lpt_field , ode_fields , solver_stats = run_simulation_with_fields(0.25, 0.8 , mesh_shape, box_size , halo_size , snapshots)

initial_conditions_g = all_gather(initial_conditions)
lpt_field_g = all_gather(lpt_field)
ode_fields_g = [all_gather(p) for p in ode_fields]

fields = {"Initial Conditions" : initial_conditions , "LPT Field" : jnp.log(lpt_field + 1)}
for i , field in enumerate(ode_fields):
    fields[f"field_{i}"] = jnp.log10(field + 1)
plot_fields_single_projection(fields,project_axis=0)
/home/wassim/micromamba/envs/jax/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:122: UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
  return lax_numpy.astype(self, dtype, copy=copy, device=device)
/home/wassim/micromamba/envs/jax/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:1132: UserWarning: A large amount of constants were captured during lowering (2.42GB total). If this is intentional, disable this warning by setting JAX_CAPTURED_CONSTANTS_WARN_BYTES=-1. To obtain a report of where these constants were encountered, set JAX_CAPTURED_CONSTANTS_REPORT_FRAMES=-1.
  warnings.warn(message)
No description has been provided for this image

General Guideline

Start with a halo size that is one-eighth of the box size. Gradually reduce it until you begin to notice lines in the visualization, indicating an insufficient halo size.

Applying Weights in a Distributed Setup

We can apply weights just like before. In general, we want to apply weights on a distributed particle grid.

Note: When using weights in a distributed setting, ensure that the weights have the same sharding as the particle grid. If the sharding is not identical, JAX may perform an all-gather or other collective operations that could significantly impact performance.

In [14]:
from jaxpm.plotting import plot_fields_single_projection

field = ode_solutions[0]

center = slice(field.shape[0] // 4, 3 * field.shape[0] // 4 )
center3d = (slice(None) , center,center)  # All of X, Y=0, Z=0
weights = jnp.ones_like(field[...,0])
# Update weights for the down-left pencil by multiplying by 100
weights = weights.at[center3d].multiply(3)
visualize_array_sharding(weights[:,:,0])

weighted = cic_paint_dx(field, weight=weights)
unweighted = cic_paint_dx(field, weight=1.0)

plot_fields_single_projection({"Weighted" : weighted , "Unweighted" : unweighted} , project_axis=0)
                                    
                                    
  CPU 0    CPU 1    CPU 2    CPU 3  
                                    
                                    
                                    
                                    
                                    
  CPU 4    CPU 5    CPU 6    CPU 7  
                                    
                                    
                                    
No description has been provided for this image