Compare commits

...

12 commits
v0.1.0 ... main

Author SHA1 Message Date
Wassim KABALAN
cb2a7ab17f
Fix pfft gradients ()
* update jaxdecomp version and test gradients

* Prepare for DTO tests

* format
2024-12-22 12:47:42 -05:00
Francois Lanusse
d81a2529e7 minor typo fix 2024-12-21 15:28:20 -05:00
Francois Lanusse
15f2fb1ee6 adding notice 2024-12-21 15:26:53 -05:00
Francois Lanusse
ae0f439ae4 fixing formatting of notebook 2024-12-21 13:14:42 -05:00
Francois Lanusse
ea9fbf6aa8
Update README.md 2024-12-21 13:13:37 -05:00
Francois Lanusse
ad16a0659a Created using Colab 2024-12-21 13:10:15 -05:00
Francois Lanusse
f245a1f685
Pypi upload compatible version ()
* moving test dependencies separately

* adding manifest to remove unecessary files

* updating name of project

* Fixing formatting

* Adding badge for pypi version

* Adding very simple install instructions
2024-12-21 11:47:13 -05:00
Francois Lanusse
160b86eb71
Create python-publish.yml 2024-12-21 11:38:24 -05:00
Francois Lanusse
f14f0fe68e
Changing description () 2024-12-21 10:45:12 -05:00
Francois Lanusse
70ab9f1931 remove deprecated stuff
Former-commit-id: da6fdea0556407c968dd31d44af832c997f2645c
2024-12-21 09:58:42 -05:00
Francois Lanusse
bc6e57532d remove massive file
Former-commit-id: 69f49ba2ed23c16bcea70024ea3bcff4d71b8a5b
2024-12-21 09:58:18 -05:00
Wassim KABALAN
4b4450d7d3 Remove huge notebook from history ()
* jaxdecomp proto ()

* adding example of distributed solution

* put back old functgion

* update formatting

* add halo exchange and slice pad

* apply formatting

* implement distributed optimized cic_paint

* Use new cic_paint with halo

* Fix seed for distributed normal

* Wrap interpolation function to avoid all gather

* Return normal order frequencies for single GPU

* add example

* format

* add optimised bench script

* times in ms

* add lpt2

* update benchmark and add slurm

* Visualize only final field

* Update scripts/distributed_pm.py

Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>

* Adjust pencil type for frequencies

* fix painting issue with slabs

* Shared operation in fourrier space now take inverted sharding axis for
slabs

* add assert to make pyright happy

* adjust test for hpc-plotter

* add PMWD test

* bench

* format

* added github workflow

* fix formatting from main

* Update for jaxDecomp pure JAX

* revert single halo extent change

* update for latest jaxDecomp

* remove fourrier_space in autoshmap

* make normal_field work with single controller

* format

* make distributed pm work in single controller

* merge bench_pm

* update to leapfrog

* add a strict dependency on jaxdecomp

* global mesh no longer needed

* kernels.py no longer uses global mesh

* quick fix in distributed

* pm.py no longer uses global mesh

* painting.py no longer uses global mesh

* update demo script

* quick fix in kernels

* quick fix in distributed

* update demo

* merge hugos LPT2 code

* format

* Small fix

* format

* remove duplicate get_ode_fn

* update visualizer

* update compensate CIC

* By default check_rep is false for shard_map

* remove experimental distributed code

* update PGDCorrection and neural ode to use new fft3d

* jaxDecomp pfft3d promotes to complex automatically

* remove deprecated stuff

* fix painting issue with read_cic

* use jnp interp instead of jc interp

* delete old slurms

* add notebook examples

* apply formatting

* add distributed zeros

* fix code in LPT2

* jit cic_paint

* update notebooks

* apply formating

* get local shape and zeros can be used by users

* add a user facing function to create uniform particle grid

* use jax interp instead of jax_cosmo

* use float64 for enmeshing

* Allow applying weights with relative cic paint

