Compare commits

..

No commits in common. "main" and "v0.1.0" have entirely different histories.
main ... v0.1.0

29 changed files with 118 additions and 2911 deletions

View file

@ -7,37 +7,15 @@ on:
branches: [ "main" ]
jobs:
formatting:
build:
runs-on: ubuntu-latest
steps:
- name: Checkout Source
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Cache pip dependencies
uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ${{ runner.os }}-formatting-pip-${{ hashFiles('.pre-commit-config.yaml') }}
restore-keys: |
${{ runner.os }}-formatting-pip-
- name: Cache pre-commit
uses: actions/cache@v4
with:
path: ~/.cache/pre-commit
key: ${{ runner.os }}-pre-commit-${{ hashFiles('.pre-commit-config.yaml') }}
restore-keys: |
${{ runner.os }}-pre-commit-
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pre-commit isort
python -m pip install --upgrade pip isort
python -m pip install pre-commit
- name: Run pre-commit
run: python -m pre_commit run --all-files

View file

@ -1,55 +0,0 @@
name: Upload Python Package
on:
release:
types: [published]
permissions:
contents: read
jobs:
release-build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.x"
- name: Build release distributions
run: |
# NOTE: put your own distribution build steps here.
python -m pip install build
python -m build
- name: Upload distributions
uses: actions/upload-artifact@v4
with:
name: release-dists
path: dist/
pypi-publish:
runs-on: ubuntu-latest
needs:
- release-build
permissions:
# IMPORTANT: this permission is mandatory for trusted publishing
id-token: write
environment:
name: pypi
url: https://pypi.org/p/jaxpm
steps:
- name: Retrieve release distributions
uses: actions/download-artifact@v4
with:
name: release-dists
path: dist/
- name: Publish release distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
packages-dir: dist/

View file

@ -10,63 +10,36 @@ on:
jobs:
run_tests:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12"]
python-version: ["3.10" , "3.11" , "3.12"]
steps:
- name: Checkout Source
uses: actions/checkout@v4
uses: actions/checkout@v2.3.1
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Cache pip dependencies
uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/requirements-test.txt', '**/pyproject.toml') }}
restore-keys: |
${{ runner.os }}-pip-${{ matrix.python-version }}-
${{ runner.os }}-pip-
- name: Cache system dependencies
uses: actions/cache@v4
with:
path: /var/cache/apt
key: ${{ runner.os }}-apt-${{ hashFiles('.github/workflows/tests.yml') }}
restore-keys: |
${{ runner.os }}-apt-
- name: Install system dependencies
- name: Install dependencies
run: |
sudo apt-get update
sudo apt-get install -y libopenmpi-dev
- name: Install Python dependencies
run: |
python -m pip install --upgrade pip setuptools wheel
# Install JAX first as it's a key dependency
pip install jax
# Install build dependencies
pip install setuptools cython mpi4py
# Install test requirements with no-build-isolation for faster builds
pip install -r requirements-test.txt --no-build-isolation
# Install additional test dependencies
pip install pytest diffrax
# Install package in development mode
pip install -e .
echo "numpy version installed:"
python -c "import numpy; print(numpy.__version__)"
python -m pip install --upgrade pip
pip install jax==0.4.35
pip install numpy setuptools cython wheel
pip install git+https://github.com/MP-Gadget/pfft-python
pip install git+https://github.com/MP-Gadget/pmesh
pip install git+https://github.com/ASKabalan/fastpm-python --no-build-isolation
pip install .[test]
- name: Run Single Device Tests
run: |
cd tests
pytest -v -m "not distributed"
- name: Run Distributed tests
run: |
pytest -v tests/test_distributed_pm.py
pytest -v -m distributed

3
.gitignore vendored
View file

@ -132,6 +132,3 @@ dmypy.json
# Pyre type checker
.pyre/
# Hide version file
_version.py

View file

@ -1,6 +1,6 @@
MIT License
Copyright (c) 2021-2025 Differentiable Universe Initiative
Copyright (c) 2021 Differentiable Universe Initiative
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal

View file

@ -1,2 +0,0 @@
prune notebooks
prune tests

View file

