2024-12-21 13:14:42 -05:00

347 KiB
Raw Blame History

Single-GPU Particle Mesh Simulation with JAXPM

In this notebook, we'll run a simple Particle Mesh (PM) simulation on a single GPU using JAXPM. This example provides a hands-on introduction to simulating the evolution of matter in the universe through basic PM techniques, allowing you to explore how cosmological structures form over time.

Open In Colab

In [3]:
jax.config.update("jax_enable_x64", True)
In [1]:
import jax
import jax.numpy as jnp
import jax_cosmo as jc

from jax.experimental.ode import odeint

from jaxpm.painting import cic_paint , cic_paint_dx
from jaxpm.pm import linear_field, lpt, make_ode_fn
from jaxpm.distributed import uniform_particles

Particle Mesh Simulation Setup

In this example, we initialize particles with uniform positions across the grid. This setup implicitly means that the Cloud-in-Cell (CIC) painting scheme will map absolute particle positions onto the grid.

In [ ]:
mesh_shape = [128, 128, 128]
box_size = [128., 128., 128.]
snapshots = jnp.array([0.1, 0.5, 1.0])

def run_simulation(omega_c, sigma8):
    # Create a small function to generate the matter power spectrum
    k = jnp.logspace(-4, 1, 128)
    pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
    pk_fn = lambda x: jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape)

    # Create initial conditions
    initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=jax.random.PRNGKey(0))

    particles = uniform_particles(mesh_shape)
    # Create particles
    cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
    # Initial displacement
    dx, p, f = lpt(cosmo, initial_conditions, particles, a=0.1)
    # Evolve the simulation forward
    res = odeint(make_ode_fn(mesh_shape), [particles + dx, p], snapshots, cosmo, rtol=1e-8, atol=1e-8)
    # Return the simulation volume at requested 

    return initial_conditions ,  particles + dx , res[0]
In [15]:
initial_conditions , lpt_particles , ode_particles = run_simulation(0.25, 0.8)
/home/wassim/micromamba/envs/jax/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
  return lax_numpy.astype(self, dtype, copy=copy, device=device)
/home/wassim/micromamba/envs/jax/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
  return lax_numpy.astype(self, dtype, copy=copy, device=device)
/home/wassim/micromamba/envs/jax/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
  return lax_numpy.astype(self, dtype, copy=copy, device=device)
In [16]:
from jaxpm.plotting import plot_fields_single_projection

fields = {"Initial Conditions" : initial_conditions , "LPT Field" : cic_paint(jnp.zeros(mesh_shape) ,lpt_particles)}
for i , field in enumerate(ode_particles[1:]):
    fields[f"field_{i}"] = cic_paint(jnp.zeros(mesh_shape) , field)
No description has been provided for this image

Particle Mesh Simulation Setup - Relative Position Painting

In the second example, we leave the initial particle positions as None, which applies a relative position (displacement-only) CIC painting scheme. This approach assumes uniform particle positions by default, saving memory and improving efficiency, though with the trade-off of assuming uniformity.

In [ ]:
mesh_shape = [128, 128, 128]
box_size = [128., 128., 128.]
snapshots = jnp.array([0.1, 0.5, 1.0])

def run_simulation(omega_c, sigma8):
    # Create a small function to generate the matter power spectrum
    k = jnp.logspace(-4, 1, 128)
    pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
    pk_fn = lambda x: jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape)

    # Create initial conditions
    initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=jax.random.PRNGKey(0))

    # Create particles
    cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
    # Initial displacement
    dx, p, f = lpt(cosmo, initial_conditions, a=0.1)
    # Evolve the simulation forward
    res = odeint(make_ode_fn(mesh_shape,paint_absolute_pos=False), [dx, p], snapshots, cosmo, rtol=1e-8, atol=1e-8)
    # Return the simulation volume at requested 

    return initial_conditions ,  dx , res[0]
initial_conditions , lpt_displacements , ode_displacements = run_simulation(0.25, 0.8)
/home/wassim/micromamba/envs/jax/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
  return lax_numpy.astype(self, dtype, copy=copy, device=device)
/home/wassim/micromamba/envs/jax/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
  return lax_numpy.astype(self, dtype, copy=copy, device=device)
/home/wassim/micromamba/envs/jax/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
  return lax_numpy.astype(self, dtype, copy=copy, device=device)
In [19]:
from jaxpm.plotting import plot_fields_single_projection

fields = {"Initial Conditions" : initial_conditions , "LPT Field" : cic_paint_dx(lpt_displacements)}
for i , field in enumerate(ode_displacements[1:]):
    fields[f"field_{i}"] = cic_paint_dx(field)
No description has been provided for this image

We note that painting only displacements is slower than painting absolute particle positions.

This slower performance occurs because painting displacements requires processing in smaller chunks or batches. Instead of creating a large array of particle positions with neighbors (e.g., (NParticles, 3, 8)), which consumes a significant amount of memory, we paint the particles in manageable batches.

This trade-off allows for greater memory efficiency but comes at the expense of speed. By reducing the memory footprint, we avoid memory limitations, especially useful in large-scale or distributed PM simulations, even if it means slightly slower painting.

Well see in later notebooks that retaining only the displacement is essential for distributed Particle Mesh (PM) simulations, where memory efficiency and computational speed are key.