* Weights can be traced

* remove script folder

* update example notebooks

* delete outdated design file

* add readme for tutorials

* update readme

* fix small error

* forgot particles in multi host

* clarifying why cic_paint_dx is slower

* clarifying the halo size dependence on the box size

* ability to choose snapshots number with MultiHost script

* Adding animation notebook

* Put plotting in package

* Add finite difference laplace kernel + powerspec functions from Hugo

Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com>

* Put plotting utils in package

* By default use absoulute painting with

* update code

* update notebooks

* add tests

* Upgrade setup.py to pyproject

* Format

* format tests

* update test dependencies

* add test workflow

* fix deprecated FftType in jaxpm.kernels

* Add aboucaud comments

* JAX version is 0.4.35 until Diffrax new release

* add numpy explicitly as dependency for tests

* fix install order for tests

* add numpy to be installed

* enforce no build isolation for fastpm

* pip install jaxpm test without build isolation

* bump jaxdecomp version

* revert test workflow

* remove outdated tests

---------

Co-authored-by: EiffL <fr.eiffel@gmail.com>
Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>
Co-authored-by: Wassim KABALAN <wassim@apc.in2p3.fr>
Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com>

* Update README.md

* Deleting Animating notebook

* Update pyproject.toml

---------

Co-authored-by: EiffL <fr.eiffel@gmail.com>
Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>
Co-authored-by: Wassim KABALAN <wassim@apc.in2p3.fr>
Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com>
Former-commit-id: 960cbf28ebcaad2ef0624c92e8f7f0729b75dceb
2024-12-20 08:43:43 -05:00
17 changed files with 2449 additions and 34 deletions

55
.github/workflows/python-publish.yml vendored Normal file
View file

@ -0,0 +1,55 @@
name: Upload Python Package
on:
release:
types: [published]
permissions:
contents: read
jobs:
release-build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.x"
- name: Build release distributions
run: |
# NOTE: put your own distribution build steps here.
python -m pip install build
python -m build
- name: Upload distributions
uses: actions/upload-artifact@v4
with:
name: release-dists
path: dist/
pypi-publish:
runs-on: ubuntu-latest
needs:
- release-build
permissions:
# IMPORTANT: this permission is mandatory for trusted publishing
id-token: write
environment:
name: pypi
url: https://pypi.org/p/jaxpm
steps:
- name: Retrieve release distributions
uses: actions/download-artifact@v4
with:
name: release-dists
path: dist/
- name: Publish release distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
packages-dir: dist/

View file

@ -34,7 +34,8 @@ jobs:
pip install git+https://github.com/MP-Gadget/pfft-python
pip install git+https://github.com/MP-Gadget/pmesh
pip install git+https://github.com/ASKabalan/fastpm-python --no-build-isolation
pip install .[test]
pip install -r requirements-test.txt
pip install .
- name: Run Single Device Tests
run: |

3
.gitignore vendored
View file

@ -132,3 +132,6 @@ dmypy.json
# Pyre type checker
.pyre/
# Hide version file
_version.py

View file

@ -1,6 +1,6 @@
MIT License
Copyright (c) 2021 Differentiable Universe Initiative
Copyright (c) 2021-2025 Differentiable Universe Initiative
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal

2
MANIFEST.in Normal file
View file

@ -0,0 +1,2 @@
prune notebooks
prune tests

View file

