mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-06 03:40:54 +00:00
Compare commits
12 commits
Author | SHA1 | Date | |
---|---|---|---|
|
cb2a7ab17f | ||
|
d81a2529e7 | ||
|
15f2fb1ee6 | ||
|
ae0f439ae4 | ||
|
ea9fbf6aa8 | ||
|
ad16a0659a | ||
|
f245a1f685 | ||
|
160b86eb71 | ||
|
f14f0fe68e | ||
|
70ab9f1931 | ||
|
bc6e57532d | ||
|
4b4450d7d3 |
17 changed files with 2449 additions and 34 deletions
55
.github/workflows/python-publish.yml
vendored
Normal file
55
.github/workflows/python-publish.yml
vendored
Normal file
|
@ -0,0 +1,55 @@
|
|||
name: Upload Python Package
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
release-build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.x"
|
||||
|
||||
- name: Build release distributions
|
||||
run: |
|
||||
# NOTE: put your own distribution build steps here.
|
||||
python -m pip install build
|
||||
python -m build
|
||||
|
||||
- name: Upload distributions
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: release-dists
|
||||
path: dist/
|
||||
|
||||
pypi-publish:
|
||||
runs-on: ubuntu-latest
|
||||
needs:
|
||||
- release-build
|
||||
permissions:
|
||||
# IMPORTANT: this permission is mandatory for trusted publishing
|
||||
id-token: write
|
||||
|
||||
environment:
|
||||
name: pypi
|
||||
url: https://pypi.org/p/jaxpm
|
||||
|
||||
steps:
|
||||
- name: Retrieve release distributions
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: release-dists
|
||||
path: dist/
|
||||
|
||||
- name: Publish release distributions to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
packages-dir: dist/
|
3
.github/workflows/tests.yml
vendored
3
.github/workflows/tests.yml
vendored
|
@ -34,7 +34,8 @@ jobs:
|
|||
pip install git+https://github.com/MP-Gadget/pfft-python
|
||||
pip install git+https://github.com/MP-Gadget/pmesh
|
||||
pip install git+https://github.com/ASKabalan/fastpm-python --no-build-isolation
|
||||
pip install .[test]
|
||||
pip install -r requirements-test.txt
|
||||
pip install .
|
||||
|
||||
- name: Run Single Device Tests
|
||||
run: |
|
||||
|
|
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -132,3 +132,6 @@ dmypy.json
|
|||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# Hide version file
|
||||
_version.py
|
||||
|
|
2
LICENSE
2
LICENSE
|
@ -1,6 +1,6 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2021 Differentiable Universe Initiative
|
||||
Copyright (c) 2021-2025 Differentiable Universe Initiative
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
|
|
2
MANIFEST.in
Normal file
2
MANIFEST.in
Normal file
|
@ -0,0 +1,2 @@
|
|||
prune notebooks
|
||||
prune tests
|
19
README.md
19
README.md
|
@ -1,9 +1,26 @@
|
|||
# JaxPM
|
||||
[](https://github.com/DifferentiableUniverseInitiative/JaxPM/actions/workflows/tests.yml) <!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
|
||||
[](https://colab.research.google.com/github/DifferentiableUniverseInitiative/JaxPM/blob/main/notebooks/01-Introduction.ipynb)
|
||||
[](https://pypi.org/project/jaxpm/) [](https://github.com/DifferentiableUniverseInitiative/JaxPM/actions/workflows/tests.yml) <!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
|
||||
[](#contributors-)
|
||||
<!-- ALL-CONTRIBUTORS-BADGE:END -->
|
||||
JAX-powered Cosmological Particle-Mesh N-body Solver
|
||||
|
||||
> ### Note
|
||||
> **The new JaxPM v0.1.xx** supports multi-GPU model distribution while remaining compatible with previous releases. These significant changes are still under development and testing, so please report any issues you encounter.
|
||||
> For the older but more stable version, install:
|
||||
> ```bash
|
||||
> pip install jaxpm==0.0.2
|
||||
> ```
|
||||
|
||||
## Install
|
||||
|
||||
Basic installation can be done using pip:
|
||||
```bash
|
||||
pip install jaxpm
|
||||
```
|
||||
For more advanced installation for optimized distribution on gpu clusters, please install jaxDecomp first. See instructions [here](https://github.com/DifferentiableUniverseInitiative/jaxDecomp).
|
||||
|
||||
|
||||
## Goals
|
||||
|
||||
Provide a modern infrastructure to support differentiable PM N-body simulations using JAX:
|
||||
|
|
|
@ -1,14 +0,0 @@
|
|||
#!/bin/bash
|
||||
#SBATCH -A m1727
|
||||
#SBATCH -C gpu
|
||||
#SBATCH -q debug
|
||||
#SBATCH -t 0:05:00
|
||||
#SBATCH -N 2
|
||||
#SBATCH --ntasks-per-node=4
|
||||
#SBATCH -c 32
|
||||
#SBATCH --gpus-per-task=1
|
||||
#SBATCH --gpu-bind=none
|
||||
|
||||
module load python cudnn/8.2.0 nccl/2.11.4 cudatoolkit
|
||||
export SLURM_CPU_BIND="cores"
|
||||
srun python test_pfft.py
|
164
notebooks/01-Introduction.ipynb
Normal file
164
notebooks/01-Introduction.ipynb
Normal file
File diff suppressed because one or more lines are too long
413
notebooks/02-Advanced_usage.ipynb
Normal file
413
notebooks/02-Advanced_usage.ipynb
Normal file
File diff suppressed because one or more lines are too long
697
notebooks/03-MultiGPU_PM_Halo.ipynb
Normal file
697
notebooks/03-MultiGPU_PM_Halo.ipynb
Normal file
File diff suppressed because one or more lines are too long
379
notebooks/04-MultiGPU_PM_Solvers.ipynb
Normal file
379
notebooks/04-MultiGPU_PM_Solvers.ipynb
Normal file
File diff suppressed because one or more lines are too long
300
notebooks/05-MultiHost_PM.ipynb
Normal file
300
notebooks/05-MultiHost_PM.ipynb
Normal file
File diff suppressed because one or more lines are too long
320
notebooks/06-Animating_PM_Fields.ipynb
Normal file
320
notebooks/06-Animating_PM_Fields.ipynb
Normal file
|
@ -0,0 +1,320 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# **Animating Particle Mesh density fields**\n",
|
||||
"\n",
|
||||
"In this tutorial, we will animate the density field of a particle mesh simulation. We will use the `manim` library to create the animation. \n",
|
||||
"\n",
|
||||
"The density fields are created exactly like in the notebook [**05-MultiHost_PM.ipynb**](05-MultiHost_PM.ipynb) using the same script [**05-MultiHost_PM.py**](05-MultiHost_PM.py)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To run a multi-host simulation, you first need to **allocate a job** with `salloc`. This command requests resources on an HPC cluster.\n",
|
||||
"\n",
|
||||
"just like in notebook [**05-MultiHost_PM.ipynb**]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!salloc --account=XXX@a100 -C a100 --gres=gpu:8 --ntasks-per-node=8 --time=00:40:00 --cpus-per-task=8 --hint=nomultithread --qos=qos_gpu-dev --nodes=4 & "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**A few hours later**\n",
|
||||
"\n",
|
||||
"Use `!squeue -u $USER -o \"%i %D %b\"` to **check the JOB ID** and verify your resource allocation.\n",
|
||||
"\n",
|
||||
"In this example, we’ve been allocated **32 GPUs split across 4 nodes**.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!squeue -u $USER -o \"%i %D %b\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Unset the following environment variables, as they can cause issues when using JAX in a distributed setting:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"del os.environ['VSCODE_PROXY_URI']\n",
|
||||
"del os.environ['NO_PROXY']\n",
|
||||
"del os.environ['no_proxy']"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Checking Available Compute Resources\n",
|
||||
"\n",
|
||||
"Run the following command to initialize JAX distributed computing and display the devices available for this job:\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!srun --jobid=467745 -n 32 python -c \"import jax; jax.distributed.initialize(); print(jax.devices()) if jax.process_index() == 0 else None\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Multi-Host Simulation Script with Arguments (reminder)\n",
|
||||
"\n",
|
||||
"This script is nearly identical to the single-host version, with the main addition being the call to `jax.distributed.initialize()` at the start, enabling multi-host parallelism. Here’s a breakdown of the key arguments:\n",
|
||||
"\n",
|
||||
"- **`--pdims`** (`-p`): Specifies processor grid dimensions as two integers, like `16 2` for 16 x 2 device mesh (default is `[1, jax.devices()]`).\n",
|
||||
"- **`--mesh_shape`** (`-m`): Defines the simulation mesh shape as three integers (default is `[512, 512, 512]`).\n",
|
||||
"- **`--box_size`** (`-b`): Sets the physical box size of the simulation as three floating-point values, e.g., `1000. 1000. 1000.` (default is `[500.0, 500.0, 500.0]`).\n",
|
||||
"- **`--halo_size`** (`-H`): Specifies the halo size for boundary overlap across nodes (default is `64`).\n",
|
||||
"- **`--solver`** (`-s`): Chooses the ODE solver (`leapfrog` or `dopri8`). The `leapfrog` solver uses a fixed step size, while `dopri8` is an adaptive Runge-Kutta solver with a PID controller (default is `leapfrog`).\n",
|
||||
"- **`--snapthots`** (`-st`) : Number of snapshots to save (warning, increases memory usage)\n",
|
||||
"\n",
|
||||
"### Running the Multi-Host Simulation Script\n",
|
||||
"\n",
|
||||
"To create a smooth animation, we need a series of closely spaced snapshots to capture the evolution of the density field over time. In this example, we set the number of snapshots to **10** to ensure smooth transitions in the animation.\n",
|
||||
"\n",
|
||||
"Using a larger number of GPUs helps process these snapshots efficiently, especially with a large simulation mesh or high-resolution data. This allows us to achieve both the desired snapshot frequency and the necessary simulation detail without excessive runtime.\n",
|
||||
"\n",
|
||||
"The command to run the multi-host simulation with these settings will look something like this:\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import subprocess\n",
|
||||
"\n",
|
||||
"# Define parameters as variables\n",
|
||||
"jobid = \"467745\"\n",
|
||||
"num_processes = 32\n",
|
||||
"script_name = \"05-MultiHost_PM.py\"\n",
|
||||
"mesh_shape = (1024, 1024, 1024)\n",
|
||||
"box_size = (1000., 1000., 1000.)\n",
|
||||
"halo_size = 128\n",
|
||||
"solver = \"leapfrog\"\n",
|
||||
"pdims = (16, 2)\n",
|
||||
"snapshots = 8\n",
|
||||
"\n",
|
||||
"# Build the command as a list, incorporating variables\n",
|
||||
"command = [\n",
|
||||
" \"srun\",\n",
|
||||
" f\"--jobid={jobid}\",\n",
|
||||
" \"-n\", str(num_processes),\n",
|
||||
" \"python\", script_name,\n",
|
||||
" \"--mesh_shape\", str(mesh_shape[0]), str(mesh_shape[1]), str(mesh_shape[2]),\n",
|
||||
" \"--box_size\", str(box_size[0]), str(box_size[1]), str(box_size[2]),\n",
|
||||
" \"--halo_size\", str(halo_size),\n",
|
||||
" \"-s\", solver,\n",
|
||||
" \"--pdims\", str(pdims[0]), str(pdims[1]),\n",
|
||||
" \"--snapshots\", str(snapshots)\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"# Execute the command as a subprocess\n",
|
||||
"subprocess.run(command)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Projecting the 3D Density Fields to 2D\n",
|
||||
"\n",
|
||||
"To visualize the 3D density fields in 2D, we need to create a projection:\n",
|
||||
"\n",
|
||||
"- **`project_to_2d` Function**: This function reduces the 3D array to 2D by summing over a portion of one axis.\n",
|
||||
" - We sum the top one-eighth of the data along the first axis to capture a slice of the density field.\n",
|
||||
"\n",
|
||||
"- **Creating 2D Projections**: Apply `project_to_2d` to each 3D field (`initial_conditions`, `lpt_displacements`, `ode_solution_0`, and `ode_solution_1`) to get 2D arrays that represent the density fields.\n",
|
||||
"\n",
|
||||
"### Applying the Magma Colormap\n",
|
||||
"\n",
|
||||
"To improve visualization, apply the \"magma\" colormap to each 2D projection:\n",
|
||||
"\n",
|
||||
"- **`apply_colormap` Function**: This function maps values in the 2D array to colors using the \"magma\" colormap.\n",
|
||||
" - First, normalize the array to the `[0, 1]` range.\n",
|
||||
" - Apply the colormap to create RGB images, which will be used for the animation.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from matplotlib import colormaps\n",
|
||||
"\n",
|
||||
"# Define a function to project the 3D field to 2D\n",
|
||||
"def project_to_2d(field):\n",
|
||||
" sum_over = field.shape[0] // 8\n",
|
||||
" slicing = [slice(None)] * field.ndim\n",
|
||||
" slicing[0] = slice(None, sum_over)\n",
|
||||
" slicing = tuple(slicing)\n",
|
||||
"\n",
|
||||
" return field[slicing].sum(axis=0)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def apply_colormap(array, cmap_name=\"magma\"):\n",
|
||||
" cmap = colormaps[cmap_name]\n",
|
||||
" normalized_array = (array - array.min()) / (array.max() - array.min())\n",
|
||||
" colored_image = cmap(normalized_array)[:, :, :3] # Drop alpha channel for RGB\n",
|
||||
" return (colored_image * 255).astype(np.uint8)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Loading and Visualizing Results\n",
|
||||
"\n",
|
||||
"After running the multi-host simulation, we load the saved results from disk:\n",
|
||||
"\n",
|
||||
"- **`initial_conditions.npy`**: Initial conditions for the simulation.\n",
|
||||
"- **`lpt_displacements.npy`**: Linear perturbation displacements.\n",
|
||||
"- **`ode_solution_*.npy`** : Solutions from the ODE solver at each snapshot.\n",
|
||||
"\n",
|
||||
"We will now project the fields to 2D maps and apply the color map\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"initial_conditions = apply_colormap(project_to_2d(np.load('fields/initial_conditions.npy')))\n",
|
||||
"lpt_displacements = apply_colormap(project_to_2d(np.load('fields/lpt_displacements.npy')))\n",
|
||||
"ode_solutions = []\n",
|
||||
"for i in range(8):\n",
|
||||
" ode_solutions.append(apply_colormap(project_to_2d(np.load(f'fields/ode_solution_{i}.npy'))))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Animating with Manim\n",
|
||||
"\n",
|
||||
"To create animations with `manim` in a Jupyter notebook, we start by configuring some settings to ensure the output displays correctly and without a background.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from manim import *\n",
|
||||
"config.media_width = \"100%\"\n",
|
||||
"config.verbosity = \"WARNING\"\n",
|
||||
"config.background_color = \"#00000000\" # Transparent background"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Defining the Animation in Manim\n",
|
||||
"\n",
|
||||
"This animation class, `FieldTransition`, smoothly transitions through the stages of the particle mesh density field evolution.\n",
|
||||
"\n",
|
||||
"- **Setup**: Each density field snapshot is loaded as an image and aligned for smooth transitions.\n",
|
||||
"- **Animation Sequence**:\n",
|
||||
" - The animation begins with a fade-in of the initial conditions.\n",
|
||||
" - It then transitions through the stages in sequence, showing each snapshot of the density field evolution with brief pauses in between.\n",
|
||||
"\n",
|
||||
"To run the animation, execute `%manim -v WARNING -qm FieldTransition` to render it in the Jupyter Notebook.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Define the animation in Manim\n",
|
||||
"class FieldTransition(Scene):\n",
|
||||
" def construct(self):\n",
|
||||
" init_conditions_img = ImageMobject(initial_conditions).scale(4)\n",
|
||||
" lpt_img = ImageMobject(lpt_displacements).scale(4)\n",
|
||||
" snapshots_imgs = [ImageMobject(sol).scale(4) for sol in ode_solutions]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" # Place the images on top of each other initially\n",
|
||||
" lpt_img.move_to(init_conditions_img)\n",
|
||||
" for img in snapshots_imgs:\n",
|
||||
" img.move_to(init_conditions_img)\n",
|
||||
"\n",
|
||||
" # Show initial field and then transform between fields\n",
|
||||
" self.play(FadeIn(init_conditions_img))\n",
|
||||
" self.wait(0.2)\n",
|
||||
" self.play(Transform(init_conditions_img, lpt_img))\n",
|
||||
" self.wait(0.2)\n",
|
||||
" self.play(Transform(lpt_img, snapshots_imgs[0]))\n",
|
||||
" self.wait(0.2)\n",
|
||||
" for img1, img2 in zip(snapshots_imgs, snapshots_imgs[1:]):\n",
|
||||
" self.play(Transform(img1, img2))\n",
|
||||
" self.wait(0.2)\n",
|
||||
"\n",
|
||||
"%manim -v WARNING -qm -o anim.gif --format=gif FieldTransition "
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -1 +0,0 @@
|
|||
c4a44973e4f11841a8c14f4d200e7e87887419aa
|
|
@ -3,28 +3,15 @@ requires = ["setuptools", "wheel", "setuptools-scm"]
|
|||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "JaxPM"
|
||||
name = "jaxpm"
|
||||
dynamic = ["version"]
|
||||
description = "A dead simple FastPM implementation in JAX"
|
||||
description = "A simple Particle-Mesh implementation in JAX"
|
||||
authors = [{ name = "JaxPM developers" }]
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
license = { file = "LICENSE" }
|
||||
urls = { "Homepage" = "https://github.com/DifferentiableUniverseInitiative/JaxPM" }
|
||||
dependencies = ["jax_cosmo", "jax>=0.4.30", "jaxdecomp>=0.2.2"]
|
||||
|
||||
[project.optional-dependencies]
|
||||
test = [
|
||||
"jax>=0.4.30",
|
||||
"numpy",
|
||||
"jax_cosmo",
|
||||
"jaxdecomp>=0.2.2",
|
||||
"pytest>=8.0.0",
|
||||
"pfft-python @ git+https://github.com/MP-Gadget/pfft-python",
|
||||
"pmesh @ git+https://github.com/MP-Gadget/pmesh",
|
||||
"fastpm @ git+https://github.com/ASKabalan/fastpm-python",
|
||||
"diffrax"
|
||||
]
|
||||
dependencies = ["jax_cosmo", "jax>=0.4.35", "jaxdecomp>=0.2.3"]
|
||||
|
||||
[tool.setuptools]
|
||||
packages = ["jaxpm"]
|
||||
|
|
5
requirements-test.txt
Normal file
5
requirements-test.txt
Normal file
|
@ -0,0 +1,5 @@
|
|||
pytest>=8.0.0
|
||||
diffrax
|
||||
pfft-python @ git+https://github.com/MP-Gadget/pfft-python
|
||||
pmesh @ git+https://github.com/MP-Gadget/pmesh
|
||||
fastpm @ git+https://github.com/ASKabalan/fastpm-python
|
87
tests/test_gradients.py
Normal file
87
tests/test_gradients.py
Normal file
|
@ -0,0 +1,87 @@
|
|||
import jax
|
||||
import pytest
|
||||
from diffrax import (BacksolveAdjoint, Dopri5, ODETerm, PIDController,
|
||||
RecursiveCheckpointAdjoint, SaveAt, diffeqsolve)
|
||||
from helpers import MSE
|
||||
from jax import numpy as jnp
|
||||
|
||||
from jaxpm.distributed import uniform_particles
|
||||
from jaxpm.painting import cic_paint, cic_paint_dx
|
||||
from jaxpm.pm import lpt, make_diffrax_ode
|
||||
|
||||
|
||||
@pytest.mark.single_device
|
||||
@pytest.mark.parametrize("order", [1, 2])
|
||||
@pytest.mark.parametrize("absolute_painting", [True, False])
|
||||
@pytest.mark.parametrize("adjoint", ['DTO', 'OTD'])
|
||||
def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
|
||||
nbody_from_lpt1, nbody_from_lpt2, cosmo, order,
|
||||
absolute_painting, adjoint):
|
||||
|
||||
mesh_shape, _ = simulation_config
|
||||
cosmo._workspace = {}
|
||||
|
||||
if adjoint == 'OTD':
|
||||
pytest.skip("OTD adjoint not implemented yet (needs PFFT3D JVP)")
|
||||
|
||||
adjoint = RecursiveCheckpointAdjoint(
|
||||
) if adjoint == 'DTO' else BacksolveAdjoint(solver=Dopri5())
|
||||
|
||||
@jax.jit
|
||||
@jax.grad
|
||||
def forward_model(initial_conditions, cosmo):
|
||||
|
||||
# Initial displacement
|
||||
if absolute_painting:
|
||||
particles = uniform_particles(mesh_shape)
|
||||
dx, p, _ = lpt(cosmo,
|
||||
initial_conditions,
|
||||
particles,
|
||||
a=lpt_scale_factor,
|
||||
order=order)
|
||||
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
|
||||
y0 = jnp.stack([particles + dx, p])
|
||||
|
||||
else:
|
||||
dx, p, _ = lpt(cosmo,
|
||||
initial_conditions,
|
||||
a=lpt_scale_factor,
|
||||
order=order)
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
|
||||
y0 = jnp.stack([dx, p])
|
||||
|
||||
solver = Dopri5()
|
||||
controller = PIDController(rtol=1e-7,
|
||||
atol=1e-7,
|
||||
pcoeff=0.4,
|
||||
icoeff=1,
|
||||
dcoeff=0)
|
||||
|
||||
saveat = SaveAt(t1=True)
|
||||
|
||||
solutions = diffeqsolve(ode_fn,
|
||||
solver,
|
||||
t0=lpt_scale_factor,
|
||||
t1=1.0,
|
||||
dt0=None,
|
||||
y0=y0,
|
||||
adjoint=adjoint,
|
||||
stepsize_controller=controller,
|
||||
saveat=saveat)
|
||||
|
||||
if absolute_painting:
|
||||
final_field = cic_paint(jnp.zeros(mesh_shape), solutions.ys[-1, 0])
|
||||
else:
|
||||
final_field = cic_paint_dx(solutions.ys[-1, 0])
|
||||
|
||||
return MSE(final_field,
|
||||
nbody_from_lpt1 if order == 1 else nbody_from_lpt2)
|
||||
|
||||
bad_initial_conditions = initial_conditions + jax.random.normal(
|
||||
jax.random.PRNGKey(0), initial_conditions.shape) * 0.5
|
||||
best_ic = forward_model(initial_conditions, cosmo)
|
||||
bad_ic = forward_model(bad_initial_conditions, cosmo)
|
||||
|
||||
assert jnp.max(best_ic) < 1e-5
|
||||
assert jnp.max(bad_ic) > 1e-5
|
Loading…
Add table
Reference in a new issue