Compare commits

..

No commits in common. "main" and "v0.1.0" have entirely different histories.
main ... v0.1.0

17 changed files with 34 additions and 2449 deletions

View file

@ -1,55 +0,0 @@
name: Upload Python Package
on:
release:
types: [published]
permissions:
contents: read
jobs:
release-build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.x"
- name: Build release distributions
run: |
# NOTE: put your own distribution build steps here.
python -m pip install build
python -m build
- name: Upload distributions
uses: actions/upload-artifact@v4
with:
name: release-dists
path: dist/
pypi-publish:
runs-on: ubuntu-latest
needs:
- release-build
permissions:
# IMPORTANT: this permission is mandatory for trusted publishing
id-token: write
environment:
name: pypi
url: https://pypi.org/p/jaxpm
steps:
- name: Retrieve release distributions
uses: actions/download-artifact@v4
with:
name: release-dists
path: dist/
- name: Publish release distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
packages-dir: dist/

View file

@ -34,8 +34,7 @@ jobs:
pip install git+https://github.com/MP-Gadget/pfft-python
pip install git+https://github.com/MP-Gadget/pmesh
pip install git+https://github.com/ASKabalan/fastpm-python --no-build-isolation
pip install -r requirements-test.txt
pip install .
pip install .[test]
- name: Run Single Device Tests
run: |

3
.gitignore vendored
View file

@ -132,6 +132,3 @@ dmypy.json
# Pyre type checker
.pyre/
# Hide version file
_version.py

View file

@ -1,6 +1,6 @@
MIT License
Copyright (c) 2021-2025 Differentiable Universe Initiative
Copyright (c) 2021 Differentiable Universe Initiative
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal

View file

@ -1,2 +0,0 @@
prune notebooks
prune tests

View file