@ -1,9 +1,26 @@
# JaxPM
[![Tests](https://github.com/DifferentiableUniverseInitiative/JaxPM/actions/workflows/tests.yml/badge.svg)](https://github.com/DifferentiableUniverseInitiative/JaxPM/actions/workflows/tests.yml) <!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
[![Notebook](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/DifferentiableUniverseInitiative/JaxPM/blob/main/notebooks/01-Introduction.ipynb)
[![PyPI version](https://img.shields.io/pypi/v/jaxpm)](https://pypi.org/project/jaxpm/) [![Tests](https://github.com/DifferentiableUniverseInitiative/JaxPM/actions/workflows/tests.yml/badge.svg)](https://github.com/DifferentiableUniverseInitiative/JaxPM/actions/workflows/tests.yml) <!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
[![All Contributors](https://img.shields.io/badge/all_contributors-5-orange.svg?style=flat-square)](#contributors-)
<!-- ALL-CONTRIBUTORS-BADGE:END -->
JAX-powered Cosmological Particle-Mesh N-body Solver
> ### Note
> **The new JaxPM v0.1.xx** supports multi-GPU model distribution while remaining compatible with previous releases. These significant changes are still under development and testing, so please report any issues you encounter.
> For the older but more stable version, install:
> ```bash
> pip install jaxpm==0.0.2
> ```
## Install
Basic installation can be done using pip:
```bash
pip install jaxpm
```
For more advanced installation for optimized distribution on gpu clusters, please install jaxDecomp first. See instructions [here](https://github.com/DifferentiableUniverseInitiative/jaxDecomp).
## Goals
Provide a modern infrastructure to support differentiable PM N-body simulations using JAX:

View file

@ -1,14 +0,0 @@
#!/bin/bash
#SBATCH -A m1727
#SBATCH -C gpu
#SBATCH -q debug
#SBATCH -t 0:05:00
#SBATCH -N 2
#SBATCH --ntasks-per-node=4
#SBATCH -c 32
#SBATCH --gpus-per-task=1
#SBATCH --gpu-bind=none
module load python cudnn/8.2.0 nccl/2.11.4 cudatoolkit
export SLURM_CPU_BIND="cores"
srun python test_pfft.py

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1,320 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# **Animating Particle Mesh density fields**\n",
"\n",
"In this tutorial, we will animate the density field of a particle mesh simulation. We will use the `manim` library to create the animation. \n",
"\n",
"The density fields are created exactly like in the notebook [**05-MultiHost_PM.ipynb**](05-MultiHost_PM.ipynb) using the same script [**05-MultiHost_PM.py**](05-MultiHost_PM.py)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To run a multi-host simulation, you first need to **allocate a job** with `salloc`. This command requests resources on an HPC cluster.\n",
"\n",
"just like in notebook [**05-MultiHost_PM.ipynb**]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!salloc --account=XXX@a100 -C a100 --gres=gpu:8 --ntasks-per-node=8 --time=00:40:00 --cpus-per-task=8 --hint=nomultithread --qos=qos_gpu-dev --nodes=4 & "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**A few hours later**\n",
"\n",
"Use `!squeue -u $USER -o \"%i %D %b\"` to **check the JOB ID** and verify your resource allocation.\n",
"\n",
"In this example, weve been allocated **32 GPUs split across 4 nodes**.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!squeue -u $USER -o \"%i %D %b\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Unset the following environment variables, as they can cause issues when using JAX in a distributed setting:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"del os.environ['VSCODE_PROXY_URI']\n",
"del os.environ['NO_PROXY']\n",
"del os.environ['no_proxy']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Checking Available Compute Resources\n",
"\n",
"Run the following command to initialize JAX distributed computing and display the devices available for this job:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!srun --jobid=467745 -n 32 python -c \"import jax; jax.distributed.initialize(); print(jax.devices()) if jax.process_index() == 0 else None\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Multi-Host Simulation Script with Arguments (reminder)\n",
"\n",
"This script is nearly identical to the single-host version, with the main addition being the call to `jax.distributed.initialize()` at the start, enabling multi-host parallelism. Heres a breakdown of the key arguments:\n",
"\n",
"- **`--pdims`** (`-p`): Specifies processor grid dimensions as two integers, like `16 2` for 16 x 2 device mesh (default is `[1, jax.devices()]`).\n",
"- **`--mesh_shape`** (`-m`): Defines the simulation mesh shape as three integers (default is `[512, 512, 512]`).\n",
"- **`--box_size`** (`-b`): Sets the physical box size of the simulation as three floating-point values, e.g., `1000. 1000. 1000.` (default is `[500.0, 500.0, 500.0]`).\n",
"- **`--halo_size`** (`-H`): Specifies the halo size for boundary overlap across nodes (default is `64`).\n",
"- **`--solver`** (`-s`): Chooses the ODE solver (`leapfrog` or `dopri8`). The `leapfrog` solver uses a fixed step size, while `dopri8` is an adaptive Runge-Kutta solver with a PID controller (default is `leapfrog`).\n",
"- **`--snapthots`** (`-st`) : Number of snapshots to save (warning, increases memory usage)\n",
"\n",
"### Running the Multi-Host Simulation Script\n",
"\n",
"To create a smooth animation, we need a series of closely spaced snapshots to capture the evolution of the density field over time. In this example, we set the number of snapshots to **10** to ensure smooth transitions in the animation.\n",
"\n",
"Using a larger number of GPUs helps process these snapshots efficiently, especially with a large simulation mesh or high-resolution data. This allows us to achieve both the desired snapshot frequency and the necessary simulation detail without excessive runtime.\n",
"\n",
"The command to run the multi-host simulation with these settings will look something like this:\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import subprocess\n",
"\n",
"# Define parameters as variables\n",
"jobid = \"467745\"\n",
"num_processes = 32\n",
"script_name = \"05-MultiHost_PM.py\"\n",
"mesh_shape = (1024, 1024, 1024)\n",
"box_size = (1000., 1000., 1000.)\n",
"halo_size = 128\n",
"solver = \"leapfrog\"\n",
"pdims = (16, 2)\n",
"snapshots = 8\n",
"\n",
"# Build the command as a list, incorporating variables\n",
"command = [\n",
" \"srun\",\n",
" f\"--jobid={jobid}\",\n",
" \"-n\", str(num_processes),\n",
" \"python\", script_name,\n",
" \"--mesh_shape\", str(mesh_shape[0]), str(mesh_shape[1]), str(mesh_shape[2]),\n",
" \"--box_size\", str(box_size[0]), str(box_size[1]), str(box_size[2]),\n",
" \"--halo_size\", str(halo_size),\n",
" \"-s\", solver,\n",
" \"--pdims\", str(pdims[0]), str(pdims[1]),\n",
" \"--snapshots\", str(snapshots)\n",
"]\n",
"\n",
"# Execute the command as a subprocess\n",
"subprocess.run(command)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Projecting the 3D Density Fields to 2D\n",
"\n",
"To visualize the 3D density fields in 2D, we need to create a projection:\n",
"\n",
"- **`project_to_2d` Function**: This function reduces the 3D array to 2D by summing over a portion of one axis.\n",
" - We sum the top one-eighth of the data along the first axis to capture a slice of the density field.\n",
"\n",
"- **Creating 2D Projections**: Apply `project_to_2d` to each 3D field (`initial_conditions`, `lpt_displacements`, `ode_solution_0`, and `ode_solution_1`) to get 2D arrays that represent the density fields.\n",
"\n",
"### Applying the Magma Colormap\n",
"\n",
"To improve visualization, apply the \"magma\" colormap to each 2D projection:\n",
"\n",
"- **`apply_colormap` Function**: This function maps values in the 2D array to colors using the \"magma\" colormap.\n",
" - First, normalize the array to the `[0, 1]` range.\n",
" - Apply the colormap to create RGB images, which will be used for the animation.\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from matplotlib import colormaps\n",
"\n",
"# Define a function to project the 3D field to 2D\n",
"def project_to_2d(field):\n",
" sum_over = field.shape[0] // 8\n",
" slicing = [slice(None)] * field.ndim\n",
" slicing[0] = slice(None, sum_over)\n",
" slicing = tuple(slicing)\n",
"\n",
" return field[slicing].sum(axis=0)\n",
"\n",
"\n",
"def apply_colormap(array, cmap_name=\"magma\"):\n",
" cmap = colormaps[cmap_name]\n",
" normalized_array = (array - array.min()) / (array.max() - array.min())\n",
" colored_image = cmap(normalized_array)[:, :, :3] # Drop alpha channel for RGB\n",
" return (colored_image * 255).astype(np.uint8)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Loading and Visualizing Results\n",
"\n",
"After running the multi-host simulation, we load the saved results from disk:\n",
"\n",
"- **`initial_conditions.npy`**: Initial conditions for the simulation.\n",
"- **`lpt_displacements.npy`**: Linear perturbation displacements.\n",
"- **`ode_solution_*.npy`** : Solutions from the ODE solver at each snapshot.\n",
"\n",
"We will now project the fields to 2D maps and apply the color map\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"initial_conditions = apply_colormap(project_to_2d(np.load('fields/initial_conditions.npy')))\n",
"lpt_displacements = apply_colormap(project_to_2d(np.load('fields/lpt_displacements.npy')))\n",
"ode_solutions = []\n",
"for i in range(8):\n",
" ode_solutions.append(apply_colormap(project_to_2d(np.load(f'fields/ode_solution_{i}.npy'))))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Animating with Manim\n",
"\n",
"To create animations with `manim` in a Jupyter notebook, we start by configuring some settings to ensure the output displays correctly and without a background.\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from manim import *\n",
"config.media_width = \"100%\"\n",
"config.verbosity = \"WARNING\"\n",
"config.background_color = \"#00000000\" # Transparent background"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Defining the Animation in Manim\n",
"\n",
"This animation class, `FieldTransition`, smoothly transitions through the stages of the particle mesh density field evolution.\n",
"\n",
"- **Setup**: Each density field snapshot is loaded as an image and aligned for smooth transitions.\n",
"- **Animation Sequence**:\n",
" - The animation begins with a fade-in of the initial conditions.\n",
" - It then transitions through the stages in sequence, showing each snapshot of the density field evolution with brief pauses in between.\n",
"\n",
"To run the animation, execute `%manim -v WARNING -qm FieldTransition` to render it in the Jupyter Notebook.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Define the animation in Manim\n",
"class FieldTransition(Scene):\n",
" def construct(self):\n",
" init_conditions_img = ImageMobject(initial_conditions).scale(4)\n",
" lpt_img = ImageMobject(lpt_displacements).scale(4)\n",
" snapshots_imgs = [ImageMobject(sol).scale(4) for sol in ode_solutions]\n",
"\n",
"\n",
" # Place the images on top of each other initially\n",
" lpt_img.move_to(init_conditions_img)\n",
" for img in snapshots_imgs:\n",
" img.move_to(init_conditions_img)\n",
"\n",
" # Show initial field and then transform between fields\n",
" self.play(FadeIn(init_conditions_img))\n",
" self.wait(0.2)\n",
" self.play(Transform(init_conditions_img, lpt_img))\n",
" self.wait(0.2)\n",
" self.play(Transform(lpt_img, snapshots_imgs[0]))\n",
" self.wait(0.2)\n",
" for img1, img2 in zip(snapshots_imgs, snapshots_imgs[1:]):\n",
" self.play(Transform(img1, img2))\n",
" self.wait(0.2)\n",
"\n",
"%manim -v WARNING -qm -o anim.gif --format=gif FieldTransition "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View file

@ -1 +0,0 @@
c4a44973e4f11841a8c14f4d200e7e87887419aa

View file

@ -3,28 +3,15 @@ requires = ["setuptools", "wheel", "setuptools-scm"]
build-backend = "setuptools.build_meta"
[project]
name = "JaxPM"
name = "jaxpm"
dynamic = ["version"]
description = "A dead simple FastPM implementation in JAX"
description = "A simple Particle-Mesh implementation in JAX"
authors = [{ name = "JaxPM developers" }]
readme = "README.md"
requires-python = ">=3.9"
license = { file = "LICENSE" }
urls = { "Homepage" = "https://github.com/DifferentiableUniverseInitiative/JaxPM" }
dependencies = ["jax_cosmo", "jax>=0.4.30", "jaxdecomp>=0.2.2"]
[project.optional-dependencies]
test = [
"jax>=0.4.30",
"numpy",
"jax_cosmo",
"jaxdecomp>=0.2.2",
"pytest>=8.0.0",
"pfft-python @ git+https://github.com/MP-Gadget/pfft-python",
"pmesh @ git+https://github.com/MP-Gadget/pmesh",
"fastpm @ git+https://github.com/ASKabalan/fastpm-python",
"diffrax"
]
dependencies = ["jax_cosmo", "jax>=0.4.35", "jaxdecomp>=0.2.3"]
[tool.setuptools]
packages = ["jaxpm"]

5
requirements-test.txt Normal file
View file

@ -0,0 +1,5 @@
pytest>=8.0.0
diffrax
pfft-python @ git+https://github.com/MP-Gadget/pfft-python
pmesh @ git+https://github.com/MP-Gadget/pmesh
fastpm @ git+https://github.com/ASKabalan/fastpm-python

87
tests/test_gradients.py Normal file
View file

@ -0,0 +1,87 @@
import jax
import pytest
from diffrax import (BacksolveAdjoint, Dopri5, ODETerm, PIDController,
RecursiveCheckpointAdjoint, SaveAt, diffeqsolve)
from helpers import MSE
from jax import numpy as jnp
from jaxpm.distributed import uniform_particles
from jaxpm.painting import cic_paint, cic_paint_dx
from jaxpm.pm import lpt, make_diffrax_ode
@pytest.mark.single_device
@pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("absolute_painting", [True, False])
@pytest.mark.parametrize("adjoint", ['DTO', 'OTD'])
def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
nbody_from_lpt1, nbody_from_lpt2, cosmo, order,
absolute_painting, adjoint):
mesh_shape, _ = simulation_config
cosmo._workspace = {}
if adjoint == 'OTD':
pytest.skip("OTD adjoint not implemented yet (needs PFFT3D JVP)")
adjoint = RecursiveCheckpointAdjoint(
) if adjoint == 'DTO' else BacksolveAdjoint(solver=Dopri5())
@jax.jit
@jax.grad
def forward_model(initial_conditions, cosmo):
# Initial displacement
if absolute_painting:
particles = uniform_particles(mesh_shape)
dx, p, _ = lpt(cosmo,
initial_conditions,
particles,
a=lpt_scale_factor,
order=order)
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
y0 = jnp.stack([particles + dx, p])
else:
dx, p, _ = lpt(cosmo,
initial_conditions,
a=lpt_scale_factor,
order=order)
ode_fn = ODETerm(
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
y0 = jnp.stack([dx, p])
solver = Dopri5()
controller = PIDController(rtol=1e-7,
atol=1e-7,
pcoeff=0.4,
icoeff=1,
dcoeff=0)
saveat = SaveAt(t1=True)
solutions = diffeqsolve(ode_fn,
solver,
t0=lpt_scale_factor,
t1=1.0,
dt0=None,
y0=y0,
adjoint=adjoint,
stepsize_controller=controller,
saveat=saveat)
if absolute_painting:
final_field = cic_paint(jnp.zeros(mesh_shape), solutions.ys[-1, 0])
else:
final_field = cic_paint_dx(solutions.ys[-1, 0])
return MSE(final_field,
nbody_from_lpt1 if order == 1 else nbody_from_lpt2)
bad_initial_conditions = initial_conditions + jax.random.normal(
jax.random.PRNGKey(0), initial_conditions.shape) * 0.5
best_ic = forward_model(initial_conditions, cosmo)
bad_ic = forward_model(bad_initial_conditions, cosmo)
assert jnp.max(best_ic) < 1e-5
assert jnp.max(bad_ic) > 1e-5