@ -1,26 +1,9 @@
# JaxPM
[![Notebook](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/DifferentiableUniverseInitiative/JaxPM/blob/main/notebooks/01-Introduction.ipynb)
[![PyPI version](https://img.shields.io/pypi/v/jaxpm)](https://pypi.org/project/jaxpm/) [![Tests](https://github.com/DifferentiableUniverseInitiative/JaxPM/actions/workflows/tests.yml/badge.svg)](https://github.com/DifferentiableUniverseInitiative/JaxPM/actions/workflows/tests.yml) <!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
[![Tests](https://github.com/DifferentiableUniverseInitiative/JaxPM/actions/workflows/tests.yml/badge.svg)](https://github.com/DifferentiableUniverseInitiative/JaxPM/actions/workflows/tests.yml) <!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
[![All Contributors](https://img.shields.io/badge/all_contributors-5-orange.svg?style=flat-square)](#contributors-)
<!-- ALL-CONTRIBUTORS-BADGE:END -->
JAX-powered Cosmological Particle-Mesh N-body Solver
> ### Note
> **The new JaxPM v0.1.xx** supports multi-GPU model distribution while remaining compatible with previous releases. These significant changes are still under development and testing, so please report any issues you encounter.
> For the older but more stable version, install:
> ```bash
> pip install jaxpm==0.0.2
> ```
## Install
Basic installation can be done using pip:
```bash
pip install jaxpm
```
For more advanced installation for optimized distribution on gpu clusters, please install jaxDecomp first. See instructions [here](https://github.com/DifferentiableUniverseInitiative/jaxDecomp).
## Goals
Provide a modern infrastructure to support differentiable PM N-body simulations using JAX:

14
dev/job_pfft.sh Normal file
View file

@ -0,0 +1,14 @@
#!/bin/bash
#SBATCH -A m1727
#SBATCH -C gpu
#SBATCH -q debug
#SBATCH -t 0:05:00
#SBATCH -N 2
#SBATCH --ntasks-per-node=4
#SBATCH -c 32
#SBATCH --gpus-per-task=1
#SBATCH --gpu-bind=none
module load python cudnn/8.2.0 nccl/2.11.4 cudatoolkit
export SLURM_CPU_BIND="cores"
srun python test_pfft.py

View file

@ -166,11 +166,11 @@ def uniform_particles(mesh_shape, sharding=None):
axis=-1)
def normal_field(seed, shape, sharding=None, dtype=float):
def normal_field(mesh_shape, seed, sharding=None):
"""Generate a Gaussian random field with the given power spectrum."""
gpu_mesh = sharding.mesh if sharding is not None else None
if gpu_mesh is not None and not (gpu_mesh.empty):
local_mesh_shape = get_local_shape(shape, sharding)
local_mesh_shape = get_local_shape(mesh_shape, sharding)
size = jax.device_count()
# rank = jax.process_index()
@ -190,9 +190,9 @@ def normal_field(seed, shape, sharding=None, dtype=float):
return jax.random.normal(key=keys[idx], shape=shape, dtype=dtype)
return shard_map(
partial(normal, shape=local_mesh_shape, dtype=dtype),
partial(normal, shape=local_mesh_shape, dtype='float32'),
mesh=gpu_mesh,
in_specs=P(None),
out_specs=spec)(keys) # yapf: disable
else:
return jax.random.normal(shape=shape, key=seed, dtype=dtype)
return jax.random.normal(shape=mesh_shape, key=seed)

View file

@ -26,7 +26,7 @@ def E(cosmo, a):
where :math:`f(a)` is the Dark Energy evolution parameter computed
by :py:meth:`.f_de`.
"""
return np.sqrt(Esqr(cosmo, a))
return np.power(Esqr(cosmo, a), 0.5)
def df_de(cosmo, a, epsilon=1e-5):

View file

@ -12,7 +12,7 @@ from jaxpm.kernels import cic_compensation, fftk
from jaxpm.painting_utils import gather, scatter
def _cic_paint_impl(grid_mesh, positions, weight=1.):
def _cic_paint_impl(grid_mesh, positions, weight=None):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
displacement field: [nx, ny, nz, 3]
@ -27,10 +27,12 @@ def _cic_paint_impl(grid_mesh, positions, weight=1.):
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
if jnp.isscalar(weight):
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
else:
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]), kernel)
if weight is not None:
if jnp.isscalar(weight):
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
else:
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]),
kernel)
neighboor_coords = jnp.mod(
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
@ -46,13 +48,7 @@ def _cic_paint_impl(grid_mesh, positions, weight=1.):
@partial(jax.jit, static_argnums=(3, 4))
def cic_paint(grid_mesh, positions, weight=1., halo_size=0, sharding=None):
if sharding is not None:
print("""
WARNING : absolute painting is not recommended in multi-device mode.
Please use relative painting instead.
""")
def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
positions = positions.reshape((*grid_mesh.shape, 3))
@ -61,11 +57,9 @@ def cic_paint(grid_mesh, positions, weight=1., halo_size=0, sharding=None):
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
weight_spec = P() if jnp.isscalar(weight) else spec
grid_mesh = autoshmap(_cic_paint_impl,
gpu_mesh=gpu_mesh,
in_specs=(spec, spec, weight_spec),
in_specs=(spec, spec, P()),
out_specs=spec)(grid_mesh, positions, weight)
grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,
@ -134,7 +128,6 @@ def cic_paint_2d(mesh, positions, weight):
positions: [npart, 2]
weight: [npart]
"""
positions = positions.reshape([-1, 2])
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
@ -143,7 +136,7 @@ def cic_paint_2d(mesh, positions, weight):
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1]
if weight is not None:
kernel = kernel * weight.reshape(*positions.shape[:-1])
kernel = kernel * weight[..., jnp.newaxis]
neighboor_coords = jnp.mod(
neighboor_coords.reshape([-1, 4, 2]).astype('int32'),
@ -158,16 +151,13 @@ def cic_paint_2d(mesh, positions, weight):
return mesh
def _cic_paint_dx_impl(displacements,
weight=1.,
halo_size=0,
chunk_size=2**24):
def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
halo_x, _ = halo_size[0]
halo_y, _ = halo_size[1]
original_shape = displacements.shape
particle_mesh = jnp.zeros(original_shape[:-1], dtype=displacements.dtype)
particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32')
if not jnp.isscalar(weight):
if weight.shape != original_shape[:-1]:
raise ValueError("Weight shape must match particle shape")
@ -185,7 +175,7 @@ def _cic_paint_dx_impl(displacements,
return scatter(pmid.reshape([-1, 3]),
displacements.reshape([-1, 3]),
particle_mesh,
chunk_size=chunk_size,
chunk_size=2**24,
val=weight)
@ -200,13 +190,13 @@ def cic_paint_dx(displacements,
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
weight_spec = P() if jnp.isscalar(weight) else spec
grid_mesh = autoshmap(partial(_cic_paint_dx_impl,
halo_size=halo_size,
weight=weight,
chunk_size=chunk_size),
gpu_mesh=gpu_mesh,
in_specs=(spec, weight_spec),
out_specs=spec)(displacements, weight)
in_specs=spec,
out_specs=spec)(displacements)
grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,
@ -240,12 +230,6 @@ def _cic_read_dx_impl(grid_mesh, disp, halo_size):
def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None):
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
# Halo size is halved for the read operation
# We only need to read the density field
# while in the painting operation we need to exchange and reduce the halo
# We chose to do that since it is much easier to write a custom jvp rule for exchange
# while it is a bit harder if there is a reduction involved
halo_size = jax.tree.map(lambda x: x // 2, halo_size)
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,

View file

@ -30,14 +30,17 @@ def enmesh(base_indices, displacements, cell_size, base_shape, offset,
"""Multilinear enmeshing."""
base_indices = jnp.asarray(base_indices)
displacements = jnp.asarray(displacements)
cell_size = jnp.array(cell_size, dtype=displacements.dtype)
if base_shape is not None:
base_shape = jnp.array(base_shape, dtype=base_indices.dtype)
offset = offset.astype(base_indices.dtype)
if new_cell_size is not None:
new_cell_size = jnp.array(new_cell_size, dtype=displacements.dtype)
if new_shape is not None:
new_shape = jnp.array(new_shape, dtype=base_indices.dtype)
with jax.experimental.enable_x64():
cell_size = jnp.float64(
cell_size) if new_cell_size is not None else jnp.array(
cell_size, dtype=displacements.dtype)
if base_shape is not None:
base_shape = jnp.array(base_shape, dtype=base_indices.dtype)
offset = jnp.float64(offset)
if new_cell_size is not None:
new_cell_size = jnp.float64(new_cell_size)
if new_shape is not None:
new_shape = jnp.array(new_shape, dtype=base_indices.dtype)
spatial_dim = base_indices.shape[1]
neighbor_offsets = (

View file

@ -131,7 +131,7 @@ def linear_field(mesh_shape, box_size, pk, seed, sharding=None):
Generate initial conditions.
"""
# Initialize a random field with one slice on each gpu
field = normal_field(seed=seed, shape=mesh_shape, sharding=sharding)
field = normal_field(mesh_shape, seed=seed, sharding=sharding)
field = fft3d(field)
kvec = fftk(field)
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2
@ -139,7 +139,7 @@ def linear_field(mesh_shape, box_size, pk, seed, sharding=None):
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (
box_size[0] * box_size[1] * box_size[2])
field = field * jnp.sqrt(pkmesh)
field = field * (pkmesh)**0.5
field = ifft3d(field)
return field
@ -172,7 +172,8 @@ def make_ode_fn(mesh_shape,
return nbody_ode
def make_diffrax_ode(mesh_shape,
def make_diffrax_ode(cosmo,
mesh_shape,
paint_absolute_pos=True,
halo_size=0,
sharding=None):
@ -182,7 +183,6 @@ def make_diffrax_ode(mesh_shape,
state is a tuple (position, velocities)
"""
pos, vel = state
cosmo = args
forces = pm_forces(pos,
mesh_shape=mesh_shape,

View file

@ -52,7 +52,7 @@ def _initialize_pk(mesh_shape, box_shape, kedges, los):
kshapes = np.eye(len(mesh_shape), dtype=np.int32) * -2 + 1
kvec = [(2 * np.pi * m / l) * np.fft.fftfreq(m).reshape(kshape)
for m, l, kshape in zip(mesh_shape, box_shape, kshapes)]
kmesh = jnp.sqrt(sum(ki**2 for ki in kvec))
kmesh = sum(ki**2 for ki in kvec)**0.5
dig = np.digitize(kmesh.reshape(-1), kedges)
kcount = np.bincount(dig, minlength=len(kedges) + 1)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -17,8 +17,9 @@ import jax_cosmo as jc
import numpy as np
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
PIDController, SaveAt, diffeqsolve)
from jax.experimental.mesh_utils import create_device_mesh
from jax.experimental.multihost_utils import process_allgather
from jax.sharding import NamedSharding
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from jaxpm.kernels import interpolate_power_spectrum
@ -77,7 +78,7 @@ def parse_arguments():
def create_mesh_and_sharding(pdims):
devices = create_device_mesh(pdims)
mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))
mesh = Mesh(devices, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
return mesh, sharding
@ -105,10 +106,7 @@ def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(mesh_shape,
paint_absolute_pos=False,
sharding=sharding,
halo_size=halo_size))
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
# Choose solver
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()

View file

@ -1,320 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# **Animating Particle Mesh density fields**\n",
"\n",
"In this tutorial, we will animate the density field of a particle mesh simulation. We will use the `manim` library to create the animation. \n",
"\n",
"The density fields are created exactly like in the notebook [**05-MultiHost_PM.ipynb**](05-MultiHost_PM.ipynb) using the same script [**05-MultiHost_PM.py**](05-MultiHost_PM.py)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To run a multi-host simulation, you first need to **allocate a job** with `salloc`. This command requests resources on an HPC cluster.\n",
"\n",
"just like in notebook [**05-MultiHost_PM.ipynb**]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!salloc --account=XXX@a100 -C a100 --gres=gpu:8 --ntasks-per-node=8 --time=00:40:00 --cpus-per-task=8 --hint=nomultithread --qos=qos_gpu-dev --nodes=4 & "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**A few hours later**\n",
"\n",
"Use `!squeue -u $USER -o \"%i %D %b\"` to **check the JOB ID** and verify your resource allocation.\n",
"\n",
"In this example, weve been allocated **32 GPUs split across 4 nodes**.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!squeue -u $USER -o \"%i %D %b\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Unset the following environment variables, as they can cause issues when using JAX in a distributed setting:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"del os.environ['VSCODE_PROXY_URI']\n",
"del os.environ['NO_PROXY']\n",
"del os.environ['no_proxy']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Checking Available Compute Resources\n",
"\n",
"Run the following command to initialize JAX distributed computing and display the devices available for this job:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!srun --jobid=467745 -n 32 python -c \"import jax; jax.distributed.initialize(); print(jax.devices()) if jax.process_index() == 0 else None\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Multi-Host Simulation Script with Arguments (reminder)\n",
"\n",
"This script is nearly identical to the single-host version, with the main addition being the call to `jax.distributed.initialize()` at the start, enabling multi-host parallelism. Heres a breakdown of the key arguments:\n",
"\n",
"- **`--pdims`** (`-p`): Specifies processor grid dimensions as two integers, like `16 2` for 16 x 2 device mesh (default is `[1, jax.devices()]`).\n",
"- **`--mesh_shape`** (`-m`): Defines the simulation mesh shape as three integers (default is `[512, 512, 512]`).\n",
"- **`--box_size`** (`-b`): Sets the physical box size of the simulation as three floating-point values, e.g., `1000. 1000. 1000.` (default is `[500.0, 500.0, 500.0]`).\n",
"- **`--halo_size`** (`-H`): Specifies the halo size for boundary overlap across nodes (default is `64`).\n",
"- **`--solver`** (`-s`): Chooses the ODE solver (`leapfrog` or `dopri8`). The `leapfrog` solver uses a fixed step size, while `dopri8` is an adaptive Runge-Kutta solver with a PID controller (default is `leapfrog`).\n",
"- **`--snapthots`** (`-st`) : Number of snapshots to save (warning, increases memory usage)\n",
"\n",
"### Running the Multi-Host Simulation Script\n",
"\n",
"To create a smooth animation, we need a series of closely spaced snapshots to capture the evolution of the density field over time. In this example, we set the number of snapshots to **10** to ensure smooth transitions in the animation.\n",
"\n",
"Using a larger number of GPUs helps process these snapshots efficiently, especially with a large simulation mesh or high-resolution data. This allows us to achieve both the desired snapshot frequency and the necessary simulation detail without excessive runtime.\n",
"\n",
"The command to run the multi-host simulation with these settings will look something like this:\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import subprocess\n",
"\n",
"# Define parameters as variables\n",
"jobid = \"467745\"\n",
"num_processes = 32\n",
"script_name = \"05-MultiHost_PM.py\"\n",
"mesh_shape = (1024, 1024, 1024)\n",
"box_size = (1000., 1000., 1000.)\n",
"halo_size = 128\n",
"solver = \"leapfrog\"\n",
"pdims = (16, 2)\n",
"snapshots = 8\n",
"\n",
"# Build the command as a list, incorporating variables\n",
"command = [\n",
" \"srun\",\n",
" f\"--jobid={jobid}\",\n",
" \"-n\", str(num_processes),\n",
" \"python\", script_name,\n",
" \"--mesh_shape\", str(mesh_shape[0]), str(mesh_shape[1]), str(mesh_shape[2]),\n",
" \"--box_size\", str(box_size[0]), str(box_size[1]), str(box_size[2]),\n",
" \"--halo_size\", str(halo_size),\n",
" \"-s\", solver,\n",
" \"--pdims\", str(pdims[0]), str(pdims[1]),\n",
" \"--snapshots\", str(snapshots)\n",
"]\n",
"\n",
"# Execute the command as a subprocess\n",
"subprocess.run(command)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Projecting the 3D Density Fields to 2D\n",
"\n",
"To visualize the 3D density fields in 2D, we need to create a projection:\n",
"\n",
"- **`project_to_2d` Function**: This function reduces the 3D array to 2D by summing over a portion of one axis.\n",
" - We sum the top one-eighth of the data along the first axis to capture a slice of the density field.\n",
"\n",
"- **Creating 2D Projections**: Apply `project_to_2d` to each 3D field (`initial_conditions`, `lpt_displacements`, `ode_solution_0`, and `ode_solution_1`) to get 2D arrays that represent the density fields.\n",
"\n",
"### Applying the Magma Colormap\n",
"\n",
"To improve visualization, apply the \"magma\" colormap to each 2D projection:\n",
"\n",
"- **`apply_colormap` Function**: This function maps values in the 2D array to colors using the \"magma\" colormap.\n",
" - First, normalize the array to the `[0, 1]` range.\n",
" - Apply the colormap to create RGB images, which will be used for the animation.\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from matplotlib import colormaps\n",
"\n",
"# Define a function to project the 3D field to 2D\n",
"def project_to_2d(field):\n",
" sum_over = field.shape[0] // 8\n",
" slicing = [slice(None)] * field.ndim\n",
" slicing[0] = slice(None, sum_over)\n",
" slicing = tuple(slicing)\n",
"\n",
" return field[slicing].sum(axis=0)\n",
"\n",
"\n",
"def apply_colormap(array, cmap_name=\"magma\"):\n",
" cmap = colormaps[cmap_name]\n",
" normalized_array = (array - array.min()) / (array.max() - array.min())\n",
" colored_image = cmap(normalized_array)[:, :, :3] # Drop alpha channel for RGB\n",
" return (colored_image * 255).astype(np.uint8)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Loading and Visualizing Results\n",
"\n",
"After running the multi-host simulation, we load the saved results from disk:\n",
"\n",
"- **`initial_conditions.npy`**: Initial conditions for the simulation.\n",
"- **`lpt_displacements.npy`**: Linear perturbation displacements.\n",
"- **`ode_solution_*.npy`** : Solutions from the ODE solver at each snapshot.\n",
"\n",
"We will now project the fields to 2D maps and apply the color map\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"initial_conditions = apply_colormap(project_to_2d(np.load('fields/initial_conditions.npy')))\n",
"lpt_displacements = apply_colormap(project_to_2d(np.load('fields/lpt_displacements.npy')))\n",
"ode_solutions = []\n",
"for i in range(8):\n",
" ode_solutions.append(apply_colormap(project_to_2d(np.load(f'fields/ode_solution_{i}.npy'))))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Animating with Manim\n",
"\n",
"To create animations with `manim` in a Jupyter notebook, we start by configuring some settings to ensure the output displays correctly and without a background.\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from manim import *\n",
"config.media_width = \"100%\"\n",
"config.verbosity = \"WARNING\"\n",
"config.background_color = \"#00000000\" # Transparent background"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Defining the Animation in Manim\n",
"\n",
"This animation class, `FieldTransition`, smoothly transitions through the stages of the particle mesh density field evolution.\n",
"\n",
"- **Setup**: Each density field snapshot is loaded as an image and aligned for smooth transitions.\n",
"- **Animation Sequence**:\n",
" - The animation begins with a fade-in of the initial conditions.\n",
" - It then transitions through the stages in sequence, showing each snapshot of the density field evolution with brief pauses in between.\n",
"\n",
"To run the animation, execute `%manim -v WARNING -qm FieldTransition` to render it in the Jupyter Notebook.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Define the animation in Manim\n",
"class FieldTransition(Scene):\n",
" def construct(self):\n",
" init_conditions_img = ImageMobject(initial_conditions).scale(4)\n",
" lpt_img = ImageMobject(lpt_displacements).scale(4)\n",
" snapshots_imgs = [ImageMobject(sol).scale(4) for sol in ode_solutions]\n",
"\n",
"\n",
" # Place the images on top of each other initially\n",
" lpt_img.move_to(init_conditions_img)\n",
" for img in snapshots_imgs:\n",
" img.move_to(init_conditions_img)\n",
"\n",
" # Show initial field and then transform between fields\n",
" self.play(FadeIn(init_conditions_img))\n",
" self.wait(0.2)\n",
" self.play(Transform(init_conditions_img, lpt_img))\n",
" self.wait(0.2)\n",
" self.play(Transform(lpt_img, snapshots_imgs[0]))\n",
" self.wait(0.2)\n",
" for img1, img2 in zip(snapshots_imgs, snapshots_imgs[1:]):\n",
" self.play(Transform(img1, img2))\n",
" self.wait(0.2)\n",
"\n",
"%manim -v WARNING -qm -o anim.gif --format=gif FieldTransition "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View file

@ -0,0 +1 @@
c4a44973e4f11841a8c14f4d200e7e87887419aa

View file

@ -37,50 +37,3 @@ Each notebook includes installation instructions and guidelines for configuring
- **SLURM** for job scheduling on clusters (if running multi-host setups)
> **Note**: These notebooks are tested on the **Jean Zay** supercomputer and may require configuration changes for different HPC clusters.
## Caveats
### Cloud-in-Cell (CIC) Painting (Single Device)
There is two ways to perform the CIC painting in JAXPM. The first one is to use the `cic_paint` which paints absolute particle positions to the mesh. The second one is to use the `cic_paint_dx` which paints relative particle positions to the mesh (using uniform particles). The absolute version is faster at the cost of more memory usage.
inorder to use relative painting you need to :
- Set the `particles` argument in `lpt` function from `jaxpm.pm` to `None`
- Set `paint_absolute_pos` to `False` in `make_ode_fn` or `make_diffrax_ode` function from `jaxpm.pm` (it is True by default)
Otherwise you set `particles` to the starting particles of your choice and leave `paint_absolute_pos` to `True` (default value).
### Cloud-in-Cell (CIC) Painting (Multi Device)
Both `cic_paint` and `cic_paint_dx` functions are available in multi-device mode.
You need to set the arguments `sharding` and `halo_size` which is explained in the notebook [03-MultiGPU_PM_Halo.ipynb](03-MultiGPU_PM_Halo.ipynb).
One thing to note that `cic_paint` is not as accurate as `cic_paint_dx` in multi-device mode and therefor is not recommended.
Using relative painting in multi-device mode is just like in single device mode.\
You need to set the `particles` argument in `lpt` function from `jaxpm.pm` to `None` and set `paint_absolute_pos` to `False`
### Distributed PM
To run a distributed PM follow the examples in notebooks [03](03-MultiGPU_PM_Halo.ipynb) and [05](05-MultiHost_PM.ipynb) for multi-host.
In short you need to set the arguments `sharding` and `halo_size` in `lpt` , `linear_field` the `make_ode` functions and `pm_forces` if you use it.
Missmatching the shardings will give you errors and unexpected results.
You can also use `normal_field` and `uniform_particles` from `jaxpm.pm.distributed` to create the fields and particles with a sharding.
### Choosing the right pdims
pdims are processor dimensions.\
Explained more in the jaxdecomp paper [here](https://github.com/DifferentiableUniverseInitiative/jaxDecomp).
For 8 devices there are three decompositions that are possible:
- (1 , 8)
- (2 , 4) , (4 , 2)
- (8 , 1)
(1 , X) should be the fastest (2 , X) or (X , 2) is more accurate but slightly slower.\
and (X , 1) is giving the least accurate results for some reason so it is not recommended.

View file

@ -3,15 +3,28 @@ requires = ["setuptools", "wheel", "setuptools-scm"]
build-backend = "setuptools.build_meta"
[project]
name = "jaxpm"
name = "JaxPM"
dynamic = ["version"]
description = "A simple Particle-Mesh implementation in JAX"
description = "A dead simple FastPM implementation in JAX"
authors = [{ name = "JaxPM developers" }]
readme = "README.md"
requires-python = ">=3.9"
license = { file = "LICENSE" }
urls = { "Homepage" = "https://github.com/DifferentiableUniverseInitiative/JaxPM" }
dependencies = ["jax_cosmo", "jax>=0.4.35", "jaxdecomp>=0.2.3"]
dependencies = ["jax_cosmo", "jax>=0.4.30", "jaxdecomp>=0.2.2"]
[project.optional-dependencies]
test = [
"jax>=0.4.30",
"numpy",
"jax_cosmo",
"jaxdecomp>=0.2.2",
"pytest>=8.0.0",
"pfft-python @ git+https://github.com/MP-Gadget/pfft-python",
"pmesh @ git+https://github.com/MP-Gadget/pmesh",
"fastpm @ git+https://github.com/ASKabalan/fastpm-python",
"diffrax"
]
[tool.setuptools]
packages = ["jaxpm"]

View file

@ -1,5 +0,0 @@
pfft-python @ git+https://github.com/MP-Gadget/pfft-python
pmesh @ git+https://github.com/MP-Gadget/pmesh
fastpm @ git+https://github.com/ASKabalan/fastpm-python
numpy==2.2.6
diffrax

View file

@ -44,7 +44,7 @@ def simulation_config(request):
return request.param
@pytest.fixture(scope="session", params=[0.1, 0.2])
@pytest.fixture(scope="session", params=[0.1, 0.5, 0.8])
def lpt_scale_factor(request):
return request.param
@ -96,9 +96,8 @@ def fpm_initial_conditions(cosmo, particle_mesh):
whitec = particle_mesh.generate_whitenoise(42,
type='complex',
unitary=False)
lineark = whitec.apply(lambda k, v: jnp.sqrt(
pk_fn(jnp.sqrt(sum(ki**2 for ki in k)))) * v * jnp.sqrt(
(1 / v.BoxSize).prod()))
lineark = whitec.apply(lambda k, v: pk_fn(sum(ki**2 for ki in k)**0.5)**0.5
* v * (1 / v.BoxSize).prod()**0.5)
init_mesh = lineark.c2r().value # XXX
return lineark, grid, init_mesh
@ -152,7 +151,7 @@ def nbody_from_lpt1(solver, fpm_lpt1, particle_mesh, lpt_scale_factor):
if lpt_scale_factor == 0.8:
pytest.skip("Do not run nbody simulation from scale factor 0.8")
stages = np.linspace(lpt_scale_factor, 1.0, 100, endpoint=True)
stages = np.linspace(lpt_scale_factor, 1.0, 10, endpoint=True)
finalstate = solver.nbody(fpm_lpt1, leapfrog(stages))
fpm_mesh = particle_mesh.paint(finalstate.X).value
@ -168,7 +167,7 @@ def nbody_from_lpt2(solver, fpm_lpt2, particle_mesh, lpt_scale_factor):
if lpt_scale_factor == 0.8:
pytest.skip("Do not run nbody simulation from scale factor 0.8")
stages = np.linspace(lpt_scale_factor, 1.0, 100, endpoint=True)
stages = np.linspace(lpt_scale_factor, 1.0, 10, endpoint=True)
finalstate = solver.nbody(fpm_lpt2, leapfrog(stages))
fpm_mesh = particle_mesh.paint(finalstate.X).value

View file

@ -2,7 +2,6 @@ import pytest
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve
from helpers import MSE, MSRE
from jax import numpy as jnp
from numpy.testing import assert_allclose
from jaxpm.distributed import uniform_particles
from jaxpm.painting import cic_paint, cic_paint_dx
@ -11,8 +10,6 @@ from jaxpm.utils import power_spectrum
_TOLERANCE = 1e-4
_PM_TOLERANCE = 1e-3
_FIELD_RTOL = 1e-4
_FIELD_ATOL = 1e-3
@pytest.mark.single_device
@ -37,10 +34,7 @@ def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor,
_, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape)
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert_allclose(lpt_field,
fpm_ref_field,
rtol=_FIELD_RTOL,
atol=_FIELD_ATOL)
assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
@ -61,10 +55,7 @@ def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor,
_, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape)
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert_allclose(lpt_field,
fpm_ref_field,
rtol=_FIELD_RTOL,
atol=_FIELD_ATOL)
assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
@ -85,7 +76,7 @@ def test_nbody_absolute(simulation_config, initial_conditions,
a=lpt_scale_factor,
order=order)
ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
solver = Dopri5()
controller = PIDController(rtol=1e-8,
@ -104,7 +95,6 @@ def test_nbody_absolute(simulation_config, initial_conditions,
t1=1.0,
dt0=None,
y0=y0,
args=cosmo,
stepsize_controller=controller,
saveat=saveat)
@ -115,10 +105,7 @@ def test_nbody_absolute(simulation_config, initial_conditions,
_, jpm_ps = power_spectrum(final_field, box_shape=box_shape)
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert_allclose(final_field,
fpm_ref_field,
rtol=_FIELD_RTOL,
atol=_FIELD_ATOL)
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE
@ -134,7 +121,8 @@ def test_nbody_relative(simulation_config, initial_conditions,
# Initial displacement
dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order)
ode_fn = ODETerm(make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
ode_fn = ODETerm(
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
solver = Dopri5()
controller = PIDController(rtol=1e-9,
@ -153,7 +141,6 @@ def test_nbody_relative(simulation_config, initial_conditions,
t1=1.0,
dt0=None,
y0=y0,
args=cosmo,
stepsize_controller=controller,
saveat=saveat)
@ -164,8 +151,5 @@ def test_nbody_relative(simulation_config, initial_conditions,
_, jpm_ps = power_spectrum(final_field, box_shape=box_shape)
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert_allclose(final_field,
fpm_ref_field,
rtol=_FIELD_RTOL,
atol=_FIELD_ATOL)
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE

View file

@ -2,11 +2,8 @@ from conftest import initialize_distributed
initialize_distributed() # ignore : E402
from functools import partial # noqa : E402
import jax # noqa : E402
import jax.numpy as jnp # noqa : E402
import jax_cosmo as jc # noqa : E402
import pytest # noqa : E402
from diffrax import SaveAt # noqa : E402
from diffrax import Dopri5, ODETerm, PIDController, diffeqsolve
@ -15,37 +12,21 @@ from jax import lax # noqa : E402
from jax.experimental.multihost_utils import process_allgather # noqa : E402
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P # noqa : E402
from jaxdecomp import get_fft_output_sharding
from jaxpm.distributed import uniform_particles # noqa : E402
from jaxpm.distributed import fft3d, ifft3d
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
from jaxpm.pm import lpt, make_diffrax_ode, pm_forces # noqa : E402
from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402
_TOLERANCE = 1e-12 # 🎉🎉🎉
pdims = [(1, 8), (8, 1), (4, 2), (2, 4)]
jax.config.update("jax_enable_x64", True) # Use double precision for accuracy
_TOLERANCE = 3.0 # 🙃🙃
@pytest.mark.distributed
@pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("pdims", pdims)
@pytest.mark.parametrize("absolute_painting", [True, False])
def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
pdims, absolute_painting):
if absolute_painting:
pytest.skip("Absolute painting is not recommended in distributed mode")
painting_str = "absolute" if absolute_painting else "relative"
print("=" * 50)
absolute_painting):
mesh_shape, box_shape = simulation_config
print(
f"Running with {painting_str} painting and pdims {pdims} and order {order} and mesh shape {mesh_shape}..."
)
# SINGLE DEVICE RUN
cosmo._workspace = {}
if absolute_painting:
@ -56,12 +37,12 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
particles,
a=0.1,
order=order)
ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
y0 = jnp.stack([particles + dx, p])
else:
dx, p, _ = lpt(cosmo, initial_conditions, a=0.1, order=order)
ode_fn = ODETerm(make_diffrax_ode(mesh_shape,
paint_absolute_pos=False))
ode_fn = ODETerm(
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
y0 = jnp.stack([dx, p])
solver = Dopri5()
@ -79,7 +60,6 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
t1=1.0,
dt0=None,
y0=y0,
args=cosmo,
stepsize_controller=controller,
saveat=saveat)
@ -92,7 +72,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
print("Done with single device run")
# MULTI DEVICE RUN
mesh = jax.make_mesh(pdims, ('x', 'y'))
mesh = jax.make_mesh((1, 8), ('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2
@ -114,7 +94,8 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(mesh_shape,
make_diffrax_ode(cosmo,
mesh_shape,
halo_size=halo_size,
sharding=sharding))
@ -127,7 +108,8 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
halo_size=halo_size,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(mesh_shape,
make_diffrax_ode(cosmo,
mesh_shape,
paint_absolute_pos=False,
halo_size=halo_size,
sharding=sharding))
@ -148,23 +130,16 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
t1=1.0,
dt0=None,
y0=y0,
args=cosmo,
stepsize_controller=controller,
saveat=saveat)
final_field = solutions.ys[-1, 0]
print(f"Final field sharding is {final_field.sharding}")
assert final_field.sharding.is_equivalent_to(sharding , ndim=3) \
, f"Final field sharding is not correct .. should be {sharding} it is instead {final_field.sharding}"
if absolute_painting:
multi_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape),
final_field,
solutions.ys[-1, 0],
halo_size=halo_size,
sharding=sharding)
else:
multi_device_final_field = cic_paint_dx(final_field,
multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0],
halo_size=halo_size,
sharding=sharding)
@ -175,230 +150,3 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
print(f"MSE is {mse}")
assert mse < _TOLERANCE
@pytest.mark.distributed
@pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("pdims", pdims)
def test_distrubted_gradients(simulation_config, initial_conditions, cosmo,
order, nbody_from_lpt1, nbody_from_lpt2, pdims):
mesh_shape, box_shape = simulation_config
# SINGLE DEVICE RUN
cosmo._workspace = {}
mesh = jax.make_mesh(pdims, ('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2
initial_conditions = lax.with_sharding_constraint(initial_conditions,
sharding)
print(f"sharded initial conditions {initial_conditions.sharding}")
cosmo._workspace = {}
@jax.jit
def forward_model(initial_conditions, cosmo):
dx, p, _ = lpt(cosmo,
initial_conditions,
a=0.1,
order=order,
halo_size=halo_size,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(mesh_shape,
paint_absolute_pos=False,
halo_size=halo_size,
sharding=sharding))
y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p)
solver = Dopri5()
controller = PIDController(rtol=1e-8,
atol=1e-8,
pcoeff=0.4,
icoeff=1,
dcoeff=0)
saveat = SaveAt(t1=True)
solutions = diffeqsolve(ode_fn,
solver,
t0=0.1,
t1=1.0,
dt0=None,
y0=y0,
args=cosmo,
stepsize_controller=controller,
saveat=saveat)
multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0],
halo_size=halo_size,
sharding=sharding)
return multi_device_final_field
@jax.jit
def model(initial_conditions, cosmo):
final_field = forward_model(initial_conditions, cosmo)
return MSE(final_field,
nbody_from_lpt1 if order == 1 else nbody_from_lpt2)
obs_val = model(initial_conditions, cosmo)
shifted_initial_conditions = initial_conditions + jax.random.normal(
jax.random.key(42), initial_conditions.shape) * 5
good_grads = jax.grad(model)(initial_conditions, cosmo)
off_grads = jax.grad(model)(shifted_initial_conditions, cosmo)
assert good_grads.sharding.is_equivalent_to(initial_conditions.sharding,
ndim=3)
assert off_grads.sharding.is_equivalent_to(initial_conditions.sharding,
ndim=3)
@pytest.mark.distributed
@pytest.mark.parametrize("pdims", pdims)
def test_fwd_rev_gradients(cosmo, pdims):
mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0)
cosmo._workspace = {}
mesh = jax.make_mesh(pdims, ('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2
initial_conditions = jax.random.normal(jax.random.PRNGKey(42), mesh_shape)
initial_conditions = lax.with_sharding_constraint(initial_conditions,
sharding)
print(f"sharded initial conditions {initial_conditions.sharding}")
cosmo._workspace = {}
@partial(jax.jit, static_argnums=(2, 3, 4))
def compute_forces(initial_conditions,
cosmo,
a=0.5,
halo_size=0,
sharding=None):
paint_absolute_pos = False
particles = jnp.zeros_like(initial_conditions,
shape=(*initial_conditions.shape, 3))
a = jnp.atleast_1d(a)
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
initial_conditions = jax.lax.with_sharding_constraint(
initial_conditions, sharding)
delta_k = fft3d(initial_conditions)
out_sharding = get_fft_output_sharding(sharding)
delta_k = jax.lax.with_sharding_constraint(delta_k, out_sharding)
initial_force = pm_forces(particles,
delta=delta_k,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding)
return initial_force[..., 0]
forces = compute_forces(initial_conditions,
cosmo,
halo_size=halo_size,
sharding=sharding)
back_gradient = jax.jacrev(compute_forces)(initial_conditions,
cosmo,
halo_size=halo_size,
sharding=sharding)
fwd_gradient = jax.jacfwd(compute_forces)(initial_conditions,
cosmo,
halo_size=halo_size,
sharding=sharding)
print(f"Forces sharding is {forces.sharding}")
print(f"Backward gradient sharding is {back_gradient.sharding}")
print(f"Forward gradient sharding is {fwd_gradient.sharding}")
assert forces.sharding.is_equivalent_to(initial_conditions.sharding,
ndim=3)
assert back_gradient[0, 0, 0, ...].sharding.is_equivalent_to(
initial_conditions.sharding, ndim=3)
assert fwd_gradient.sharding.is_equivalent_to(initial_conditions.sharding,
ndim=3)
@pytest.mark.distributed
@pytest.mark.parametrize("pdims", pdims)
def test_vmap(cosmo, pdims):
mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0)
cosmo._workspace = {}
mesh = jax.make_mesh(pdims, ('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2
single_dev_initial_conditions = jax.random.normal(jax.random.PRNGKey(42),
mesh_shape)
initial_conditions = lax.with_sharding_constraint(
single_dev_initial_conditions, sharding)
single_ics = jnp.stack([
single_dev_initial_conditions, single_dev_initial_conditions,
single_dev_initial_conditions
])
sharded_ics = jnp.stack(
[initial_conditions, initial_conditions, initial_conditions])
print(f"unsharded initial conditions batch {single_ics.sharding}")
print(f"sharded initial conditions batch {sharded_ics.sharding}")
cosmo._workspace = {}
@partial(jax.jit, static_argnums=(2, 3, 4))
def compute_forces(initial_conditions,
cosmo,
a=0.5,
halo_size=0,
sharding=None):
paint_absolute_pos = False
particles = jnp.zeros_like(initial_conditions,
shape=(*initial_conditions.shape, 3))
a = jnp.atleast_1d(a)
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
initial_conditions = jax.lax.with_sharding_constraint(
initial_conditions, sharding)
delta_k = fft3d(initial_conditions)
out_sharding = get_fft_output_sharding(sharding)
delta_k = jax.lax.with_sharding_constraint(delta_k, out_sharding)
initial_force = pm_forces(particles,
delta=delta_k,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding)
return initial_force[..., 0]
def fn(ic):
return compute_forces(ic,
cosmo,
halo_size=halo_size,
sharding=sharding)
v_compute_forces = jax.vmap(fn)
print(f"single_ics shape {single_ics.shape}")
print(f"sharded_ics shape {sharded_ics.shape}")
single_dev_forces = v_compute_forces(single_ics)
sharded_forces = v_compute_forces(sharded_ics)
assert single_dev_forces.ndim == 4
assert sharded_forces.ndim == 4
print(f"Sharded forces {sharded_forces.sharding}")
assert sharded_forces[0].sharding.is_equivalent_to(
initial_conditions.sharding, ndim=3)
assert sharded_forces.sharding.spec[0] == None

View file

@ -1,88 +0,0 @@
import jax
import pytest
from diffrax import (BacksolveAdjoint, Dopri5, ODETerm, PIDController,
RecursiveCheckpointAdjoint, SaveAt, diffeqsolve)
from helpers import MSE
from jax import numpy as jnp
from jaxpm.distributed import uniform_particles
from jaxpm.painting import cic_paint, cic_paint_dx
from jaxpm.pm import lpt, make_diffrax_ode
@pytest.mark.single_device
@pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("absolute_painting", [True, False])
@pytest.mark.parametrize("adjoint", ['DTO', 'OTD'])
def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
nbody_from_lpt1, nbody_from_lpt2, cosmo, order,
absolute_painting, adjoint):
mesh_shape, _ = simulation_config
cosmo._workspace = {}
if adjoint == 'OTD':
pytest.skip("OTD adjoint not implemented yet (needs PFFT3D JVP)")
adjoint = RecursiveCheckpointAdjoint(
) if adjoint == 'DTO' else BacksolveAdjoint(solver=Dopri5())
@jax.jit
@jax.grad
def forward_model(initial_conditions, cosmo):
# Initial displacement
if absolute_painting:
particles = uniform_particles(mesh_shape)
dx, p, _ = lpt(cosmo,
initial_conditions,
particles,
a=lpt_scale_factor,
order=order)
ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
y0 = jnp.stack([particles + dx, p])
else:
dx, p, _ = lpt(cosmo,
initial_conditions,
a=lpt_scale_factor,
order=order)
ode_fn = ODETerm(
make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
y0 = jnp.stack([dx, p])
solver = Dopri5()
controller = PIDController(rtol=1e-7,
atol=1e-7,
pcoeff=0.4,
icoeff=1,
dcoeff=0)
saveat = SaveAt(t1=True)
solutions = diffeqsolve(ode_fn,
solver,
t0=lpt_scale_factor,
t1=1.0,
dt0=None,
y0=y0,
args=cosmo,
adjoint=adjoint,
stepsize_controller=controller,
saveat=saveat)
if absolute_painting:
final_field = cic_paint(jnp.zeros(mesh_shape), solutions.ys[-1, 0])
else:
final_field = cic_paint_dx(solutions.ys[-1, 0])
return MSE(final_field,
nbody_from_lpt1 if order == 1 else nbody_from_lpt2)
bad_initial_conditions = initial_conditions + jax.random.normal(
jax.random.PRNGKey(0), initial_conditions.shape) * 0.5
best_ic = forward_model(initial_conditions, cosmo)
bad_ic = forward_model(bad_initial_conditions, cosmo)
assert jnp.max(best_ic) < 1e-5
assert jnp.max(bad_ic) > 1e-5