@ -1,26 +1,9 @@
# JaxPM
[![Notebook](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/DifferentiableUniverseInitiative/JaxPM/blob/main/notebooks/01-Introduction.ipynb)
[![PyPI version](https://img.shields.io/pypi/v/jaxpm)](https://pypi.org/project/jaxpm/) [![Tests](https://github.com/DifferentiableUniverseInitiative/JaxPM/actions/workflows/tests.yml/badge.svg)](https://github.com/DifferentiableUniverseInitiative/JaxPM/actions/workflows/tests.yml) <!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
[![Tests](https://github.com/DifferentiableUniverseInitiative/JaxPM/actions/workflows/tests.yml/badge.svg)](https://github.com/DifferentiableUniverseInitiative/JaxPM/actions/workflows/tests.yml) <!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
[![All Contributors](https://img.shields.io/badge/all_contributors-5-orange.svg?style=flat-square)](#contributors-)
<!-- ALL-CONTRIBUTORS-BADGE:END -->
JAX-powered Cosmological Particle-Mesh N-body Solver
> ### Note
> **The new JaxPM v0.1.xx** supports multi-GPU model distribution while remaining compatible with previous releases. These significant changes are still under development and testing, so please report any issues you encounter.
> For the older but more stable version, install:
> ```bash
> pip install jaxpm==0.0.2
> ```
## Install
Basic installation can be done using pip:
```bash
pip install jaxpm
```
For more advanced installation for optimized distribution on gpu clusters, please install jaxDecomp first. See instructions [here](https://github.com/DifferentiableUniverseInitiative/jaxDecomp).
## Goals
Provide a modern infrastructure to support differentiable PM N-body simulations using JAX:

14
dev/job_pfft.sh Normal file
View file

@ -0,0 +1,14 @@
#!/bin/bash
#SBATCH -A m1727
#SBATCH -C gpu
#SBATCH -q debug
#SBATCH -t 0:05:00
#SBATCH -N 2
#SBATCH --ntasks-per-node=4
#SBATCH -c 32
#SBATCH --gpus-per-task=1
#SBATCH --gpu-bind=none
module load python cudnn/8.2.0 nccl/2.11.4 cudatoolkit
export SLURM_CPU_BIND="cores"
srun python test_pfft.py

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -1,320 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# **Animating Particle Mesh density fields**\n",
"\n",
"In this tutorial, we will animate the density field of a particle mesh simulation. We will use the `manim` library to create the animation. \n",
"\n",
"The density fields are created exactly like in the notebook [**05-MultiHost_PM.ipynb**](05-MultiHost_PM.ipynb) using the same script [**05-MultiHost_PM.py**](05-MultiHost_PM.py)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To run a multi-host simulation, you first need to **allocate a job** with `salloc`. This command requests resources on an HPC cluster.\n",
"\n",
"just like in notebook [**05-MultiHost_PM.ipynb**]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!salloc --account=XXX@a100 -C a100 --gres=gpu:8 --ntasks-per-node=8 --time=00:40:00 --cpus-per-task=8 --hint=nomultithread --qos=qos_gpu-dev --nodes=4 & "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**A few hours later**\n",
"\n",
"Use `!squeue -u $USER -o \"%i %D %b\"` to **check the JOB ID** and verify your resource allocation.\n",
"\n",
"In this example, weve been allocated **32 GPUs split across 4 nodes**.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!squeue -u $USER -o \"%i %D %b\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Unset the following environment variables, as they can cause issues when using JAX in a distributed setting:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"del os.environ['VSCODE_PROXY_URI']\n",
"del os.environ['NO_PROXY']\n",
"del os.environ['no_proxy']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Checking Available Compute Resources\n",
"\n",
"Run the following command to initialize JAX distributed computing and display the devices available for this job:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!srun --jobid=467745 -n 32 python -c \"import jax; jax.distributed.initialize(); print(jax.devices()) if jax.process_index() == 0 else None\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Multi-Host Simulation Script with Arguments (reminder)\n",
"\n",
"This script is nearly identical to the single-host version, with the main addition being the call to `jax.distributed.initialize()` at the start, enabling multi-host parallelism. Heres a breakdown of the key arguments:\n",
"\n",
"- **`--pdims`** (`-p`): Specifies processor grid dimensions as two integers, like `16 2` for 16 x 2 device mesh (default is `[1, jax.devices()]`).\n",
"- **`--mesh_shape`** (`-m`): Defines the simulation mesh shape as three integers (default is `[512, 512, 512]`).\n",
"- **`--box_size`** (`-b`): Sets the physical box size of the simulation as three floating-point values, e.g., `1000. 1000. 1000.` (default is `[500.0, 500.0, 500.0]`).\n",
"- **`--halo_size`** (`-H`): Specifies the halo size for boundary overlap across nodes (default is `64`).\n",
"- **`--solver`** (`-s`): Chooses the ODE solver (`leapfrog` or `dopri8`). The `leapfrog` solver uses a fixed step size, while `dopri8` is an adaptive Runge-Kutta solver with a PID controller (default is `leapfrog`).\n",
"- **`--snapthots`** (`-st`) : Number of snapshots to save (warning, increases memory usage)\n",
"\n",
"### Running the Multi-Host Simulation Script\n",
"\n",
"To create a smooth animation, we need a series of closely spaced snapshots to capture the evolution of the density field over time. In this example, we set the number of snapshots to **10** to ensure smooth transitions in the animation.\n",
"\n",
"Using a larger number of GPUs helps process these snapshots efficiently, especially with a large simulation mesh or high-resolution data. This allows us to achieve both the desired snapshot frequency and the necessary simulation detail without excessive runtime.\n",
"\n",
"The command to run the multi-host simulation with these settings will look something like this:\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import subprocess\n",
"\n",
"# Define parameters as variables\n",
"jobid = \"467745\"\n",
"num_processes = 32\n",
"script_name = \"05-MultiHost_PM.py\"\n",
"mesh_shape = (1024, 1024, 1024)\n",
"box_size = (1000., 1000., 1000.)\n",
"halo_size = 128\n",
"solver = \"leapfrog\"\n",
"pdims = (16, 2)\n",
"snapshots = 8\n",
"\n",
"# Build the command as a list, incorporating variables\n",
"command = [\n",
" \"srun\",\n",
" f\"--jobid={jobid}\",\n",
" \"-n\", str(num_processes),\n",
" \"python\", script_name,\n",
" \"--mesh_shape\", str(mesh_shape[0]), str(mesh_shape[1]), str(mesh_shape[2]),\n",
" \"--box_size\", str(box_size[0]), str(box_size[1]), str(box_size[2]),\n",
" \"--halo_size\", str(halo_size),\n",
" \"-s\", solver,\n",
" \"--pdims\", str(pdims[0]), str(pdims[1]),\n",
" \"--snapshots\", str(snapshots)\n",
"]\n",
"\n",
"# Execute the command as a subprocess\n",
"subprocess.run(command)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Projecting the 3D Density Fields to 2D\n",
"\n",
"To visualize the 3D density fields in 2D, we need to create a projection:\n",
"\n",
"- **`project_to_2d` Function**: This function reduces the 3D array to 2D by summing over a portion of one axis.\n",
" - We sum the top one-eighth of the data along the first axis to capture a slice of the density field.\n",
"\n",
"- **Creating 2D Projections**: Apply `project_to_2d` to each 3D field (`initial_conditions`, `lpt_displacements`, `ode_solution_0`, and `ode_solution_1`) to get 2D arrays that represent the density fields.\n",
"\n",
"### Applying the Magma Colormap\n",
"\n",
"To improve visualization, apply the \"magma\" colormap to each 2D projection:\n",
"\n",
"- **`apply_colormap` Function**: This function maps values in the 2D array to colors using the \"magma\" colormap.\n",
" - First, normalize the array to the `[0, 1]` range.\n",
" - Apply the colormap to create RGB images, which will be used for the animation.\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from matplotlib import colormaps\n",
"\n",
"# Define a function to project the 3D field to 2D\n",
"def project_to_2d(field):\n",
" sum_over = field.shape[0] // 8\n",
" slicing = [slice(None)] * field.ndim\n",
" slicing[0] = slice(None, sum_over)\n",
" slicing = tuple(slicing)\n",
"\n",
" return field[slicing].sum(axis=0)\n",
"\n",
"\n",
"def apply_colormap(array, cmap_name=\"magma\"):\n",
" cmap = colormaps[cmap_name]\n",
" normalized_array = (array - array.min()) / (array.max() - array.min())\n",
" colored_image = cmap(normalized_array)[:, :, :3] # Drop alpha channel for RGB\n",
" return (colored_image * 255).astype(np.uint8)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Loading and Visualizing Results\n",
"\n",
"After running the multi-host simulation, we load the saved results from disk:\n",
"\n",
"- **`initial_conditions.npy`**: Initial conditions for the simulation.\n",
"- **`lpt_displacements.npy`**: Linear perturbation displacements.\n",
"- **`ode_solution_*.npy`** : Solutions from the ODE solver at each snapshot.\n",
"\n",
"We will now project the fields to 2D maps and apply the color map\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"initial_conditions = apply_colormap(project_to_2d(np.load('fields/initial_conditions.npy')))\n",
"lpt_displacements = apply_colormap(project_to_2d(np.load('fields/lpt_displacements.npy')))\n",
"ode_solutions = []\n",
"for i in range(8):\n",
" ode_solutions.append(apply_colormap(project_to_2d(np.load(f'fields/ode_solution_{i}.npy'))))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Animating with Manim\n",
"\n",
"To create animations with `manim` in a Jupyter notebook, we start by configuring some settings to ensure the output displays correctly and without a background.\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from manim import *\n",
"config.media_width = \"100%\"\n",
"config.verbosity = \"WARNING\"\n",
"config.background_color = \"#00000000\" # Transparent background"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Defining the Animation in Manim\n",
"\n",
"This animation class, `FieldTransition`, smoothly transitions through the stages of the particle mesh density field evolution.\n",
"\n",
"- **Setup**: Each density field snapshot is loaded as an image and aligned for smooth transitions.\n",
"- **Animation Sequence**:\n",
" - The animation begins with a fade-in of the initial conditions.\n",
" - It then transitions through the stages in sequence, showing each snapshot of the density field evolution with brief pauses in between.\n",
"\n",
"To run the animation, execute `%manim -v WARNING -qm FieldTransition` to render it in the Jupyter Notebook.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Define the animation in Manim\n",
"class FieldTransition(Scene):\n",
" def construct(self):\n",
" init_conditions_img = ImageMobject(initial_conditions).scale(4)\n",
" lpt_img = ImageMobject(lpt_displacements).scale(4)\n",
" snapshots_imgs = [ImageMobject(sol).scale(4) for sol in ode_solutions]\n",
"\n",
"\n",
" # Place the images on top of each other initially\n",
" lpt_img.move_to(init_conditions_img)\n",
" for img in snapshots_imgs:\n",
" img.move_to(init_conditions_img)\n",
"\n",
" # Show initial field and then transform between fields\n",
" self.play(FadeIn(init_conditions_img))\n",
" self.wait(0.2)\n",
" self.play(Transform(init_conditions_img, lpt_img))\n",
" self.wait(0.2)\n",
" self.play(Transform(lpt_img, snapshots_imgs[0]))\n",
" self.wait(0.2)\n",
" for img1, img2 in zip(snapshots_imgs, snapshots_imgs[1:]):\n",
" self.play(Transform(img1, img2))\n",
" self.wait(0.2)\n",
"\n",
"%manim -v WARNING -qm -o anim.gif --format=gif FieldTransition "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View file

@ -0,0 +1 @@
c4a44973e4f11841a8c14f4d200e7e87887419aa

View file

@ -3,15 +3,28 @@ requires = ["setuptools", "wheel", "setuptools-scm"]
build-backend = "setuptools.build_meta"
[project]
name = "jaxpm"
name = "JaxPM"
dynamic = ["version"]
description = "A simple Particle-Mesh implementation in JAX"
description = "A dead simple FastPM implementation in JAX"
authors = [{ name = "JaxPM developers" }]
readme = "README.md"
requires-python = ">=3.9"
license = { file = "LICENSE" }
urls = { "Homepage" = "https://github.com/DifferentiableUniverseInitiative/JaxPM" }
dependencies = ["jax_cosmo", "jax>=0.4.35", "jaxdecomp>=0.2.3"]
dependencies = ["jax_cosmo", "jax>=0.4.30", "jaxdecomp>=0.2.2"]
[project.optional-dependencies]
test = [
"jax>=0.4.30",
"numpy",
"jax_cosmo",
"jaxdecomp>=0.2.2",
"pytest>=8.0.0",
"pfft-python @ git+https://github.com/MP-Gadget/pfft-python",
"pmesh @ git+https://github.com/MP-Gadget/pmesh",
"fastpm @ git+https://github.com/ASKabalan/fastpm-python",
"diffrax"
]
[tool.setuptools]
packages = ["jaxpm"]

View file

@ -1,5 +0,0 @@
pytest>=8.0.0
diffrax
pfft-python @ git+https://github.com/MP-Gadget/pfft-python
pmesh @ git+https://github.com/MP-Gadget/pmesh
fastpm @ git+https://github.com/ASKabalan/fastpm-python

View file

@ -1,87 +0,0 @@
import jax
import pytest
from diffrax import (BacksolveAdjoint, Dopri5, ODETerm, PIDController,
RecursiveCheckpointAdjoint, SaveAt, diffeqsolve)
from helpers import MSE
from jax import numpy as jnp
from jaxpm.distributed import uniform_particles
from jaxpm.painting import cic_paint, cic_paint_dx
from jaxpm.pm import lpt, make_diffrax_ode
@pytest.mark.single_device
@pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("absolute_painting", [True, False])
@pytest.mark.parametrize("adjoint", ['DTO', 'OTD'])
def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
nbody_from_lpt1, nbody_from_lpt2, cosmo, order,
absolute_painting, adjoint):
mesh_shape, _ = simulation_config
cosmo._workspace = {}
if adjoint == 'OTD':
pytest.skip("OTD adjoint not implemented yet (needs PFFT3D JVP)")
adjoint = RecursiveCheckpointAdjoint(
) if adjoint == 'DTO' else BacksolveAdjoint(solver=Dopri5())
@jax.jit
@jax.grad
def forward_model(initial_conditions, cosmo):
# Initial displacement
if absolute_painting:
particles = uniform_particles(mesh_shape)
dx, p, _ = lpt(cosmo,
initial_conditions,
particles,
a=lpt_scale_factor,
order=order)
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
y0 = jnp.stack([particles + dx, p])
else:
dx, p, _ = lpt(cosmo,
initial_conditions,
a=lpt_scale_factor,
order=order)
ode_fn = ODETerm(
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
y0 = jnp.stack([dx, p])
solver = Dopri5()
controller = PIDController(rtol=1e-7,
atol=1e-7,
pcoeff=0.4,
icoeff=1,
dcoeff=0)
saveat = SaveAt(t1=True)
solutions = diffeqsolve(ode_fn,
solver,
t0=lpt_scale_factor,
t1=1.0,
dt0=None,
y0=y0,
adjoint=adjoint,
stepsize_controller=controller,
saveat=saveat)
if absolute_painting:
final_field = cic_paint(jnp.zeros(mesh_shape), solutions.ys[-1, 0])
else:
final_field = cic_paint_dx(solutions.ys[-1, 0])
return MSE(final_field,
nbody_from_lpt1 if order == 1 else nbody_from_lpt2)
bad_initial_conditions = initial_conditions + jax.random.normal(
jax.random.PRNGKey(0), initial_conditions.shape) * 0.5
best_ic = forward_model(initial_conditions, cosmo)
bad_ic = forward_model(bad_initial_conditions, cosmo)
assert jnp.max(best_ic) < 1e-5
assert jnp.max(bad_ic) > 1e-5