mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-07-07 04:21:12 +00:00
Compare commits
14 commits
Author | SHA1 | Date | |
---|---|---|---|
|
7d76573701 | ||
|
6693e5c725 | ||
|
cb2a7ab17f | ||
|
d81a2529e7 | ||
|
15f2fb1ee6 | ||
|
ae0f439ae4 | ||
|
ea9fbf6aa8 | ||
|
ad16a0659a | ||
|
f245a1f685 | ||
|
160b86eb71 | ||
|
f14f0fe68e | ||
|
70ab9f1931 | ||
|
bc6e57532d | ||
|
4b4450d7d3 |
29 changed files with 2911 additions and 118 deletions
34
.github/workflows/formatting.yml
vendored
34
.github/workflows/formatting.yml
vendored
|
@ -7,15 +7,37 @@ on:
|
|||
branches: [ "main" ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
formatting:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v3
|
||||
- name: Checkout Source
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Cache pip dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-formatting-pip-${{ hashFiles('.pre-commit-config.yaml') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-formatting-pip-
|
||||
|
||||
- name: Cache pre-commit
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pre-commit
|
||||
key: ${{ runner.os }}-pre-commit-${{ hashFiles('.pre-commit-config.yaml') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pre-commit-
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip isort
|
||||
python -m pip install pre-commit
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install pre-commit isort
|
||||
|
||||
- name: Run pre-commit
|
||||
run: python -m pre_commit run --all-files
|
||||
|
|
55
.github/workflows/python-publish.yml
vendored
Normal file
55
.github/workflows/python-publish.yml
vendored
Normal file
|
@ -0,0 +1,55 @@
|
|||
name: Upload Python Package
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
release-build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.x"
|
||||
|
||||
- name: Build release distributions
|
||||
run: |
|
||||
# NOTE: put your own distribution build steps here.
|
||||
python -m pip install build
|
||||
python -m build
|
||||
|
||||
- name: Upload distributions
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: release-dists
|
||||
path: dist/
|
||||
|
||||
pypi-publish:
|
||||
runs-on: ubuntu-latest
|
||||
needs:
|
||||
- release-build
|
||||
permissions:
|
||||
# IMPORTANT: this permission is mandatory for trusted publishing
|
||||
id-token: write
|
||||
|
||||
environment:
|
||||
name: pypi
|
||||
url: https://pypi.org/p/jaxpm
|
||||
|
||||
steps:
|
||||
- name: Retrieve release distributions
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: release-dists
|
||||
path: dist/
|
||||
|
||||
- name: Publish release distributions to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
packages-dir: dist/
|
53
.github/workflows/tests.yml
vendored
53
.github/workflows/tests.yml
vendored
|
@ -10,36 +10,63 @@ on:
|
|||
|
||||
jobs:
|
||||
run_tests:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.10" , "3.11" , "3.12"]
|
||||
python-version: ["3.10", "3.11", "3.12"]
|
||||
|
||||
steps:
|
||||
- name: Checkout Source
|
||||
uses: actions/checkout@v2.3.1
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install dependencies
|
||||
- name: Cache pip dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/requirements-test.txt', '**/pyproject.toml') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pip-${{ matrix.python-version }}-
|
||||
${{ runner.os }}-pip-
|
||||
|
||||
- name: Cache system dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: /var/cache/apt
|
||||
key: ${{ runner.os }}-apt-${{ hashFiles('.github/workflows/tests.yml') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-apt-
|
||||
|
||||
- name: Install system dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libopenmpi-dev
|
||||
python -m pip install --upgrade pip
|
||||
pip install jax==0.4.35
|
||||
pip install numpy setuptools cython wheel
|
||||
pip install git+https://github.com/MP-Gadget/pfft-python
|
||||
pip install git+https://github.com/MP-Gadget/pmesh
|
||||
pip install git+https://github.com/ASKabalan/fastpm-python --no-build-isolation
|
||||
pip install .[test]
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip setuptools wheel
|
||||
# Install JAX first as it's a key dependency
|
||||
pip install jax
|
||||
# Install build dependencies
|
||||
pip install setuptools cython mpi4py
|
||||
# Install test requirements with no-build-isolation for faster builds
|
||||
pip install -r requirements-test.txt --no-build-isolation
|
||||
# Install additional test dependencies
|
||||
pip install pytest diffrax
|
||||
# Install package in development mode
|
||||
pip install -e .
|
||||
echo "numpy version installed:"
|
||||
python -c "import numpy; print(numpy.__version__)"
|
||||
|
||||
- name: Run Single Device Tests
|
||||
run: |
|
||||
cd tests
|
||||
pytest -v -m "not distributed"
|
||||
|
||||
- name: Run Distributed tests
|
||||
run: |
|
||||
pytest -v -m distributed
|
||||
pytest -v tests/test_distributed_pm.py
|
||||
|
|
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -132,3 +132,6 @@ dmypy.json
|
|||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# Hide version file
|
||||
_version.py
|
||||
|
|
2
LICENSE
2
LICENSE
|
@ -1,6 +1,6 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2021 Differentiable Universe Initiative
|
||||
Copyright (c) 2021-2025 Differentiable Universe Initiative
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
|
|
2
MANIFEST.in
Normal file
2
MANIFEST.in
Normal file
|
@ -0,0 +1,2 @@
|
|||
prune notebooks
|
||||
prune tests
|
19
README.md
19
README.md
|
@ -1,9 +1,26 @@
|
|||
# JaxPM
|
||||
[](https://github.com/DifferentiableUniverseInitiative/JaxPM/actions/workflows/tests.yml) <!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
|
||||
[](https://colab.research.google.com/github/DifferentiableUniverseInitiative/JaxPM/blob/main/notebooks/01-Introduction.ipynb)
|
||||
[](https://pypi.org/project/jaxpm/) [](https://github.com/DifferentiableUniverseInitiative/JaxPM/actions/workflows/tests.yml) <!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
|
||||
[](#contributors-)
|
||||
<!-- ALL-CONTRIBUTORS-BADGE:END -->
|
||||
JAX-powered Cosmological Particle-Mesh N-body Solver
|
||||
|
||||
> ### Note
|
||||
> **The new JaxPM v0.1.xx** supports multi-GPU model distribution while remaining compatible with previous releases. These significant changes are still under development and testing, so please report any issues you encounter.
|
||||
> For the older but more stable version, install:
|
||||
> ```bash
|
||||
> pip install jaxpm==0.0.2
|
||||
> ```
|
||||
|
||||
## Install
|
||||
|
||||
Basic installation can be done using pip:
|
||||
```bash
|
||||
pip install jaxpm
|
||||
```
|
||||
For more advanced installation for optimized distribution on gpu clusters, please install jaxDecomp first. See instructions [here](https://github.com/DifferentiableUniverseInitiative/jaxDecomp).
|
||||
|
||||
|
||||
## Goals
|
||||
|
||||
Provide a modern infrastructure to support differentiable PM N-body simulations using JAX:
|
||||
|
|
|
@ -1,14 +0,0 @@
|
|||
#!/bin/bash
|
||||
#SBATCH -A m1727
|
||||
#SBATCH -C gpu
|
||||
#SBATCH -q debug
|
||||
#SBATCH -t 0:05:00
|
||||
#SBATCH -N 2
|
||||
#SBATCH --ntasks-per-node=4
|
||||
#SBATCH -c 32
|
||||
#SBATCH --gpus-per-task=1
|
||||
#SBATCH --gpu-bind=none
|
||||
|
||||
module load python cudnn/8.2.0 nccl/2.11.4 cudatoolkit
|
||||
export SLURM_CPU_BIND="cores"
|
||||
srun python test_pfft.py
|
|
@ -166,11 +166,11 @@ def uniform_particles(mesh_shape, sharding=None):
|
|||
axis=-1)
|
||||
|
||||
|
||||
def normal_field(mesh_shape, seed, sharding=None):
|
||||
def normal_field(seed, shape, sharding=None, dtype=float):
|
||||
"""Generate a Gaussian random field with the given power spectrum."""
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
if gpu_mesh is not None and not (gpu_mesh.empty):
|
||||
local_mesh_shape = get_local_shape(mesh_shape, sharding)
|
||||
local_mesh_shape = get_local_shape(shape, sharding)
|
||||
|
||||
size = jax.device_count()
|
||||
# rank = jax.process_index()
|
||||
|
@ -190,9 +190,9 @@ def normal_field(mesh_shape, seed, sharding=None):
|
|||
return jax.random.normal(key=keys[idx], shape=shape, dtype=dtype)
|
||||
|
||||
return shard_map(
|
||||
partial(normal, shape=local_mesh_shape, dtype='float32'),
|
||||
partial(normal, shape=local_mesh_shape, dtype=dtype),
|
||||
mesh=gpu_mesh,
|
||||
in_specs=P(None),
|
||||
out_specs=spec)(keys) # yapf: disable
|
||||
else:
|
||||
return jax.random.normal(shape=mesh_shape, key=seed)
|
||||
return jax.random.normal(shape=shape, key=seed, dtype=dtype)
|
||||
|
|
|
@ -26,7 +26,7 @@ def E(cosmo, a):
|
|||
where :math:`f(a)` is the Dark Energy evolution parameter computed
|
||||
by :py:meth:`.f_de`.
|
||||
"""
|
||||
return np.power(Esqr(cosmo, a), 0.5)
|
||||
return np.sqrt(Esqr(cosmo, a))
|
||||
|
||||
|
||||
def df_de(cosmo, a, epsilon=1e-5):
|
||||
|
|
|
@ -12,7 +12,7 @@ from jaxpm.kernels import cic_compensation, fftk
|
|||
from jaxpm.painting_utils import gather, scatter
|
||||
|
||||
|
||||
def _cic_paint_impl(grid_mesh, positions, weight=None):
|
||||
def _cic_paint_impl(grid_mesh, positions, weight=1.):
|
||||
""" Paints positions onto mesh
|
||||
mesh: [nx, ny, nz]
|
||||
displacement field: [nx, ny, nz, 3]
|
||||
|
@ -27,12 +27,10 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
|
|||
neighboor_coords = floor + connection
|
||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||
if weight is not None:
|
||||
if jnp.isscalar(weight):
|
||||
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
|
||||
else:
|
||||
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]),
|
||||
kernel)
|
||||
if jnp.isscalar(weight):
|
||||
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
|
||||
else:
|
||||
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]), kernel)
|
||||
|
||||
neighboor_coords = jnp.mod(
|
||||
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
|
||||
|
@ -48,7 +46,13 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
|
|||
|
||||
|
||||
@partial(jax.jit, static_argnums=(3, 4))
|
||||
def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
|
||||
def cic_paint(grid_mesh, positions, weight=1., halo_size=0, sharding=None):
|
||||
|
||||
if sharding is not None:
|
||||
print("""
|
||||
WARNING : absolute painting is not recommended in multi-device mode.
|
||||
Please use relative painting instead.
|
||||
""")
|
||||
|
||||
positions = positions.reshape((*grid_mesh.shape, 3))
|
||||
|
||||
|
@ -57,9 +61,11 @@ def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
|
|||
|
||||
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
||||
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
||||
weight_spec = P() if jnp.isscalar(weight) else spec
|
||||
|
||||
grid_mesh = autoshmap(_cic_paint_impl,
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=(spec, spec, P()),
|
||||
in_specs=(spec, spec, weight_spec),
|
||||
out_specs=spec)(grid_mesh, positions, weight)
|
||||
grid_mesh = halo_exchange(grid_mesh,
|
||||
halo_extents=halo_extents,
|
||||
|
@ -128,6 +134,7 @@ def cic_paint_2d(mesh, positions, weight):
|
|||
positions: [npart, 2]
|
||||
weight: [npart]
|
||||
"""
|
||||
positions = positions.reshape([-1, 2])
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
floor = jnp.floor(positions)
|
||||
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
|
||||
|
@ -136,7 +143,7 @@ def cic_paint_2d(mesh, positions, weight):
|
|||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1]
|
||||
if weight is not None:
|
||||
kernel = kernel * weight[..., jnp.newaxis]
|
||||
kernel = kernel * weight.reshape(*positions.shape[:-1])
|
||||
|
||||
neighboor_coords = jnp.mod(
|
||||
neighboor_coords.reshape([-1, 4, 2]).astype('int32'),
|
||||
|
@ -151,13 +158,16 @@ def cic_paint_2d(mesh, positions, weight):
|
|||
return mesh
|
||||
|
||||
|
||||
def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
|
||||
def _cic_paint_dx_impl(displacements,
|
||||
weight=1.,
|
||||
halo_size=0,
|
||||
chunk_size=2**24):
|
||||
|
||||
halo_x, _ = halo_size[0]
|
||||
halo_y, _ = halo_size[1]
|
||||
|
||||
original_shape = displacements.shape
|
||||
particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32')
|
||||
particle_mesh = jnp.zeros(original_shape[:-1], dtype=displacements.dtype)
|
||||
if not jnp.isscalar(weight):
|
||||
if weight.shape != original_shape[:-1]:
|
||||
raise ValueError("Weight shape must match particle shape")
|
||||
|
@ -175,7 +185,7 @@ def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
|
|||
return scatter(pmid.reshape([-1, 3]),
|
||||
displacements.reshape([-1, 3]),
|
||||
particle_mesh,
|
||||
chunk_size=2**24,
|
||||
chunk_size=chunk_size,
|
||||
val=weight)
|
||||
|
||||
|
||||
|
@ -190,13 +200,13 @@ def cic_paint_dx(displacements,
|
|||
|
||||
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
||||
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
||||
weight_spec = P() if jnp.isscalar(weight) else spec
|
||||
grid_mesh = autoshmap(partial(_cic_paint_dx_impl,
|
||||
halo_size=halo_size,
|
||||
weight=weight,
|
||||
chunk_size=chunk_size),
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=spec,
|
||||
out_specs=spec)(displacements)
|
||||
in_specs=(spec, weight_spec),
|
||||
out_specs=spec)(displacements, weight)
|
||||
|
||||
grid_mesh = halo_exchange(grid_mesh,
|
||||
halo_extents=halo_extents,
|
||||
|
@ -230,6 +240,12 @@ def _cic_read_dx_impl(grid_mesh, disp, halo_size):
|
|||
def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None):
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
|
||||
# Halo size is halved for the read operation
|
||||
# We only need to read the density field
|
||||
# while in the painting operation we need to exchange and reduce the halo
|
||||
# We chose to do that since it is much easier to write a custom jvp rule for exchange
|
||||
# while it is a bit harder if there is a reduction involved
|
||||
halo_size = jax.tree.map(lambda x: x // 2, halo_size)
|
||||
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
|
||||
grid_mesh = halo_exchange(grid_mesh,
|
||||
halo_extents=halo_extents,
|
||||
|
|
|
@ -30,17 +30,14 @@ def enmesh(base_indices, displacements, cell_size, base_shape, offset,
|
|||
"""Multilinear enmeshing."""
|
||||
base_indices = jnp.asarray(base_indices)
|
||||
displacements = jnp.asarray(displacements)
|
||||
with jax.experimental.enable_x64():
|
||||
cell_size = jnp.float64(
|
||||
cell_size) if new_cell_size is not None else jnp.array(
|
||||
cell_size, dtype=displacements.dtype)
|
||||
if base_shape is not None:
|
||||
base_shape = jnp.array(base_shape, dtype=base_indices.dtype)
|
||||
offset = jnp.float64(offset)
|
||||
if new_cell_size is not None:
|
||||
new_cell_size = jnp.float64(new_cell_size)
|
||||
if new_shape is not None:
|
||||
new_shape = jnp.array(new_shape, dtype=base_indices.dtype)
|
||||
cell_size = jnp.array(cell_size, dtype=displacements.dtype)
|
||||
if base_shape is not None:
|
||||
base_shape = jnp.array(base_shape, dtype=base_indices.dtype)
|
||||
offset = offset.astype(base_indices.dtype)
|
||||
if new_cell_size is not None:
|
||||
new_cell_size = jnp.array(new_cell_size, dtype=displacements.dtype)
|
||||
if new_shape is not None:
|
||||
new_shape = jnp.array(new_shape, dtype=base_indices.dtype)
|
||||
|
||||
spatial_dim = base_indices.shape[1]
|
||||
neighbor_offsets = (
|
||||
|
|
|
@ -131,7 +131,7 @@ def linear_field(mesh_shape, box_size, pk, seed, sharding=None):
|
|||
Generate initial conditions.
|
||||
"""
|
||||
# Initialize a random field with one slice on each gpu
|
||||
field = normal_field(mesh_shape, seed=seed, sharding=sharding)
|
||||
field = normal_field(seed=seed, shape=mesh_shape, sharding=sharding)
|
||||
field = fft3d(field)
|
||||
kvec = fftk(field)
|
||||
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2
|
||||
|
@ -139,7 +139,7 @@ def linear_field(mesh_shape, box_size, pk, seed, sharding=None):
|
|||
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (
|
||||
box_size[0] * box_size[1] * box_size[2])
|
||||
|
||||
field = field * (pkmesh)**0.5
|
||||
field = field * jnp.sqrt(pkmesh)
|
||||
field = ifft3d(field)
|
||||
return field
|
||||
|
||||
|
@ -172,8 +172,7 @@ def make_ode_fn(mesh_shape,
|
|||
return nbody_ode
|
||||
|
||||
|
||||
def make_diffrax_ode(cosmo,
|
||||
mesh_shape,
|
||||
def make_diffrax_ode(mesh_shape,
|
||||
paint_absolute_pos=True,
|
||||
halo_size=0,
|
||||
sharding=None):
|
||||
|
@ -183,6 +182,7 @@ def make_diffrax_ode(cosmo,
|
|||
state is a tuple (position, velocities)
|
||||
"""
|
||||
pos, vel = state
|
||||
cosmo = args
|
||||
|
||||
forces = pm_forces(pos,
|
||||
mesh_shape=mesh_shape,
|
||||
|
|
|
@ -52,7 +52,7 @@ def _initialize_pk(mesh_shape, box_shape, kedges, los):
|
|||
kshapes = np.eye(len(mesh_shape), dtype=np.int32) * -2 + 1
|
||||
kvec = [(2 * np.pi * m / l) * np.fft.fftfreq(m).reshape(kshape)
|
||||
for m, l, kshape in zip(mesh_shape, box_shape, kshapes)]
|
||||
kmesh = sum(ki**2 for ki in kvec)**0.5
|
||||
kmesh = jnp.sqrt(sum(ki**2 for ki in kvec))
|
||||
|
||||
dig = np.digitize(kmesh.reshape(-1), kedges)
|
||||
kcount = np.bincount(dig, minlength=len(kedges) + 1)
|
||||
|
|
164
notebooks/01-Introduction.ipynb
Normal file
164
notebooks/01-Introduction.ipynb
Normal file
File diff suppressed because one or more lines are too long
428
notebooks/02-Advanced_usage.ipynb
Normal file
428
notebooks/02-Advanced_usage.ipynb
Normal file
File diff suppressed because one or more lines are too long
681
notebooks/03-MultiGPU_PM_Halo.ipynb
Normal file
681
notebooks/03-MultiGPU_PM_Halo.ipynb
Normal file
File diff suppressed because one or more lines are too long
378
notebooks/04-MultiGPU_PM_Solvers.ipynb
Normal file
378
notebooks/04-MultiGPU_PM_Solvers.ipynb
Normal file
File diff suppressed because one or more lines are too long
300
notebooks/05-MultiHost_PM.ipynb
Normal file
300
notebooks/05-MultiHost_PM.ipynb
Normal file
File diff suppressed because one or more lines are too long
|
@ -17,9 +17,8 @@ import jax_cosmo as jc
|
|||
import numpy as np
|
||||
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
|
||||
PIDController, SaveAt, diffeqsolve)
|
||||
from jax.experimental.mesh_utils import create_device_mesh
|
||||
from jax.experimental.multihost_utils import process_allgather
|
||||
from jax.sharding import Mesh, NamedSharding
|
||||
from jax.sharding import NamedSharding
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
from jaxpm.kernels import interpolate_power_spectrum
|
||||
|
@ -78,7 +77,7 @@ def parse_arguments():
|
|||
|
||||
def create_mesh_and_sharding(pdims):
|
||||
devices = create_device_mesh(pdims)
|
||||
mesh = Mesh(devices, axis_names=('x', 'y'))
|
||||
mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))
|
||||
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||
return mesh, sharding
|
||||
|
||||
|
@ -106,7 +105,10 @@ def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
|
|||
sharding=sharding)
|
||||
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
|
||||
make_diffrax_ode(mesh_shape,
|
||||
paint_absolute_pos=False,
|
||||
sharding=sharding,
|
||||
halo_size=halo_size))
|
||||
|
||||
# Choose solver
|
||||
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()
|
||||
|
|
320
notebooks/06-Animating_PM_Fields.ipynb
Normal file
320
notebooks/06-Animating_PM_Fields.ipynb
Normal file
|
@ -0,0 +1,320 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# **Animating Particle Mesh density fields**\n",
|
||||
"\n",
|
||||
"In this tutorial, we will animate the density field of a particle mesh simulation. We will use the `manim` library to create the animation. \n",
|
||||
"\n",
|
||||
"The density fields are created exactly like in the notebook [**05-MultiHost_PM.ipynb**](05-MultiHost_PM.ipynb) using the same script [**05-MultiHost_PM.py**](05-MultiHost_PM.py)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To run a multi-host simulation, you first need to **allocate a job** with `salloc`. This command requests resources on an HPC cluster.\n",
|
||||
"\n",
|
||||
"just like in notebook [**05-MultiHost_PM.ipynb**]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!salloc --account=XXX@a100 -C a100 --gres=gpu:8 --ntasks-per-node=8 --time=00:40:00 --cpus-per-task=8 --hint=nomultithread --qos=qos_gpu-dev --nodes=4 & "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**A few hours later**\n",
|
||||
"\n",
|
||||
"Use `!squeue -u $USER -o \"%i %D %b\"` to **check the JOB ID** and verify your resource allocation.\n",
|
||||
"\n",
|
||||
"In this example, we’ve been allocated **32 GPUs split across 4 nodes**.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!squeue -u $USER -o \"%i %D %b\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Unset the following environment variables, as they can cause issues when using JAX in a distributed setting:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"del os.environ['VSCODE_PROXY_URI']\n",
|
||||
"del os.environ['NO_PROXY']\n",
|
||||
"del os.environ['no_proxy']"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Checking Available Compute Resources\n",
|
||||
"\n",
|
||||
"Run the following command to initialize JAX distributed computing and display the devices available for this job:\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!srun --jobid=467745 -n 32 python -c \"import jax; jax.distributed.initialize(); print(jax.devices()) if jax.process_index() == 0 else None\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Multi-Host Simulation Script with Arguments (reminder)\n",
|
||||
"\n",
|
||||
"This script is nearly identical to the single-host version, with the main addition being the call to `jax.distributed.initialize()` at the start, enabling multi-host parallelism. Here’s a breakdown of the key arguments:\n",
|
||||
"\n",
|
||||
"- **`--pdims`** (`-p`): Specifies processor grid dimensions as two integers, like `16 2` for 16 x 2 device mesh (default is `[1, jax.devices()]`).\n",
|
||||
"- **`--mesh_shape`** (`-m`): Defines the simulation mesh shape as three integers (default is `[512, 512, 512]`).\n",
|
||||
"- **`--box_size`** (`-b`): Sets the physical box size of the simulation as three floating-point values, e.g., `1000. 1000. 1000.` (default is `[500.0, 500.0, 500.0]`).\n",
|
||||
"- **`--halo_size`** (`-H`): Specifies the halo size for boundary overlap across nodes (default is `64`).\n",
|
||||
"- **`--solver`** (`-s`): Chooses the ODE solver (`leapfrog` or `dopri8`). The `leapfrog` solver uses a fixed step size, while `dopri8` is an adaptive Runge-Kutta solver with a PID controller (default is `leapfrog`).\n",
|
||||
"- **`--snapthots`** (`-st`) : Number of snapshots to save (warning, increases memory usage)\n",
|
||||
"\n",
|
||||
"### Running the Multi-Host Simulation Script\n",
|
||||
"\n",
|
||||
"To create a smooth animation, we need a series of closely spaced snapshots to capture the evolution of the density field over time. In this example, we set the number of snapshots to **10** to ensure smooth transitions in the animation.\n",
|
||||
"\n",
|
||||
"Using a larger number of GPUs helps process these snapshots efficiently, especially with a large simulation mesh or high-resolution data. This allows us to achieve both the desired snapshot frequency and the necessary simulation detail without excessive runtime.\n",
|
||||
"\n",
|
||||
"The command to run the multi-host simulation with these settings will look something like this:\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import subprocess\n",
|
||||
"\n",
|
||||
"# Define parameters as variables\n",
|
||||
"jobid = \"467745\"\n",
|
||||
"num_processes = 32\n",
|
||||
"script_name = \"05-MultiHost_PM.py\"\n",
|
||||
"mesh_shape = (1024, 1024, 1024)\n",
|
||||
"box_size = (1000., 1000., 1000.)\n",
|
||||
"halo_size = 128\n",
|
||||
"solver = \"leapfrog\"\n",
|
||||
"pdims = (16, 2)\n",
|
||||
"snapshots = 8\n",
|
||||
"\n",
|
||||
"# Build the command as a list, incorporating variables\n",
|
||||
"command = [\n",
|
||||
" \"srun\",\n",
|
||||
" f\"--jobid={jobid}\",\n",
|
||||
" \"-n\", str(num_processes),\n",
|
||||
" \"python\", script_name,\n",
|
||||
" \"--mesh_shape\", str(mesh_shape[0]), str(mesh_shape[1]), str(mesh_shape[2]),\n",
|
||||
" \"--box_size\", str(box_size[0]), str(box_size[1]), str(box_size[2]),\n",
|
||||
" \"--halo_size\", str(halo_size),\n",
|
||||
" \"-s\", solver,\n",
|
||||
" \"--pdims\", str(pdims[0]), str(pdims[1]),\n",
|
||||
" \"--snapshots\", str(snapshots)\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"# Execute the command as a subprocess\n",
|
||||
"subprocess.run(command)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Projecting the 3D Density Fields to 2D\n",
|
||||
"\n",
|
||||
"To visualize the 3D density fields in 2D, we need to create a projection:\n",
|
||||
"\n",
|
||||
"- **`project_to_2d` Function**: This function reduces the 3D array to 2D by summing over a portion of one axis.\n",
|
||||
" - We sum the top one-eighth of the data along the first axis to capture a slice of the density field.\n",
|
||||
"\n",
|
||||
"- **Creating 2D Projections**: Apply `project_to_2d` to each 3D field (`initial_conditions`, `lpt_displacements`, `ode_solution_0`, and `ode_solution_1`) to get 2D arrays that represent the density fields.\n",
|
||||
"\n",
|
||||
"### Applying the Magma Colormap\n",
|
||||
"\n",
|
||||
"To improve visualization, apply the \"magma\" colormap to each 2D projection:\n",
|
||||
"\n",
|
||||
"- **`apply_colormap` Function**: This function maps values in the 2D array to colors using the \"magma\" colormap.\n",
|
||||
" - First, normalize the array to the `[0, 1]` range.\n",
|
||||
" - Apply the colormap to create RGB images, which will be used for the animation.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from matplotlib import colormaps\n",
|
||||
"\n",
|
||||
"# Define a function to project the 3D field to 2D\n",
|
||||
"def project_to_2d(field):\n",
|
||||
" sum_over = field.shape[0] // 8\n",
|
||||
" slicing = [slice(None)] * field.ndim\n",
|
||||
" slicing[0] = slice(None, sum_over)\n",
|
||||
" slicing = tuple(slicing)\n",
|
||||
"\n",
|
||||
" return field[slicing].sum(axis=0)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def apply_colormap(array, cmap_name=\"magma\"):\n",
|
||||
" cmap = colormaps[cmap_name]\n",
|
||||
" normalized_array = (array - array.min()) / (array.max() - array.min())\n",
|
||||
" colored_image = cmap(normalized_array)[:, :, :3] # Drop alpha channel for RGB\n",
|
||||
" return (colored_image * 255).astype(np.uint8)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Loading and Visualizing Results\n",
|
||||
"\n",
|
||||
"After running the multi-host simulation, we load the saved results from disk:\n",
|
||||
"\n",
|
||||
"- **`initial_conditions.npy`**: Initial conditions for the simulation.\n",
|
||||
"- **`lpt_displacements.npy`**: Linear perturbation displacements.\n",
|
||||
"- **`ode_solution_*.npy`** : Solutions from the ODE solver at each snapshot.\n",
|
||||
"\n",
|
||||
"We will now project the fields to 2D maps and apply the color map\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"initial_conditions = apply_colormap(project_to_2d(np.load('fields/initial_conditions.npy')))\n",
|
||||
"lpt_displacements = apply_colormap(project_to_2d(np.load('fields/lpt_displacements.npy')))\n",
|
||||
"ode_solutions = []\n",
|
||||
"for i in range(8):\n",
|
||||
" ode_solutions.append(apply_colormap(project_to_2d(np.load(f'fields/ode_solution_{i}.npy'))))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Animating with Manim\n",
|
||||
"\n",
|
||||
"To create animations with `manim` in a Jupyter notebook, we start by configuring some settings to ensure the output displays correctly and without a background.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from manim import *\n",
|
||||
"config.media_width = \"100%\"\n",
|
||||
"config.verbosity = \"WARNING\"\n",
|
||||
"config.background_color = \"#00000000\" # Transparent background"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Defining the Animation in Manim\n",
|
||||
"\n",
|
||||
"This animation class, `FieldTransition`, smoothly transitions through the stages of the particle mesh density field evolution.\n",
|
||||
"\n",
|
||||
"- **Setup**: Each density field snapshot is loaded as an image and aligned for smooth transitions.\n",
|
||||
"- **Animation Sequence**:\n",
|
||||
" - The animation begins with a fade-in of the initial conditions.\n",
|
||||
" - It then transitions through the stages in sequence, showing each snapshot of the density field evolution with brief pauses in between.\n",
|
||||
"\n",
|
||||
"To run the animation, execute `%manim -v WARNING -qm FieldTransition` to render it in the Jupyter Notebook.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Define the animation in Manim\n",
|
||||
"class FieldTransition(Scene):\n",
|
||||
" def construct(self):\n",
|
||||
" init_conditions_img = ImageMobject(initial_conditions).scale(4)\n",
|
||||
" lpt_img = ImageMobject(lpt_displacements).scale(4)\n",
|
||||
" snapshots_imgs = [ImageMobject(sol).scale(4) for sol in ode_solutions]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" # Place the images on top of each other initially\n",
|
||||
" lpt_img.move_to(init_conditions_img)\n",
|
||||
" for img in snapshots_imgs:\n",
|
||||
" img.move_to(init_conditions_img)\n",
|
||||
"\n",
|
||||
" # Show initial field and then transform between fields\n",
|
||||
" self.play(FadeIn(init_conditions_img))\n",
|
||||
" self.wait(0.2)\n",
|
||||
" self.play(Transform(init_conditions_img, lpt_img))\n",
|
||||
" self.wait(0.2)\n",
|
||||
" self.play(Transform(lpt_img, snapshots_imgs[0]))\n",
|
||||
" self.wait(0.2)\n",
|
||||
" for img1, img2 in zip(snapshots_imgs, snapshots_imgs[1:]):\n",
|
||||
" self.play(Transform(img1, img2))\n",
|
||||
" self.wait(0.2)\n",
|
||||
"\n",
|
||||
"%manim -v WARNING -qm -o anim.gif --format=gif FieldTransition "
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -1 +0,0 @@
|
|||
c4a44973e4f11841a8c14f4d200e7e87887419aa
|
|
@ -37,3 +37,50 @@ Each notebook includes installation instructions and guidelines for configuring
|
|||
- **SLURM** for job scheduling on clusters (if running multi-host setups)
|
||||
|
||||
> **Note**: These notebooks are tested on the **Jean Zay** supercomputer and may require configuration changes for different HPC clusters.
|
||||
|
||||
## Caveats
|
||||
|
||||
### Cloud-in-Cell (CIC) Painting (Single Device)
|
||||
|
||||
There is two ways to perform the CIC painting in JAXPM. The first one is to use the `cic_paint` which paints absolute particle positions to the mesh. The second one is to use the `cic_paint_dx` which paints relative particle positions to the mesh (using uniform particles). The absolute version is faster at the cost of more memory usage.
|
||||
|
||||
inorder to use relative painting you need to :
|
||||
|
||||
- Set the `particles` argument in `lpt` function from `jaxpm.pm` to `None`
|
||||
- Set `paint_absolute_pos` to `False` in `make_ode_fn` or `make_diffrax_ode` function from `jaxpm.pm` (it is True by default)
|
||||
|
||||
Otherwise you set `particles` to the starting particles of your choice and leave `paint_absolute_pos` to `True` (default value).
|
||||
|
||||
### Cloud-in-Cell (CIC) Painting (Multi Device)
|
||||
|
||||
Both `cic_paint` and `cic_paint_dx` functions are available in multi-device mode.
|
||||
|
||||
You need to set the arguments `sharding` and `halo_size` which is explained in the notebook [03-MultiGPU_PM_Halo.ipynb](03-MultiGPU_PM_Halo.ipynb).
|
||||
|
||||
One thing to note that `cic_paint` is not as accurate as `cic_paint_dx` in multi-device mode and therefor is not recommended.
|
||||
|
||||
Using relative painting in multi-device mode is just like in single device mode.\
|
||||
You need to set the `particles` argument in `lpt` function from `jaxpm.pm` to `None` and set `paint_absolute_pos` to `False`
|
||||
|
||||
### Distributed PM
|
||||
|
||||
To run a distributed PM follow the examples in notebooks [03](03-MultiGPU_PM_Halo.ipynb) and [05](05-MultiHost_PM.ipynb) for multi-host.
|
||||
|
||||
In short you need to set the arguments `sharding` and `halo_size` in `lpt` , `linear_field` the `make_ode` functions and `pm_forces` if you use it.
|
||||
|
||||
Missmatching the shardings will give you errors and unexpected results.
|
||||
|
||||
You can also use `normal_field` and `uniform_particles` from `jaxpm.pm.distributed` to create the fields and particles with a sharding.
|
||||
|
||||
### Choosing the right pdims
|
||||
|
||||
pdims are processor dimensions.\
|
||||
Explained more in the jaxdecomp paper [here](https://github.com/DifferentiableUniverseInitiative/jaxDecomp).
|
||||
|
||||
For 8 devices there are three decompositions that are possible:
|
||||
- (1 , 8)
|
||||
- (2 , 4) , (4 , 2)
|
||||
- (8 , 1)
|
||||
|
||||
(1 , X) should be the fastest (2 , X) or (X , 2) is more accurate but slightly slower.\
|
||||
and (X , 1) is giving the least accurate results for some reason so it is not recommended.
|
||||
|
|
|
@ -3,28 +3,15 @@ requires = ["setuptools", "wheel", "setuptools-scm"]
|
|||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "JaxPM"
|
||||
name = "jaxpm"
|
||||
dynamic = ["version"]
|
||||
description = "A dead simple FastPM implementation in JAX"
|
||||
description = "A simple Particle-Mesh implementation in JAX"
|
||||
authors = [{ name = "JaxPM developers" }]
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
license = { file = "LICENSE" }
|
||||
urls = { "Homepage" = "https://github.com/DifferentiableUniverseInitiative/JaxPM" }
|
||||
dependencies = ["jax_cosmo", "jax>=0.4.30", "jaxdecomp>=0.2.2"]
|
||||
|
||||
[project.optional-dependencies]
|
||||
test = [
|
||||
"jax>=0.4.30",
|
||||
"numpy",
|
||||
"jax_cosmo",
|
||||
"jaxdecomp>=0.2.2",
|
||||
"pytest>=8.0.0",
|
||||
"pfft-python @ git+https://github.com/MP-Gadget/pfft-python",
|
||||
"pmesh @ git+https://github.com/MP-Gadget/pmesh",
|
||||
"fastpm @ git+https://github.com/ASKabalan/fastpm-python",
|
||||
"diffrax"
|
||||
]
|
||||
dependencies = ["jax_cosmo", "jax>=0.4.35", "jaxdecomp>=0.2.3"]
|
||||
|
||||
[tool.setuptools]
|
||||
packages = ["jaxpm"]
|
||||
|
|
5
requirements-test.txt
Normal file
5
requirements-test.txt
Normal file
|
@ -0,0 +1,5 @@
|
|||
pfft-python @ git+https://github.com/MP-Gadget/pfft-python
|
||||
pmesh @ git+https://github.com/MP-Gadget/pmesh
|
||||
fastpm @ git+https://github.com/ASKabalan/fastpm-python
|
||||
numpy==2.2.6
|
||||
diffrax
|
|
@ -44,7 +44,7 @@ def simulation_config(request):
|
|||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", params=[0.1, 0.5, 0.8])
|
||||
@pytest.fixture(scope="session", params=[0.1, 0.2])
|
||||
def lpt_scale_factor(request):
|
||||
return request.param
|
||||
|
||||
|
@ -96,8 +96,9 @@ def fpm_initial_conditions(cosmo, particle_mesh):
|
|||
whitec = particle_mesh.generate_whitenoise(42,
|
||||
type='complex',
|
||||
unitary=False)
|
||||
lineark = whitec.apply(lambda k, v: pk_fn(sum(ki**2 for ki in k)**0.5)**0.5
|
||||
* v * (1 / v.BoxSize).prod()**0.5)
|
||||
lineark = whitec.apply(lambda k, v: jnp.sqrt(
|
||||
pk_fn(jnp.sqrt(sum(ki**2 for ki in k)))) * v * jnp.sqrt(
|
||||
(1 / v.BoxSize).prod()))
|
||||
init_mesh = lineark.c2r().value # XXX
|
||||
|
||||
return lineark, grid, init_mesh
|
||||
|
@ -151,7 +152,7 @@ def nbody_from_lpt1(solver, fpm_lpt1, particle_mesh, lpt_scale_factor):
|
|||
if lpt_scale_factor == 0.8:
|
||||
pytest.skip("Do not run nbody simulation from scale factor 0.8")
|
||||
|
||||
stages = np.linspace(lpt_scale_factor, 1.0, 10, endpoint=True)
|
||||
stages = np.linspace(lpt_scale_factor, 1.0, 100, endpoint=True)
|
||||
|
||||
finalstate = solver.nbody(fpm_lpt1, leapfrog(stages))
|
||||
fpm_mesh = particle_mesh.paint(finalstate.X).value
|
||||
|
@ -167,7 +168,7 @@ def nbody_from_lpt2(solver, fpm_lpt2, particle_mesh, lpt_scale_factor):
|
|||
if lpt_scale_factor == 0.8:
|
||||
pytest.skip("Do not run nbody simulation from scale factor 0.8")
|
||||
|
||||
stages = np.linspace(lpt_scale_factor, 1.0, 10, endpoint=True)
|
||||
stages = np.linspace(lpt_scale_factor, 1.0, 100, endpoint=True)
|
||||
|
||||
finalstate = solver.nbody(fpm_lpt2, leapfrog(stages))
|
||||
fpm_mesh = particle_mesh.paint(finalstate.X).value
|
||||
|
|
|
@ -2,6 +2,7 @@ import pytest
|
|||
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve
|
||||
from helpers import MSE, MSRE
|
||||
from jax import numpy as jnp
|
||||
from numpy.testing import assert_allclose
|
||||
|
||||
from jaxpm.distributed import uniform_particles
|
||||
from jaxpm.painting import cic_paint, cic_paint_dx
|
||||
|
@ -10,6 +11,8 @@ from jaxpm.utils import power_spectrum
|
|||
|
||||
_TOLERANCE = 1e-4
|
||||
_PM_TOLERANCE = 1e-3
|
||||
_FIELD_RTOL = 1e-4
|
||||
_FIELD_ATOL = 1e-3
|
||||
|
||||
|
||||
@pytest.mark.single_device
|
||||
|
@ -34,7 +37,10 @@ def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor,
|
|||
_, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape)
|
||||
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
|
||||
|
||||
assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE
|
||||
assert_allclose(lpt_field,
|
||||
fpm_ref_field,
|
||||
rtol=_FIELD_RTOL,
|
||||
atol=_FIELD_ATOL)
|
||||
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
|
||||
|
||||
|
||||
|
@ -55,7 +61,10 @@ def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor,
|
|||
_, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape)
|
||||
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
|
||||
|
||||
assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE
|
||||
assert_allclose(lpt_field,
|
||||
fpm_ref_field,
|
||||
rtol=_FIELD_RTOL,
|
||||
atol=_FIELD_ATOL)
|
||||
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
|
||||
|
||||
|
||||
|
@ -76,7 +85,7 @@ def test_nbody_absolute(simulation_config, initial_conditions,
|
|||
a=lpt_scale_factor,
|
||||
order=order)
|
||||
|
||||
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
|
||||
ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
|
||||
|
||||
solver = Dopri5()
|
||||
controller = PIDController(rtol=1e-8,
|
||||
|
@ -95,6 +104,7 @@ def test_nbody_absolute(simulation_config, initial_conditions,
|
|||
t1=1.0,
|
||||
dt0=None,
|
||||
y0=y0,
|
||||
args=cosmo,
|
||||
stepsize_controller=controller,
|
||||
saveat=saveat)
|
||||
|
||||
|
@ -105,7 +115,10 @@ def test_nbody_absolute(simulation_config, initial_conditions,
|
|||
_, jpm_ps = power_spectrum(final_field, box_shape=box_shape)
|
||||
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
|
||||
|
||||
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE
|
||||
assert_allclose(final_field,
|
||||
fpm_ref_field,
|
||||
rtol=_FIELD_RTOL,
|
||||
atol=_FIELD_ATOL)
|
||||
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE
|
||||
|
||||
|
||||
|
@ -121,8 +134,7 @@ def test_nbody_relative(simulation_config, initial_conditions,
|
|||
# Initial displacement
|
||||
dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order)
|
||||
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
|
||||
ode_fn = ODETerm(make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
|
||||
|
||||
solver = Dopri5()
|
||||
controller = PIDController(rtol=1e-9,
|
||||
|
@ -141,6 +153,7 @@ def test_nbody_relative(simulation_config, initial_conditions,
|
|||
t1=1.0,
|
||||
dt0=None,
|
||||
y0=y0,
|
||||
args=cosmo,
|
||||
stepsize_controller=controller,
|
||||
saveat=saveat)
|
||||
|
||||
|
@ -151,5 +164,8 @@ def test_nbody_relative(simulation_config, initial_conditions,
|
|||
_, jpm_ps = power_spectrum(final_field, box_shape=box_shape)
|
||||
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
|
||||
|
||||
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE
|
||||
assert_allclose(final_field,
|
||||
fpm_ref_field,
|
||||
rtol=_FIELD_RTOL,
|
||||
atol=_FIELD_ATOL)
|
||||
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE
|
||||
|
|
|
@ -2,8 +2,11 @@ from conftest import initialize_distributed
|
|||
|
||||
initialize_distributed() # ignore : E402
|
||||
|
||||
from functools import partial # noqa : E402
|
||||
|
||||
import jax # noqa : E402
|
||||
import jax.numpy as jnp # noqa : E402
|
||||
import jax_cosmo as jc # noqa : E402
|
||||
import pytest # noqa : E402
|
||||
from diffrax import SaveAt # noqa : E402
|
||||
from diffrax import Dopri5, ODETerm, PIDController, diffeqsolve
|
||||
|
@ -12,21 +15,37 @@ from jax import lax # noqa : E402
|
|||
from jax.experimental.multihost_utils import process_allgather # noqa : E402
|
||||
from jax.sharding import NamedSharding
|
||||
from jax.sharding import PartitionSpec as P # noqa : E402
|
||||
from jaxdecomp import get_fft_output_sharding
|
||||
|
||||
from jaxpm.distributed import uniform_particles # noqa : E402
|
||||
from jaxpm.distributed import fft3d, ifft3d
|
||||
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
|
||||
from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402
|
||||
from jaxpm.pm import lpt, make_diffrax_ode, pm_forces # noqa : E402
|
||||
|
||||
_TOLERANCE = 3.0 # 🙃🙃
|
||||
_TOLERANCE = 1e-12 # 🎉🎉🎉
|
||||
|
||||
pdims = [(1, 8), (8, 1), (4, 2), (2, 4)]
|
||||
|
||||
jax.config.update("jax_enable_x64", True) # Use double precision for accuracy
|
||||
|
||||
|
||||
@pytest.mark.distributed
|
||||
@pytest.mark.parametrize("order", [1, 2])
|
||||
@pytest.mark.parametrize("pdims", pdims)
|
||||
@pytest.mark.parametrize("absolute_painting", [True, False])
|
||||
def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
||||
absolute_painting):
|
||||
pdims, absolute_painting):
|
||||
|
||||
if absolute_painting:
|
||||
pytest.skip("Absolute painting is not recommended in distributed mode")
|
||||
|
||||
painting_str = "absolute" if absolute_painting else "relative"
|
||||
print("=" * 50)
|
||||
|
||||
mesh_shape, box_shape = simulation_config
|
||||
print(
|
||||
f"Running with {painting_str} painting and pdims {pdims} and order {order} and mesh shape {mesh_shape}..."
|
||||
)
|
||||
# SINGLE DEVICE RUN
|
||||
cosmo._workspace = {}
|
||||
if absolute_painting:
|
||||
|
@ -37,12 +56,12 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
|||
particles,
|
||||
a=0.1,
|
||||
order=order)
|
||||
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
|
||||
ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
|
||||
y0 = jnp.stack([particles + dx, p])
|
||||
else:
|
||||
dx, p, _ = lpt(cosmo, initial_conditions, a=0.1, order=order)
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
|
||||
ode_fn = ODETerm(make_diffrax_ode(mesh_shape,
|
||||
paint_absolute_pos=False))
|
||||
y0 = jnp.stack([dx, p])
|
||||
|
||||
solver = Dopri5()
|
||||
|
@ -60,6 +79,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
|||
t1=1.0,
|
||||
dt0=None,
|
||||
y0=y0,
|
||||
args=cosmo,
|
||||
stepsize_controller=controller,
|
||||
saveat=saveat)
|
||||
|
||||
|
@ -72,7 +92,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
|||
print("Done with single device run")
|
||||
# MULTI DEVICE RUN
|
||||
|
||||
mesh = jax.make_mesh((1, 8), ('x', 'y'))
|
||||
mesh = jax.make_mesh(pdims, ('x', 'y'))
|
||||
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||
halo_size = mesh_shape[0] // 2
|
||||
|
||||
|
@ -94,8 +114,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
|||
sharding=sharding)
|
||||
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo,
|
||||
mesh_shape,
|
||||
make_diffrax_ode(mesh_shape,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding))
|
||||
|
||||
|
@ -108,8 +127,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
|||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo,
|
||||
mesh_shape,
|
||||
make_diffrax_ode(mesh_shape,
|
||||
paint_absolute_pos=False,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding))
|
||||
|
@ -130,16 +148,23 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
|||
t1=1.0,
|
||||
dt0=None,
|
||||
y0=y0,
|
||||
args=cosmo,
|
||||
stepsize_controller=controller,
|
||||
saveat=saveat)
|
||||
|
||||
final_field = solutions.ys[-1, 0]
|
||||
print(f"Final field sharding is {final_field.sharding}")
|
||||
|
||||
assert final_field.sharding.is_equivalent_to(sharding , ndim=3) \
|
||||
, f"Final field sharding is not correct .. should be {sharding} it is instead {final_field.sharding}"
|
||||
|
||||
if absolute_painting:
|
||||
multi_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape),
|
||||
solutions.ys[-1, 0],
|
||||
final_field,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
else:
|
||||
multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0],
|
||||
multi_device_final_field = cic_paint_dx(final_field,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
|
||||
|
@ -150,3 +175,230 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
|||
print(f"MSE is {mse}")
|
||||
|
||||
assert mse < _TOLERANCE
|
||||
|
||||
|
||||
@pytest.mark.distributed
|
||||
@pytest.mark.parametrize("order", [1, 2])
|
||||
@pytest.mark.parametrize("pdims", pdims)
|
||||
def test_distrubted_gradients(simulation_config, initial_conditions, cosmo,
|
||||
order, nbody_from_lpt1, nbody_from_lpt2, pdims):
|
||||
|
||||
mesh_shape, box_shape = simulation_config
|
||||
# SINGLE DEVICE RUN
|
||||
cosmo._workspace = {}
|
||||
|
||||
mesh = jax.make_mesh(pdims, ('x', 'y'))
|
||||
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||
halo_size = mesh_shape[0] // 2
|
||||
|
||||
initial_conditions = lax.with_sharding_constraint(initial_conditions,
|
||||
sharding)
|
||||
|
||||
print(f"sharded initial conditions {initial_conditions.sharding}")
|
||||
cosmo._workspace = {}
|
||||
|
||||
@jax.jit
|
||||
def forward_model(initial_conditions, cosmo):
|
||||
|
||||
dx, p, _ = lpt(cosmo,
|
||||
initial_conditions,
|
||||
a=0.1,
|
||||
order=order,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(mesh_shape,
|
||||
paint_absolute_pos=False,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding))
|
||||
y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p)
|
||||
|
||||
solver = Dopri5()
|
||||
controller = PIDController(rtol=1e-8,
|
||||
atol=1e-8,
|
||||
pcoeff=0.4,
|
||||
icoeff=1,
|
||||
dcoeff=0)
|
||||
|
||||
saveat = SaveAt(t1=True)
|
||||
solutions = diffeqsolve(ode_fn,
|
||||
solver,
|
||||
t0=0.1,
|
||||
t1=1.0,
|
||||
dt0=None,
|
||||
y0=y0,
|
||||
args=cosmo,
|
||||
stepsize_controller=controller,
|
||||
saveat=saveat)
|
||||
|
||||
multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0],
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
|
||||
return multi_device_final_field
|
||||
|
||||
@jax.jit
|
||||
def model(initial_conditions, cosmo):
|
||||
final_field = forward_model(initial_conditions, cosmo)
|
||||
return MSE(final_field,
|
||||
nbody_from_lpt1 if order == 1 else nbody_from_lpt2)
|
||||
|
||||
obs_val = model(initial_conditions, cosmo)
|
||||
|
||||
shifted_initial_conditions = initial_conditions + jax.random.normal(
|
||||
jax.random.key(42), initial_conditions.shape) * 5
|
||||
|
||||
good_grads = jax.grad(model)(initial_conditions, cosmo)
|
||||
off_grads = jax.grad(model)(shifted_initial_conditions, cosmo)
|
||||
|
||||
assert good_grads.sharding.is_equivalent_to(initial_conditions.sharding,
|
||||
ndim=3)
|
||||
assert off_grads.sharding.is_equivalent_to(initial_conditions.sharding,
|
||||
ndim=3)
|
||||
|
||||
|
||||
@pytest.mark.distributed
|
||||
@pytest.mark.parametrize("pdims", pdims)
|
||||
def test_fwd_rev_gradients(cosmo, pdims):
|
||||
|
||||
mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0)
|
||||
cosmo._workspace = {}
|
||||
|
||||
mesh = jax.make_mesh(pdims, ('x', 'y'))
|
||||
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||
halo_size = mesh_shape[0] // 2
|
||||
|
||||
initial_conditions = jax.random.normal(jax.random.PRNGKey(42), mesh_shape)
|
||||
initial_conditions = lax.with_sharding_constraint(initial_conditions,
|
||||
sharding)
|
||||
print(f"sharded initial conditions {initial_conditions.sharding}")
|
||||
cosmo._workspace = {}
|
||||
|
||||
@partial(jax.jit, static_argnums=(2, 3, 4))
|
||||
def compute_forces(initial_conditions,
|
||||
cosmo,
|
||||
a=0.5,
|
||||
halo_size=0,
|
||||
sharding=None):
|
||||
|
||||
paint_absolute_pos = False
|
||||
particles = jnp.zeros_like(initial_conditions,
|
||||
shape=(*initial_conditions.shape, 3))
|
||||
|
||||
a = jnp.atleast_1d(a)
|
||||
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
|
||||
|
||||
initial_conditions = jax.lax.with_sharding_constraint(
|
||||
initial_conditions, sharding)
|
||||
delta_k = fft3d(initial_conditions)
|
||||
out_sharding = get_fft_output_sharding(sharding)
|
||||
delta_k = jax.lax.with_sharding_constraint(delta_k, out_sharding)
|
||||
|
||||
initial_force = pm_forces(particles,
|
||||
delta=delta_k,
|
||||
paint_absolute_pos=paint_absolute_pos,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
|
||||
return initial_force[..., 0]
|
||||
|
||||
forces = compute_forces(initial_conditions,
|
||||
cosmo,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
back_gradient = jax.jacrev(compute_forces)(initial_conditions,
|
||||
cosmo,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
fwd_gradient = jax.jacfwd(compute_forces)(initial_conditions,
|
||||
cosmo,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
|
||||
print(f"Forces sharding is {forces.sharding}")
|
||||
print(f"Backward gradient sharding is {back_gradient.sharding}")
|
||||
print(f"Forward gradient sharding is {fwd_gradient.sharding}")
|
||||
assert forces.sharding.is_equivalent_to(initial_conditions.sharding,
|
||||
ndim=3)
|
||||
assert back_gradient[0, 0, 0, ...].sharding.is_equivalent_to(
|
||||
initial_conditions.sharding, ndim=3)
|
||||
assert fwd_gradient.sharding.is_equivalent_to(initial_conditions.sharding,
|
||||
ndim=3)
|
||||
|
||||
|
||||
@pytest.mark.distributed
|
||||
@pytest.mark.parametrize("pdims", pdims)
|
||||
def test_vmap(cosmo, pdims):
|
||||
|
||||
mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0)
|
||||
cosmo._workspace = {}
|
||||
|
||||
mesh = jax.make_mesh(pdims, ('x', 'y'))
|
||||
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||
halo_size = mesh_shape[0] // 2
|
||||
|
||||
single_dev_initial_conditions = jax.random.normal(jax.random.PRNGKey(42),
|
||||
mesh_shape)
|
||||
initial_conditions = lax.with_sharding_constraint(
|
||||
single_dev_initial_conditions, sharding)
|
||||
|
||||
single_ics = jnp.stack([
|
||||
single_dev_initial_conditions, single_dev_initial_conditions,
|
||||
single_dev_initial_conditions
|
||||
])
|
||||
sharded_ics = jnp.stack(
|
||||
[initial_conditions, initial_conditions, initial_conditions])
|
||||
print(f"unsharded initial conditions batch {single_ics.sharding}")
|
||||
print(f"sharded initial conditions batch {sharded_ics.sharding}")
|
||||
cosmo._workspace = {}
|
||||
|
||||
@partial(jax.jit, static_argnums=(2, 3, 4))
|
||||
def compute_forces(initial_conditions,
|
||||
cosmo,
|
||||
a=0.5,
|
||||
halo_size=0,
|
||||
sharding=None):
|
||||
|
||||
paint_absolute_pos = False
|
||||
particles = jnp.zeros_like(initial_conditions,
|
||||
shape=(*initial_conditions.shape, 3))
|
||||
|
||||
a = jnp.atleast_1d(a)
|
||||
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
|
||||
|
||||
initial_conditions = jax.lax.with_sharding_constraint(
|
||||
initial_conditions, sharding)
|
||||
delta_k = fft3d(initial_conditions)
|
||||
out_sharding = get_fft_output_sharding(sharding)
|
||||
delta_k = jax.lax.with_sharding_constraint(delta_k, out_sharding)
|
||||
|
||||
initial_force = pm_forces(particles,
|
||||
delta=delta_k,
|
||||
paint_absolute_pos=paint_absolute_pos,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
|
||||
return initial_force[..., 0]
|
||||
|
||||
def fn(ic):
|
||||
return compute_forces(ic,
|
||||
cosmo,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
|
||||
v_compute_forces = jax.vmap(fn)
|
||||
|
||||
print(f"single_ics shape {single_ics.shape}")
|
||||
print(f"sharded_ics shape {sharded_ics.shape}")
|
||||
|
||||
single_dev_forces = v_compute_forces(single_ics)
|
||||
sharded_forces = v_compute_forces(sharded_ics)
|
||||
|
||||
assert single_dev_forces.ndim == 4
|
||||
assert sharded_forces.ndim == 4
|
||||
|
||||
print(f"Sharded forces {sharded_forces.sharding}")
|
||||
|
||||
assert sharded_forces[0].sharding.is_equivalent_to(
|
||||
initial_conditions.sharding, ndim=3)
|
||||
assert sharded_forces.sharding.spec[0] == None
|
||||
|
|
88
tests/test_gradients.py
Normal file
88
tests/test_gradients.py
Normal file
|
@ -0,0 +1,88 @@
|
|||
import jax
|
||||
import pytest
|
||||
from diffrax import (BacksolveAdjoint, Dopri5, ODETerm, PIDController,
|
||||
RecursiveCheckpointAdjoint, SaveAt, diffeqsolve)
|
||||
from helpers import MSE
|
||||
from jax import numpy as jnp
|
||||
|
||||
from jaxpm.distributed import uniform_particles
|
||||
from jaxpm.painting import cic_paint, cic_paint_dx
|
||||
from jaxpm.pm import lpt, make_diffrax_ode
|
||||
|
||||
|
||||
@pytest.mark.single_device
|
||||
@pytest.mark.parametrize("order", [1, 2])
|
||||
@pytest.mark.parametrize("absolute_painting", [True, False])
|
||||
@pytest.mark.parametrize("adjoint", ['DTO', 'OTD'])
|
||||
def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
|
||||
nbody_from_lpt1, nbody_from_lpt2, cosmo, order,
|
||||
absolute_painting, adjoint):
|
||||
|
||||
mesh_shape, _ = simulation_config
|
||||
cosmo._workspace = {}
|
||||
|
||||
if adjoint == 'OTD':
|
||||
pytest.skip("OTD adjoint not implemented yet (needs PFFT3D JVP)")
|
||||
|
||||
adjoint = RecursiveCheckpointAdjoint(
|
||||
) if adjoint == 'DTO' else BacksolveAdjoint(solver=Dopri5())
|
||||
|
||||
@jax.jit
|
||||
@jax.grad
|
||||
def forward_model(initial_conditions, cosmo):
|
||||
|
||||
# Initial displacement
|
||||
if absolute_painting:
|
||||
particles = uniform_particles(mesh_shape)
|
||||
dx, p, _ = lpt(cosmo,
|
||||
initial_conditions,
|
||||
particles,
|
||||
a=lpt_scale_factor,
|
||||
order=order)
|
||||
ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
|
||||
y0 = jnp.stack([particles + dx, p])
|
||||
|
||||
else:
|
||||
dx, p, _ = lpt(cosmo,
|
||||
initial_conditions,
|
||||
a=lpt_scale_factor,
|
||||
order=order)
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
|
||||
y0 = jnp.stack([dx, p])
|
||||
|
||||
solver = Dopri5()
|
||||
controller = PIDController(rtol=1e-7,
|
||||
atol=1e-7,
|
||||
pcoeff=0.4,
|
||||
icoeff=1,
|
||||
dcoeff=0)
|
||||
|
||||
saveat = SaveAt(t1=True)
|
||||
|
||||
solutions = diffeqsolve(ode_fn,
|
||||
solver,
|
||||
t0=lpt_scale_factor,
|
||||
t1=1.0,
|
||||
dt0=None,
|
||||
y0=y0,
|
||||
args=cosmo,
|
||||
adjoint=adjoint,
|
||||
stepsize_controller=controller,
|
||||
saveat=saveat)
|
||||
|
||||
if absolute_painting:
|
||||
final_field = cic_paint(jnp.zeros(mesh_shape), solutions.ys[-1, 0])
|
||||
else:
|
||||
final_field = cic_paint_dx(solutions.ys[-1, 0])
|
||||
|
||||
return MSE(final_field,
|
||||
nbody_from_lpt1 if order == 1 else nbody_from_lpt2)
|
||||
|
||||
bad_initial_conditions = initial_conditions + jax.random.normal(
|
||||
jax.random.PRNGKey(0), initial_conditions.shape) * 0.5
|
||||
best_ic = forward_model(initial_conditions, cosmo)
|
||||
bad_ic = forward_model(bad_initial_conditions, cosmo)
|
||||
|
||||
assert jnp.max(best_ic) < 1e-5
|
||||
assert jnp.max(bad_ic) > 1e-5
|
Loading…
Add table
Add a link
Reference in a new issue