jaxdecomp proto (#21)

* adding example of distributed solution

* put back old functgion

* update formatting

* add halo exchange and slice pad

* apply formatting

* implement distributed optimized cic_paint

* Use new cic_paint with halo

* Fix seed for distributed normal

* Wrap interpolation function to avoid all gather

* Return normal order frequencies for single GPU

* add example

* format

* add optimised bench script

* times in ms

* add lpt2

* update benchmark and add slurm

* Visualize only final field

* Update scripts/distributed_pm.py

Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>

* Adjust pencil type for frequencies

* fix painting issue with slabs

* Shared operation in fourrier space now take inverted sharding axis for
slabs

* add assert to make pyright happy

* adjust test for hpc-plotter

* add PMWD test

* bench

* format

* added github workflow

* fix formatting from main

* Update for jaxDecomp pure JAX

* revert single halo extent change

* update for latest jaxDecomp

* remove fourrier_space in autoshmap

* make normal_field work with single controller

* format

* make distributed pm work in single controller

* merge bench_pm

* update to leapfrog

* add a strict dependency on jaxdecomp

* global mesh no longer needed

* kernels.py no longer uses global mesh

* quick fix in distributed

* pm.py no longer uses global mesh

* painting.py no longer uses global mesh

* update demo script

* quick fix in kernels

* quick fix in distributed

* update demo

* merge hugos LPT2 code

* format

* Small fix

* format

* remove duplicate get_ode_fn

* update visualizer

* update compensate CIC

* By default check_rep is false for shard_map

* remove experimental distributed code

* update PGDCorrection and neural ode to use new fft3d

* jaxDecomp pfft3d promotes to complex automatically

* remove deprecated stuff

* fix painting issue with read_cic

* use jnp interp instead of jc interp

* delete old slurms

* add notebook examples

* apply formatting

* add distributed zeros

* fix code in LPT2

* jit cic_paint

* update notebooks

* apply formating

* get local shape and zeros can be used by users

* add a user facing function to create uniform particle grid

* use jax interp instead of jax_cosmo

* use float64 for enmeshing

* Allow applying weights with relative cic paint

* Weights can be traced

* remove script folder

* update example notebooks

* delete outdated design file

* add readme for tutorials

* update readme

* fix small error

* forgot particles in multi host

* clarifying why cic_paint_dx is slower

* clarifying the halo size dependence on the box size

* ability to choose snapshots number with MultiHost script

* Adding animation notebook

* Put plotting in package

* Add finite difference laplace kernel + powerspec functions from Hugo

Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com>

* Put plotting utils in package

* By default use absoulute painting with

* update code

* update notebooks

* add tests

* Upgrade setup.py to pyproject

* Format

* format tests

* update test dependencies

* add test workflow

* fix deprecated FftType in jaxpm.kernels

* Add aboucaud comments

* JAX version is 0.4.35 until Diffrax new release

* add numpy explicitly as dependency for tests

* fix install order for tests

* add numpy to be installed

* enforce no build isolation for fastpm

* pip install jaxpm test without build isolation

* bump jaxdecomp version

* revert test workflow

* remove outdated tests

---------

Co-authored-by: EiffL <fr.eiffel@gmail.com>
Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>
Co-authored-by: Wassim KABALAN <wassim@apc.in2p3.fr>
Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com>
Former-commit-id: 8c2e823d4669eac712089bf7f85ffb7912e8232d
This commit is contained in:
Wassim KABALAN 2024-12-20 11:44:02 +01:00 committed by GitHub
parent a0a79277e5
commit df8602b318
26 changed files with 1871 additions and 434 deletions

21
.github/workflows/formatting.yml vendored Normal file
View file

@ -0,0 +1,21 @@
name: Code Formatting
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
- name: Install dependencies
run: |
python -m pip install --upgrade pip isort
python -m pip install pre-commit
- name: Run pre-commit
run: python -m pre_commit run --all-files

45
.github/workflows/tests.yml vendored Normal file
View file

@ -0,0 +1,45 @@
name: Tests
on:
push:
branches:
- main
pull_request:
branches:
- main
jobs:
run_tests:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10" , "3.11" , "3.12"]
steps:
- name: Checkout Source
uses: actions/checkout@v2.3.1
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
sudo apt-get install -y libopenmpi-dev
python -m pip install --upgrade pip
pip install jax==0.4.35
pip install numpy setuptools cython wheel
pip install git+https://github.com/MP-Gadget/pfft-python
pip install git+https://github.com/MP-Gadget/pmesh
pip install git+https://github.com/ASKabalan/fastpm-python --no-build-isolation
pip install .[test]
- name: Run Single Device Tests
run: |
cd tests
pytest -v -m "not distributed"
- name: Run Distributed tests
run: |
pytest -v -m distributed

5
.gitignore vendored
View file

@ -98,6 +98,11 @@ __pypackages__/
celerybeat-schedule
celerybeat.pid
out
traces
*.npy
*.out
# SageMath parsed files
*.sage.py

View file

@ -4,16 +4,15 @@
<!-- ALL-CONTRIBUTORS-BADGE:END -->
JAX-powered Cosmological Particle-Mesh N-body Solver
**This project is currently in an early design phase. All inputs are welcome on the [design document](https://github.com/DifferentiableUniverseInitiative/JaxPM/blob/main/design.md)**
## Goals
Provide a modern infrastructure to support differentiable PM N-body simulations using JAX:
- Keep implementation simple and readable, in pure NumPy API
- Transparent distribution using builtin `xmap`
- Any order forward and backward automatic differentiation
- Support automated batching using `vmap`
- Compatibility with external optimizer libraries like `optax`
- Now fully distributable on **multi-GPU and multi-node** systems using [jaxDecomp](https://github.com/DifferentiableUniverseInitiative/jaxDecomp) working with`JAX v0.4.35`
## Open development and use
@ -23,6 +22,10 @@ Current expectations are:
- Everyone is welcome to contribute, and can join the JOSS publication (until it is submitted to the journal).
- Anyone (including main contributors) can use this code as a framework to build and publish their own applications, with no expectation that they *need* to extend authorship to all jaxpm developers.
## Getting Started
To dive into JaxPMs capabilities, please explore the **notebook section** for detailed tutorials and examples on various setups, from single-device simulations to multi-host configurations. You can find the notebooks' [README here](notebooks/README.md) for a structured guide through each tutorial.
## Contributors ✨

View file

@ -1,52 +0,0 @@
# Design Document for JaxPM
This document aims to detail some of the API, implementation choices, and internal mechanism.
## Objective
Provide a user-friendly framework for distributed Particle-Mesh N-body simulations.
## Related Work
This project would be the latest iteration of a number of past libraries that have provided differentiable N-body models.
- [FlowPM](https://github.com/DifferentiableUniverseInitiative/flowpm): TensorFlow
- [vmad FastPM](https://github.com/rainwoodman/vmad/blob/master/vmad/lib/fastpm.py): VMAD
- Borg
In addition, a number of fast N-body simulation projets exist out there:
- [FastPM](https://github.com/fastpm/fastpm)
- ...
## Design Overview
### Coding principles
Following recent trends and JAX philosophy, the library should have a functional programming type of interface.
### Illustration of API
Here is a potential illustration of what the user interface could be for the simulation code:
```python
import jaxpm as jpm
import jax_cosmo as jc
# Instantiate differentiable cosmology object
cosmo = jc.Planck()
# Creates initial conditions
inital_conditions = jpm.generate_ic(cosmo, boxsize, nmesh, dtype='float32')
# Create a particular solver
solver = jpm.solvers.fastpm(cosmo, B=1)
# Initialize and run the simulation
state = solver.init(initial_conditions)
state = solver.nbody(state)
# Painting the results
density = jpm.zeros(boxsize, nmesh)
density = jpm.paint(density, state.positions)
```

View file

@ -1,96 +0,0 @@
# Can be executed with:
# srun -n 4 -c 32 --gpus-per-task 1 --gpu-bind=none python test_pfft.py
from functools import partial
import jax
import jax.lax as lax
import jax.numpy as jnp
import numpy as np
from jax.experimental.maps import Mesh, xmap
from jax.experimental.pjit import PartitionSpec, pjit
jax.distributed.initialize()
cube_size = 2048
@partial(xmap,
in_axes=[...],
out_axes=['x', 'y', ...],
axis_sizes={
'x': cube_size,
'y': cube_size
},
axis_resources={
'x': 'nx',
'y': 'ny',
'key_x': 'nx',
'key_y': 'ny'
})
def pnormal(key):
return jax.random.normal(key, shape=[cube_size])
@partial(xmap,
in_axes={
0: 'x',
1: 'y'
},
out_axes=['x', 'y', ...],
axis_resources={
'x': 'nx',
'y': 'ny'
})
@jax.jit
def pfft3d(mesh):
# [x, y, z]
mesh = jnp.fft.fft(mesh) # Transform on z
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now x is exposed, [z,y,x]
mesh = jnp.fft.fft(mesh) # Transform on x
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now y is exposed, [z,x,y]
mesh = jnp.fft.fft(mesh) # Transform on y
# [z, x, y]
return mesh
@partial(xmap,
in_axes={
0: 'x',
1: 'y'
},
out_axes=['x', 'y', ...],
axis_resources={
'x': 'nx',
'y': 'ny'
})
@jax.jit
def pifft3d(mesh):
# [z, x, y]
mesh = jnp.fft.ifft(mesh) # Transform on y
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now x is exposed, [z,y,x]
mesh = jnp.fft.ifft(mesh) # Transform on x
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now z is exposed, [x,y,z]
mesh = jnp.fft.ifft(mesh) # Transform on z
# [x, y, z]
return mesh
key = jax.random.PRNGKey(42)
# keys = jax.random.split(key, 4).reshape((2,2,2))
# We reshape all our devices to the mesh shape we want
devices = np.array(jax.devices()).reshape((2, 4))
with Mesh(devices, ('nx', 'ny')):
mesh = pnormal(key)
kmesh = pfft3d(mesh)
kmesh.block_until_ready()
# jax.profiler.start_trace("tensorboard")
# with Mesh(devices, ('nx', 'ny')):
# mesh = pnormal(key)
# kmesh = pfft3d(mesh)
# kmesh.block_until_ready()
# jax.profiler.stop_trace()
print('Done')

View file

@ -1,68 +0,0 @@
# Start this script with:
# mpirun -np 4 python test_script.py
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'
import jax
import jax.lax as lax
import jax.numpy as jnp
import matplotlib.pylab as plt
import numpy as np
import tensorflow_probability as tfp
from jax.experimental.maps import mesh, xmap
from jax.experimental.pjit import PartitionSpec, pjit
tfp = tfp.substrates.jax
tfd = tfp.distributions
def cic_paint(mesh, positions):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [npart, 3]
"""
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1,
2))
mesh = lax.scatter_add(
mesh,
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
kernel.reshape([-1, 8]), dnums)
return mesh
# And let's draw some points from some 3D distribution
dist = tfd.MultivariateNormalDiag(loc=[16., 16., 16.],
scale_identity_multiplier=3.)
pos = dist.sample(1e4, seed=jax.random.PRNGKey(0))
f = pjit(lambda x: cic_paint(x, pos),
in_axis_resources=PartitionSpec('x', 'y', 'z'),
out_axis_resources=None)
devices = np.array(jax.devices()).reshape((2, 2, 1))
# Let's import the mesh
m = jnp.zeros([32, 32, 32])
with mesh(devices, ('x', 'y', 'z')):
# Shard the mesh, I'm not sure this is absolutely necessary
m = pjit(lambda x: x,
in_axis_resources=None,
out_axis_resources=PartitionSpec('x', 'y', 'z'))(m)
# Apply the sharded CiC function
res = f(m)
plt.imshow(res.sum(axis=2))
plt.show()

198
jaxpm/distributed.py Normal file
View file

@ -0,0 +1,198 @@
from typing import Any, Callable, Hashable
Specs = Any
AxisName = Hashable
from functools import partial
import jax
import jax.numpy as jnp
import jaxdecomp
from jax import lax
from jax.experimental.shard_map import shard_map
from jax.sharding import AbstractMesh, Mesh
from jax.sharding import PartitionSpec as P
def autoshmap(
f: Callable,
gpu_mesh: Mesh | AbstractMesh | None,
in_specs: Specs,
out_specs: Specs,
check_rep: bool = False,
auto: frozenset[AxisName] = frozenset()) -> Callable:
"""Helper function to wrap the provided function in a shard map if
the code is being executed in a mesh context."""
if gpu_mesh is None or gpu_mesh.empty:
return f
else:
return shard_map(f, gpu_mesh, in_specs, out_specs, check_rep, auto)
def fft3d(x):
return jaxdecomp.pfft3d(x)
def ifft3d(x):
return jaxdecomp.pifft3d(x).real
def get_halo_size(halo_size, sharding):
gpu_mesh = sharding.mesh if sharding is not None else None
if gpu_mesh is None or gpu_mesh.empty:
zero_ext = (0, 0)
zero_tuple = (0, 0)
return (zero_tuple, zero_tuple, zero_tuple), zero_ext
else:
pdims = gpu_mesh.devices.shape
halo_x = (0, 0) if pdims[0] == 1 else (halo_size, halo_size)
halo_y = (0, 0) if pdims[1] == 1 else (halo_size, halo_size)
halo_x_ext = 0 if pdims[0] == 1 else halo_size // 2
halo_y_ext = 0 if pdims[1] == 1 else halo_size // 2
return ((halo_x, halo_y, (0, 0)), (halo_x_ext, halo_y_ext))
def halo_exchange(x, halo_extents, halo_periods=(True, True)):
if (halo_extents[0] > 0 or halo_extents[1] > 0):
return jaxdecomp.halo_exchange(x, halo_extents, halo_periods)
else:
return x
def slice_unpad_impl(x, pad_width):
halo_x, _ = pad_width[0]
halo_y, _ = pad_width[1]
# Apply corrections along x
x = x.at[halo_x:halo_x + halo_x // 2].add(x[:halo_x // 2])
x = x.at[-(halo_x + halo_x // 2):-halo_x].add(x[-halo_x // 2:])
# Apply corrections along y
x = x.at[:, halo_y:halo_y + halo_y // 2].add(x[:, :halo_y // 2])
x = x.at[:, -(halo_y + halo_y // 2):-halo_y].add(x[:, -halo_y // 2:])
unpad_slice = [slice(None)] * 3
if halo_x > 0:
unpad_slice[0] = slice(halo_x, -halo_x)
if halo_y > 0:
unpad_slice[1] = slice(halo_y, -halo_y)
return x[tuple(unpad_slice)]
def slice_pad(x, pad_width, sharding):
gpu_mesh = sharding.mesh if sharding is not None else None
if gpu_mesh is not None and not (gpu_mesh.empty) and (
pad_width[0][0] > 0 or pad_width[1][0] > 0):
assert sharding is not None
spec = sharding.spec
return shard_map((partial(jnp.pad, pad_width=pad_width)),
mesh=gpu_mesh,
in_specs=spec,
out_specs=spec)(x)
else:
return x
def slice_unpad(x, pad_width, sharding):
mesh = sharding.mesh if sharding is not None else None
if mesh is not None and not (mesh.empty) and (pad_width[0][0] > 0
or pad_width[1][0] > 0):
assert sharding is not None
spec = sharding.spec
return shard_map(partial(slice_unpad_impl, pad_width=pad_width),
mesh=mesh,
in_specs=spec,
out_specs=spec)(x)
else:
return x
def get_local_shape(mesh_shape, sharding=None):
""" Helper function to get the local size of a mesh given the global size.
"""
gpu_mesh = sharding.mesh if sharding is not None else None
if gpu_mesh is None or gpu_mesh.empty:
return mesh_shape
else:
pdims = gpu_mesh.devices.shape
return [
mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1],
*mesh_shape[2:]
]
def _axis_names(spec):
if len(spec) == 1:
x_axis, = spec
y_axis = None
single_axis = True
elif len(spec) == 2:
x_axis, y_axis = spec
if y_axis == None:
single_axis = True
elif x_axis == None:
x_axis = y_axis
single_axis = True
else:
single_axis = False
else:
raise ValueError("Only 1 or 2 axis sharding is supported")
return x_axis, y_axis, single_axis
def uniform_particles(mesh_shape, sharding=None):
gpu_mesh = sharding.mesh if sharding is not None else None
if gpu_mesh is not None and not (gpu_mesh.empty):
local_mesh_shape = get_local_shape(mesh_shape, sharding)
spec = sharding.spec
x_axis, y_axis, single_axis = _axis_names(spec)
def particles():
x_indx = lax.axis_index(x_axis)
y_indx = 0 if single_axis else lax.axis_index(y_axis)
x = jnp.arange(local_mesh_shape[0]) + x_indx * local_mesh_shape[0]
y = jnp.arange(local_mesh_shape[1]) + y_indx * local_mesh_shape[1]
z = jnp.arange(local_mesh_shape[2])
return jnp.stack(jnp.meshgrid(x, y, z, indexing='ij'), axis=-1)
return shard_map(particles, mesh=gpu_mesh, in_specs=(),
out_specs=spec)()
else:
return jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape],
indexing='ij'),
axis=-1)
def normal_field(mesh_shape, seed, sharding=None):
"""Generate a Gaussian random field with the given power spectrum."""
gpu_mesh = sharding.mesh if sharding is not None else None
if gpu_mesh is not None and not (gpu_mesh.empty):
local_mesh_shape = get_local_shape(mesh_shape, sharding)
size = jax.device_count()
# rank = jax.process_index()
# process_index is multi_host only
# to make the code work both in multi host and single controller we can do this trick
keys = jax.random.split(seed, size)
spec = sharding.spec
x_axis, y_axis, single_axis = _axis_names(spec)
def normal(keys, shape, dtype):
idx = lax.axis_index(x_axis)
if not single_axis:
y_index = lax.axis_index(y_axis)
x_size = lax.psum(1, axis_name=x_axis)
idx += y_index * x_size
return jax.random.normal(key=keys[idx], shape=shape, dtype=dtype)
return shard_map(
partial(normal, shape=local_mesh_shape, dtype='float32'),
mesh=gpu_mesh,
in_specs=P(None),
out_specs=spec)(keys) # yapf: disable
else:
return jax.random.normal(shape=mesh_shape, key=seed)

View file

@ -1,6 +1,6 @@
import jax.numpy as np
from jax.numpy import interp
from jax_cosmo.background import *
from jax_cosmo.scipy.interpolate import interp
from jax_cosmo.scipy.ode import odeint
@ -587,5 +587,6 @@ def dGf2a(cosmo, a):
cache = cosmo._workspace['background.growth_factor']
f2p = cache['h2'] / cache['a'] * cache['g2']
f2p = interp(np.log(a), np.log(cache['a']), f2p)
E = E(cosmo, a)
return (f2p * a**3 * E + D2f * a**3 * dEa(cosmo, a) + 3 * a**2 * E * D2f)
E_a = E(cosmo, a)
return (f2p * a**3 * E_a + D2f * a**3 * dEa(cosmo, a) +
3 * a**2 * E_a * D2f)

View file

@ -1,30 +1,46 @@
import jax.numpy as jnp
import numpy as np
from jax.lax import FftType
from jax.sharding import PartitionSpec as P
from jaxdecomp import fftfreq3d, get_output_specs
from jaxpm.distributed import autoshmap
def fftk(shape, symmetric=True, finite=False, dtype=np.float32):
def fftk(k_array):
"""
Return wave-vectors for a given shape
"""
k = []
for d in range(len(shape)):
kd = np.fft.fftfreq(shape[d])
kd *= 2 * np.pi
kdshape = np.ones(len(shape), dtype='int')
if symmetric and d == len(shape) - 1:
kd = kd[:shape[d] // 2 + 1]
kdshape[d] = len(kd)
kd = kd.reshape(kdshape)
Generate Fourier transform wave numbers for a given mesh.
k.append(kd.astype(dtype))
del kd, kdshape
return k
Args:
nc (int): Shape of the mesh grid.
Returns:
list: List of wave number arrays for each dimension in
the order [kx, ky, kz].
"""
kx, ky, kz = fftfreq3d(k_array)
# to the order of dimensions in the transposed FFT
return kx, ky, kz
def interpolate_power_spectrum(input, k, pk, sharding=None):
pk_fn = lambda x: jnp.interp(x.reshape(-1), k, pk).reshape(x.shape)
gpu_mesh = sharding.mesh if sharding is not None else None
specs = sharding.spec if sharding is not None else P()
out_specs = P(*get_output_specs(
FftType.FFT, specs, mesh=gpu_mesh)) if gpu_mesh is not None else P()
return autoshmap(pk_fn,
gpu_mesh=gpu_mesh,
in_specs=out_specs,
out_specs=out_specs)(input)
def gradient_kernel(kvec, direction, order=1):
"""
Computes the gradient kernel in the requested direction
Parameters
-----------
kvec: list
@ -50,20 +66,27 @@ def gradient_kernel(kvec, direction, order=1):
return wts
def invlaplace_kernel(kvec):
def invlaplace_kernel(kvec, fd=False):
"""
Compute the inverse Laplace kernel
Compute the inverse Laplace kernel.
cf. [Feng+2016](https://arxiv.org/pdf/1603.00476)
Parameters
-----------
kvec: list
List of wave-vectors
fd: bool
Finite difference kernel
Returns
--------
wts: array
Complex kernel values
"""
if fd:
kk = sum((ki * jnp.sinc(ki / (2 * jnp.pi)))**2 for ki in kvec)
else:
kk = sum(ki**2 for ki in kvec)
kk_nozeros = jnp.where(kk == 0, 1, kk)
return -jnp.where(kk == 0, 0, 1 / kk_nozeros)
@ -79,12 +102,10 @@ def longrange_kernel(kvec, r_split):
List of wave-vectors
r_split: float
Splitting radius
Returns
--------
wts: array
Complex kernel values
TODO: @modichirag add documentation
"""
if r_split != 0:
@ -105,13 +126,12 @@ def cic_compensation(kvec):
-----------
kvec: list
List of wave-vectors
Returns:
--------
wts: array
Complex kernel values
"""
kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
kwts = [jnp.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
return wts

View file

@ -1,15 +1,24 @@
from functools import partial
import jax
import jax.lax as lax
import jax.numpy as jnp
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P
from jaxpm.distributed import (autoshmap, fft3d, get_halo_size, halo_exchange,
ifft3d, slice_pad, slice_unpad)
from jaxpm.kernels import cic_compensation, fftk
from jaxpm.painting_utils import gather, scatter
def cic_paint(mesh, positions, weight=None):
def _cic_paint_impl(grid_mesh, positions, weight=None):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [npart, 3]
displacement field: [nx, ny, nz, 3]
"""
positions = positions.reshape([-1, 3])
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
@ -19,40 +28,98 @@ def cic_paint(mesh, positions, weight=None):
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
if weight is not None:
if jnp.isscalar(weight):
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
else:
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]),
kernel)
neighboor_coords = jnp.mod(
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
jnp.array(mesh.shape))
jnp.array(grid_mesh.shape))
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1,
2))
mesh = lax.scatter_add(mesh, neighboor_coords, kernel.reshape([-1, 8]),
dnums)
mesh = lax.scatter_add(grid_mesh, neighboor_coords,
kernel.reshape([-1, 8]), dnums)
return mesh
def cic_read(mesh, positions):
@partial(jax.jit, static_argnums=(3, 4))
def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
positions = positions.reshape((*grid_mesh.shape, 3))
halo_size, halo_extents = get_halo_size(halo_size, sharding)
grid_mesh = slice_pad(grid_mesh, halo_size, sharding)
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
grid_mesh = autoshmap(_cic_paint_impl,
gpu_mesh=gpu_mesh,
in_specs=(spec, spec, P()),
out_specs=spec)(grid_mesh, positions, weight)
grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,
halo_periods=(True, True))
grid_mesh = slice_unpad(grid_mesh, halo_size, sharding)
return grid_mesh
def _cic_read_impl(grid_mesh, positions):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [npart, 3]
positions: [nx,ny,nz, 3]
"""
# Save original shape for reshaping output later
original_shape = positions.shape
# Reshape positions to a flat list of 3D coordinates
positions = positions.reshape([-1, 3])
# Expand dimensions to calculate neighbor coordinates
positions = jnp.expand_dims(positions, 1)
# Floor the positions to get the base grid cell for each particle
floor = jnp.floor(positions)
# Define connections to calculate all neighbor coordinates
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
# Calculate the 8 neighboring coordinates
neighboor_coords = floor + connection
# Calculate kernel weights based on distance from each neighboring coordinate
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
# Modulo operation to wrap around edges if necessary
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'),
jnp.array(mesh.shape))
jnp.array(grid_mesh.shape))
# Ensure grid_mesh shape is as expected
# Retrieve values from grid_mesh at each neighboring coordinate and multiply by kernel
return (grid_mesh[neighboor_coords[..., 0],
neighboor_coords[..., 1],
neighboor_coords[..., 2]] * kernel).sum(axis=-1).reshape(original_shape[:-1]) # yapf: disable
return (mesh[neighboor_coords[..., 0], neighboor_coords[..., 1],
neighboor_coords[..., 3]] * kernel).sum(axis=-1)
@partial(jax.jit, static_argnums=(2, 3))
def cic_read(grid_mesh, positions, halo_size=0, sharding=None):
original_shape = positions.shape
positions = positions.reshape((*grid_mesh.shape, 3))
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,
halo_periods=(True, True))
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
displacement = autoshmap(_cic_read_impl,
gpu_mesh=gpu_mesh,
in_specs=(spec, spec),
out_specs=spec)(grid_mesh, positions)
return displacement.reshape(original_shape[:-1])
def cic_paint_2d(mesh, positions, weight):
@ -84,6 +151,99 @@ def cic_paint_2d(mesh, positions, weight):
return mesh
def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
halo_x, _ = halo_size[0]
halo_y, _ = halo_size[1]
original_shape = displacements.shape
particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32')
if not jnp.isscalar(weight):
if weight.shape != original_shape[:-1]:
raise ValueError("Weight shape must match particle shape")
else:
weight = weight.flatten()
# Padding is forced to be zero in a single gpu run
a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]),
jnp.arange(particle_mesh.shape[1]),
jnp.arange(particle_mesh.shape[2]),
indexing='ij')
particle_mesh = jnp.pad(particle_mesh, halo_size)
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
return scatter(pmid.reshape([-1, 3]),
displacements.reshape([-1, 3]),
particle_mesh,
chunk_size=2**24,
val=weight)
@partial(jax.jit, static_argnums=(1, 2, 4))
def cic_paint_dx(displacements,
halo_size=0,
sharding=None,
weight=1.0,
chunk_size=2**24):
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
grid_mesh = autoshmap(partial(_cic_paint_dx_impl,
halo_size=halo_size,
weight=weight,
chunk_size=chunk_size),
gpu_mesh=gpu_mesh,
in_specs=spec,
out_specs=spec)(displacements)
grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,
halo_periods=(True, True))
grid_mesh = slice_unpad(grid_mesh, halo_size, sharding)
return grid_mesh
def _cic_read_dx_impl(grid_mesh, disp, halo_size):
halo_x, _ = halo_size[0]
halo_y, _ = halo_size[1]
original_shape = [
dim - 2 * halo[0] for dim, halo in zip(grid_mesh.shape, halo_size)
]
a, b, c = jnp.meshgrid(jnp.arange(original_shape[0]),
jnp.arange(original_shape[1]),
jnp.arange(original_shape[2]),
indexing='ij')
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
pmid = pmid.reshape([-1, 3])
disp = disp.reshape([-1, 3])
return gather(pmid, disp, grid_mesh).reshape(original_shape)
@partial(jax.jit, static_argnums=(2, 3))
def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None):
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,
halo_periods=(True, True))
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
displacements = autoshmap(partial(_cic_read_dx_impl, halo_size=halo_size),
gpu_mesh=gpu_mesh,
in_specs=(spec),
out_specs=spec)(grid_mesh, disp)
return displacements
def compensate_cic(field):
"""
Compensate for CiC painting
@ -92,9 +252,8 @@ def compensate_cic(field):
Returns:
compensated_field
"""
nc = field.shape
kvec = fftk(nc)
delta_k = fft3d(field)
delta_k = jnp.fft.rfftn(field)
kvec = fftk(delta_k)
delta_k = cic_compensation(kvec) * delta_k
return jnp.fft.irfftn(delta_k)
return ifft3d(delta_k)

190
jaxpm/painting_utils.py Normal file
View file

@ -0,0 +1,190 @@
import jax
import jax.numpy as jnp
from jax.lax import scan
def _chunk_split(ptcl_num, chunk_size, *arrays):
"""Split and reshape particle arrays into chunks and remainders, with the remainders
preceding the chunks. 0D ones are duplicated as full arrays in the chunks."""
chunk_size = ptcl_num if chunk_size is None else min(chunk_size, ptcl_num)
remainder_size = ptcl_num % chunk_size
chunk_num = ptcl_num // chunk_size
remainder = None
chunks = arrays
if remainder_size:
remainder = [x[:remainder_size] if x.ndim != 0 else x for x in arrays]
chunks = [x[remainder_size:] if x.ndim != 0 else x for x in arrays]
# `scan` triggers errors in scatter and gather without the `full`
chunks = [
x.reshape(chunk_num, chunk_size, *x.shape[1:])
if x.ndim != 0 else jnp.full(chunk_num, x) for x in chunks
]
return remainder, chunks
def enmesh(base_indices, displacements, cell_size, base_shape, offset,
new_cell_size, new_shape):
"""Multilinear enmeshing."""
base_indices = jnp.asarray(base_indices)
displacements = jnp.asarray(displacements)
with jax.experimental.enable_x64():
cell_size = jnp.float64(
cell_size) if new_cell_size is not None else jnp.array(
cell_size, dtype=displacements.dtype)
if base_shape is not None:
base_shape = jnp.array(base_shape, dtype=base_indices.dtype)
offset = jnp.float64(offset)
if new_cell_size is not None:
new_cell_size = jnp.float64(new_cell_size)
if new_shape is not None:
new_shape = jnp.array(new_shape, dtype=base_indices.dtype)
spatial_dim = base_indices.shape[1]
neighbor_offsets = (
jnp.arange(2**spatial_dim, dtype=base_indices.dtype)[:, jnp.newaxis] >>
jnp.arange(spatial_dim, dtype=base_indices.dtype)) & 1
if new_cell_size is not None:
particle_positions = base_indices * cell_size + displacements - offset
particle_positions = particle_positions[:, jnp.
newaxis] # insert neighbor axis
new_indices = particle_positions + neighbor_offsets * new_cell_size # multilinear
if base_shape is not None:
grid_length = base_shape * cell_size
new_indices %= grid_length
new_indices //= new_cell_size
new_displacements = particle_positions - new_indices * new_cell_size
if base_shape is not None:
new_displacements -= jnp.rint(
new_displacements / grid_length
) * grid_length # also abs(new_displacements) < new_cell_size is expected
new_indices = new_indices.astype(base_indices.dtype)
new_displacements = new_displacements.astype(displacements.dtype)
new_cell_size = new_cell_size.astype(displacements.dtype)
new_displacements /= new_cell_size
else:
offset_indices, offset_displacements = jnp.divmod(offset, cell_size)
base_indices -= offset_indices.astype(base_indices.dtype)
displacements -= offset_displacements.astype(displacements.dtype)
# insert neighbor axis
base_indices = base_indices[:, jnp.newaxis]
displacements = displacements[:, jnp.newaxis]
# multilinear
displacements /= cell_size
new_indices = jnp.floor(displacements).astype(base_indices.dtype)
new_indices += neighbor_offsets
new_displacements = displacements - new_indices
new_indices += base_indices
if base_shape is not None:
new_indices %= base_shape
weights = 1 - jnp.abs(new_displacements)
if base_shape is None and new_shape is not None: # all new_indices >= 0 if base_shape is not None
new_indices = jnp.where(new_indices < 0, new_shape, new_indices)
weights = weights.prod(axis=-1)
return new_indices, weights
def _scatter_chunk(carry, chunk):
mesh, offset, cell_size, mesh_shape = carry
pmid, disp, val = chunk
spatial_ndim = pmid.shape[1]
spatial_shape = mesh.shape
# multilinear mesh indices and fractions
ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size,
spatial_shape)
# scatter
ind = tuple(ind[..., i] for i in range(spatial_ndim))
mesh = mesh.at[ind].add(jnp.multiply(jnp.expand_dims(val, axis=-1), frac))
carry = mesh, offset, cell_size, mesh_shape
return carry, None
def scatter(pmid,
disp,
mesh,
chunk_size=2**24,
val=1.,
offset=0,
cell_size=1.):
ptcl_num, spatial_ndim = pmid.shape
val = jnp.asarray(val)
mesh = jnp.asarray(mesh)
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
carry = mesh, offset, cell_size, mesh.shape
if remainder is not None:
carry = _scatter_chunk(carry, remainder)[0]
carry = scan(_scatter_chunk, carry, chunks)[0]
mesh = carry[0]
return mesh
def _chunk_cat(remainder_array, chunked_array):
"""Reshape and concatenate one remainder and one chunked particle arrays."""
array = chunked_array.reshape(-1, *chunked_array.shape[2:])
if remainder_array is not None:
array = jnp.concatenate((remainder_array, array), axis=0)
return array
def gather(pmid, disp, mesh, chunk_size=2**24, val=0, offset=0, cell_size=1.):
ptcl_num, spatial_ndim = pmid.shape
mesh = jnp.asarray(mesh)
val = jnp.asarray(val)
if mesh.shape[spatial_ndim:] != val.shape[1:]:
raise ValueError('channel shape mismatch: '
f'{mesh.shape[spatial_ndim:]} != {val.shape[1:]}')
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
carry = mesh, offset, cell_size, mesh.shape
val_0 = None
if remainder is not None:
val_0 = _gather_chunk(carry, remainder)[1]
val = scan(_gather_chunk, carry, chunks)[1]
val = _chunk_cat(val_0, val)
return val
def _gather_chunk(carry, chunk):
mesh, offset, cell_size, mesh_shape = carry
pmid, disp, val = chunk
spatial_ndim = pmid.shape[1]
spatial_shape = mesh.shape[:spatial_ndim]
chan_ndim = mesh.ndim - spatial_ndim
chan_axis = tuple(range(-chan_ndim, 0))
# multilinear mesh indices and fractions
ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size,
spatial_shape)
# gather
ind = tuple(ind[..., i] for i in range(spatial_ndim))
frac = jnp.expand_dims(frac, chan_axis)
val += (mesh.at[ind].get(mode='drop', fill_value=0) * frac).sum(axis=1)
return carry, val

129
jaxpm/plotting.py Normal file
View file

@ -0,0 +1,129 @@
import matplotlib.pyplot as plt
import numpy as np
def plot_fields(fields_dict, sum_over=None):
"""
Plots sum projections of 3D fields along different axes,
slicing only the first `sum_over` elements along each axis.
Args:
- fields: list of 3D arrays representing fields to plot
- names: list of names for each field, used in titles
- sum_over: number of slices to sum along each axis (default: fields[0].shape[0] // 8)
"""
sum_over = sum_over or list(fields_dict.values())[0].shape[0] // 8
nb_rows = len(fields_dict)
nb_cols = 3
fig, axes = plt.subplots(nb_rows, nb_cols, figsize=(15, 5 * nb_rows))
def plot_subplots(proj_axis, field, row, title):
slicing = [slice(None)] * field.ndim
slicing[proj_axis] = slice(None, sum_over)
slicing = tuple(slicing)
# Sum projection over the specified axis and plot
axes[row, proj_axis].imshow(
field[slicing].sum(axis=proj_axis) + 1,
cmap='magma',
extent=[0, field.shape[proj_axis], 0, field.shape[proj_axis]])
axes[row, proj_axis].set_xlabel('Mpc/h')
axes[row, proj_axis].set_ylabel('Mpc/h')
axes[row, proj_axis].set_title(title)
# Plot each field across the three axes
for i, (name, field) in enumerate(fields_dict.items()):
for proj_axis in range(3):
plot_subplots(proj_axis, field, i,
f"{name} projection {proj_axis}")
plt.tight_layout()
plt.show()
def plot_fields_single_projection(fields_dict,
sum_over=None,
project_axis=0,
vmin=None,
vmax=None,
colorbar=False):
"""
Plots a single projection (along axis 0) of 3D fields in a grid,
summing over the first `sum_over` elements along the 0-axis, with 4 images per row.
Args:
- fields_dict: dictionary where keys are field names and values are 3D arrays
- sum_over: number of slices to sum along the projection axis (default: fields[0].shape[0] // 8)
"""
sum_over = sum_over or list(fields_dict.values())[0].shape[0] // 8
nb_fields = len(fields_dict)
nb_cols = 4 # Set number of images per row
nb_rows = (nb_fields + nb_cols - 1) // nb_cols # Calculate required rows
fig, axes = plt.subplots(nb_rows,
nb_cols,
figsize=(5 * nb_cols, 5 * nb_rows))
axes = np.atleast_2d(axes) # Ensure axes is always a 2D array
for i, (name, field) in enumerate(fields_dict.items()):
row, col = divmod(i, nb_cols)
# Define the slice for the 0-axis projection
slicing = [slice(None)] * field.ndim
slicing[project_axis] = slice(None, sum_over)
slicing = tuple(slicing)
# Sum projection over axis 0 and plot
a = axes[row,
col].imshow(field[slicing].sum(axis=project_axis) + 1,
cmap='magma',
extent=[0, field.shape[1], 0, field.shape[2]],
vmin=vmin,
vmax=vmax)
axes[row, col].set_xlabel('Mpc/h')
axes[row, col].set_ylabel('Mpc/h')
axes[row, col].set_title(f"{name} projection 0")
if colorbar:
fig.colorbar(a, ax=axes[row, col], shrink=0.7)
# Remove any empty subplots
for j in range(i + 1, nb_rows * nb_cols):
fig.delaxes(axes.flatten()[j])
plt.tight_layout()
plt.show()
def stack_slices(array):
"""
Stacks 2D slices of an array into a single array based on provided partition dimensions.
Args:
- array_slices: a 2D list of array slices (list of lists format) where
array_slices[i][j] is the slice located at row i, column j in the grid.
- pdims: a tuple representing the grid dimensions (rows, columns).
Returns:
- A single array constructed by stacking the slices.
"""
# Initialize an empty list to store the vertically stacked rows
pdims = array.sharding.mesh.devices.shape
field_slices = []
# Iterate over rows in pdims[0]
for i in range(pdims[0]):
row_slices = []
# Iterate over columns in pdims[1]
for j in range(pdims[1]):
slice_index = i * pdims[0] + j
row_slices.append(array.addressable_data(slice_index))
# Stack the current row of slices vertically
stacked_row = np.hstack(row_slices)
field_slices.append(stacked_row)
# Stack all rows horizontally to form the full array
full_array = np.vstack(field_slices)
return full_array

View file

@ -1,50 +1,92 @@
import jax
import jax.numpy as jnp
import jax_cosmo as jc
from jax_cosmo import Cosmology
from jaxpm.growth import growth_factor, growth_rate, dGfa, growth_factor_second, growth_rate_second, dGf2a
from jaxpm.kernels import PGD_kernel, fftk, gradient_kernel, invlaplace_kernel, longrange_kernel
from jaxpm.painting import cic_paint, cic_read
from jaxpm.distributed import fft3d, ifft3d, normal_field
from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second,
growth_rate, growth_rate_second)
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel,
invlaplace_kernel, longrange_kernel)
from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx
def pm_forces(positions, mesh_shape, delta=None, r_split=0):
def pm_forces(positions,
mesh_shape=None,
delta=None,
r_split=0,
paint_absolute_pos=True,
halo_size=0,
sharding=None):
"""
Computes gravitational forces on particles using a PM scheme
"""
if mesh_shape is None:
assert (delta is not None),\
"If mesh_shape is not provided, delta should be provided"
mesh_shape = delta.shape
if paint_absolute_pos:
paint_fn = lambda pos: cic_paint(jnp.zeros(shape=mesh_shape,
device=sharding),
pos,
halo_size=halo_size,
sharding=sharding)
read_fn = lambda grid_mesh, pos: cic_read(
grid_mesh, pos, halo_size=halo_size, sharding=sharding)
else:
paint_fn = lambda disp: cic_paint_dx(
disp, halo_size=halo_size, sharding=sharding)
read_fn = lambda grid_mesh, disp: cic_read_dx(
grid_mesh, disp, halo_size=halo_size, sharding=sharding)
if delta is None:
delta_k = jnp.fft.rfftn(cic_paint(jnp.zeros(mesh_shape), positions))
field = paint_fn(positions)
delta_k = fft3d(field)
elif jnp.isrealobj(delta):
delta_k = jnp.fft.rfftn(delta)
delta_k = fft3d(delta)
else:
delta_k = delta
kvec = fftk(delta_k)
# Computes gravitational potential
kvec = fftk(mesh_shape)
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec, r_split=r_split)
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(
kvec, r_split=r_split)
# Computes gravitational forces
return jnp.stack([cic_read(jnp.fft.irfftn(- gradient_kernel(kvec, i) * pot_k), positions)
for i in range(3)], axis=-1)
forces = jnp.stack([
read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),positions
) for i in range(3)], axis=-1) # yapf: disable
return forces
def lpt(cosmo:Cosmology, init_mesh, positions, a, order=1):
def lpt(cosmo,
initial_conditions,
particles=None,
a=0.1,
halo_size=0,
sharding=None,
order=1):
"""
Computes first and second order LPT displacement and momentum,
e.g. Eq. 2 and 3 [Jenkins2010](https://arxiv.org/pdf/0910.0258)
"""
paint_absolute_pos = particles is not None
if particles is None:
particles = jnp.zeros_like(initial_conditions,
shape=(*initial_conditions.shape, 3))
a = jnp.atleast_1d(a)
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
delta_k = jnp.fft.rfftn(init_mesh) # TODO: pass the modes directly to save one or two fft?
mesh_shape = init_mesh.shape
init_force = pm_forces(positions, mesh_shape, delta=delta_k)
dx = growth_factor(cosmo, a) * init_force
delta_k = fft3d(initial_conditions)
initial_force = pm_forces(particles,
delta=delta_k,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding)
dx = growth_factor(cosmo, a) * initial_force
p = a**2 * growth_rate(cosmo, a) * E * dx
f = a**2 * E * dGfa(cosmo, a) * init_force
f = a**2 * E * dGfa(cosmo, a) * initial_force
if order == 2:
kvec = fftk(mesh_shape)
kvec = fftk(delta_k)
pot_k = delta_k * invlaplace_kernel(kvec)
delta2 = 0
@ -54,7 +96,7 @@ def lpt(cosmo:Cosmology, init_mesh, positions, a, order=1):
# Add products of diagonal terms = 0 + s11*s00 + s22*(s11+s00)...
# shear_ii = jnp.fft.irfftn(- ki**2 * pot_k)
nabla_i_nabla_i = gradient_kernel(kvec, i)**2
shear_ii = jnp.fft.irfftn(nabla_i_nabla_i * pot_k)
shear_ii = ifft3d(nabla_i_nabla_i * pot_k)
delta2 += shear_ii * shear_acc
shear_acc += shear_ii
@ -62,10 +104,16 @@ def lpt(cosmo:Cosmology, init_mesh, positions, a, order=1):
for j in range(i + 1, 3):
# Substract squared strict-up-triangle terms
# delta2 -= jnp.fft.irfftn(- ki * kj * pot_k)**2
nabla_i_nabla_j = gradient_kernel(kvec, i) * gradient_kernel(kvec, j)
delta2 -= jnp.fft.irfftn(nabla_i_nabla_j * pot_k)**2
nabla_i_nabla_j = gradient_kernel(kvec, i) * gradient_kernel(
kvec, j)
delta2 -= ifft3d(nabla_i_nabla_j * pot_k)**2
init_force2 = pm_forces(positions, mesh_shape, delta=jnp.fft.rfftn(delta2))
delta_k2 = fft3d(delta2)
init_force2 = pm_forces(particles,
delta=delta_k2,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding)
# NOTE: growth_factor_second is renormalized: - D2 = 3/7 * growth_factor_second
dx2 = 3 / 7 * growth_factor_second(cosmo, a) * init_force2
p2 = a**2 * growth_rate_second(cosmo, a) * E * dx2
@ -78,23 +126,28 @@ def lpt(cosmo:Cosmology, init_mesh, positions, a, order=1):
return dx, p, f
def linear_field(mesh_shape, box_size, pk, seed):
def linear_field(mesh_shape, box_size, pk, seed, sharding=None):
"""
Generate initial conditions.
"""
kvec = fftk(mesh_shape)
# Initialize a random field with one slice on each gpu
field = normal_field(mesh_shape, seed=seed, sharding=sharding)
field = fft3d(field)
kvec = fftk(field)
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2
for i, kk in enumerate(kvec))**0.5
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (
box_size[0] * box_size[1] * box_size[2])
field = jax.random.normal(seed, mesh_shape)
field = jnp.fft.rfftn(field) * pkmesh**0.5
field = jnp.fft.irfftn(field)
field = field * (pkmesh)**0.5
field = ifft3d(field)
return field
def make_ode_fn(mesh_shape):
def make_ode_fn(mesh_shape,
paint_absolute_pos=True,
halo_size=0,
sharding=None):
def nbody_ode(state, a, cosmo):
"""
@ -102,7 +155,11 @@ def make_ode_fn(mesh_shape):
"""
pos, vel = state
forces = pm_forces(pos, mesh_shape=mesh_shape) * 1.5 * cosmo.Omega_m
forces = pm_forces(pos,
mesh_shape=mesh_shape,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding) * 1.5 * cosmo.Omega_m
# Computes the update of position (drift)
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
@ -114,16 +171,24 @@ def make_ode_fn(mesh_shape):
return nbody_ode
def get_ode_fn(cosmo:Cosmology, mesh_shape):
def make_diffrax_ode(cosmo,
mesh_shape,
paint_absolute_pos=True,
halo_size=0,
sharding=None):
def nbody_ode(a, state, args):
"""
State is an array [position, velocities]
Compatible with [Diffrax API](https://docs.kidger.site/diffrax/)
state is a tuple (position, velocities)
"""
pos, vel = state
forces = pm_forces(pos, mesh_shape) * 1.5 * cosmo.Omega_m
forces = pm_forces(pos,
mesh_shape=mesh_shape,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding) * 1.5 * cosmo.Omega_m
# Computes the update of position (drift)
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
@ -145,16 +210,19 @@ def pgd_correction(pos, mesh_shape, params):
pos: particle positions [npart, 3]
params: [alpha, kl, ks] pgd parameters
"""
kvec = fftk(mesh_shape)
delta = cic_paint(jnp.zeros(mesh_shape), pos)
delta_k = fft3d(delta)
kvec = fftk(delta_k)
alpha, kl, ks = params
delta_k = jnp.fft.rfftn(delta)
PGD_range = PGD_kernel(kvec, kl, ks)
pot_k_pgd = (delta_k * invlaplace_kernel(kvec)) * PGD_range
forces_pgd= jnp.stack([cic_read(jnp.fft.irfftn(- gradient_kernel(kvec, i)*pot_k_pgd), pos)
for i in range(3)],axis=-1)
forces_pgd = jnp.stack([
cic_read(fft3d(-gradient_kernel(kvec, i) * pot_k_pgd), pos)
for i in range(3)
],
axis=-1)
dpos_pgd = forces_pgd * alpha
@ -162,27 +230,30 @@ def pgd_correction(pos, mesh_shape, params):
def make_neural_ode_fn(model, mesh_shape):
def neural_nbody_ode(state, a, cosmo: Cosmology, params):
"""
state is a tuple (position, velocities)
"""
pos, vel = state
kvec = fftk(mesh_shape)
delta = cic_paint(jnp.zeros(mesh_shape), pos)
delta_k = jnp.fft.rfftn(delta)
delta_k = fft3d(delta)
kvec = fftk(delta_k)
# Computes gravitational potential
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec, r_split=0)
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec,
r_split=0)
# Apply a correction filter
kk = jnp.sqrt(sum((ki / jnp.pi)**2 for ki in kvec))
pot_k = pot_k * (1. + model.apply(params, kk, jnp.atleast_1d(a)))
# Computes gravitational forces
forces = jnp.stack([cic_read(jnp.fft.irfftn(- gradient_kernel(kvec, i)*pot_k), pos)
for i in range(3)],axis=-1)
forces = jnp.stack([
cic_read(fft3d(-gradient_kernel(kvec, i) * pot_k), pos)
for i in range(3)
],
axis=-1)
forces = forces * 1.5 * cosmo.Omega_m
@ -193,4 +264,5 @@ def make_neural_ode_fn(model, mesh_shape):
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
return dpos, dvel
return neural_nbody_ode

View file

@ -1,89 +1,161 @@
from functools import partial
import jax.numpy as jnp
import numpy as np
from jax.scipy.stats import norm
from scipy.special import legendre
__all__ = ['power_spectrum']
def _initialize_pk(shape, boxsize, kmin, dk):
"""
Helper function to initialize various (fixed) values for powerspectra... not differentiable!
"""
I = np.eye(len(shape), dtype='int') * -2 + 1
W = np.empty(shape, dtype='f4')
W[...] = 2.0
W[..., 0] = 1.0
W[..., -1] = 1.0
kmax = np.pi * np.min(np.array(shape)) / np.max(np.array(boxsize)) + dk / 2
kedges = np.arange(kmin, kmax, dk)
k = [
np.fft.fftfreq(N, 1. / (N * 2 * np.pi / L))[:pkshape].reshape(kshape)
for N, L, kshape, pkshape in zip(shape, boxsize, I, shape)
__all__ = [
'power_spectrum', 'transfer', 'coherence', 'pktranscoh',
'cross_correlation_coefficients', 'gaussian_smoothing'
]
kmag = sum(ki**2 for ki in k)**0.5
xsum = np.zeros(len(kedges) + 1)
Nsum = np.zeros(len(kedges) + 1)
dig = np.digitize(kmag.flat, kedges)
xsum.flat += np.bincount(dig, weights=(W * kmag).flat, minlength=xsum.size)
Nsum.flat += np.bincount(dig, weights=W.flat, minlength=xsum.size)
return dig, Nsum, xsum, W, k, kedges
def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
def _initialize_pk(mesh_shape, box_shape, kedges, los):
"""
Calculate the powerspectra given real space field
Args:
field: real valued field
kmin: minimum k-value for binned powerspectra
dk: differential in each kbin
boxsize: length of each boxlength (can be strangly shaped?)
Returns:
kbins: the central value of the bins for plotting
power: real valued array of power in each bin
Parameters
----------
mesh_shape : tuple of int
Shape of the mesh grid.
box_shape : tuple of float
Physical dimensions of the box.
kedges : None, int, float, or list
If None, set dk to twice the minimum.
If int, specifies number of edges.
If float, specifies dk.
los : array_like
Line-of-sight vector.
Returns
-------
dig : ndarray
Indices of the bins to which each value in input array belongs.
kcount : ndarray
Count of values in each bin.
kedges : ndarray
Edges of the bins.
mumesh : ndarray
Mu values for the mesh grid.
"""
shape = field.shape
nx, ny, nz = shape
kmax = np.pi * np.min(mesh_shape / box_shape) # = knyquist
#initialze values related to powerspectra (mode bins and weights)
dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk)
if isinstance(kedges, None | int | float):
if kedges is None:
dk = 2 * np.pi / np.min(
box_shape) * 2 # twice the minimum wavenumber
if isinstance(kedges, int):
dk = kmax / (kedges + 1) # final number of bins will be kedges-1
elif isinstance(kedges, float):
dk = kedges
kedges = np.arange(dk, kmax, dk) + dk / 2 # from dk/2 to kmax-dk/2
#fast fourier transform
fft_image = jnp.fft.fftn(field)
kshapes = np.eye(len(mesh_shape), dtype=np.int32) * -2 + 1
kvec = [(2 * np.pi * m / l) * np.fft.fftfreq(m).reshape(kshape)
for m, l, kshape in zip(mesh_shape, box_shape, kshapes)]
kmesh = sum(ki**2 for ki in kvec)**0.5
#absolute value of fast fourier transform
pk = jnp.real(fft_image * jnp.conj(fft_image))
dig = np.digitize(kmesh.reshape(-1), kedges)
kcount = np.bincount(dig, minlength=len(kedges) + 1)
#calculating powerspectra
real = jnp.real(pk).reshape([-1])
imag = jnp.imag(pk).reshape([-1])
# Central value of each bin
# kavg = (kedges[1:] + kedges[:-1]) / 2
kavg = np.bincount(
dig, weights=kmesh.reshape(-1), minlength=len(kedges) + 1) / kcount
kavg = kavg[1:-1]
Psum = jnp.bincount(dig, weights=(W.flatten() * imag),
length=xsum.size) * 1j
Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
if los is None:
mumesh = 1.
else:
mumesh = sum(ki * losi for ki, losi in zip(kvec, los))
kmesh_nozeros = np.where(kmesh == 0, 1, kmesh)
mumesh = np.where(kmesh == 0, 0, mumesh / kmesh_nozeros)
P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')
#normalization for powerspectra
norm = np.prod(np.array(shape[:])).astype('float32')**2
#find central values of each bin
kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2
return kbins, P / norm
return dig, kcount, kavg, mumesh
def cross_correlation_coefficients(field_a,field_b, kmin=5, dk=0.5, boxsize=False):
def power_spectrum(mesh,
mesh2=None,
box_shape=None,
kedges: int | float | list = None,
multipoles=0,
los=[0., 0., 1.]):
"""
Compute the auto and cross spectrum of 3D fields, with multipoles.
"""
# Initialize
mesh_shape = np.array(mesh.shape)
if box_shape is None:
box_shape = mesh_shape
else:
box_shape = np.asarray(box_shape)
if multipoles == 0:
los = None
else:
los = np.asarray(los)
los = los / np.linalg.norm(los)
poles = np.atleast_1d(multipoles)
dig, kcount, kavg, mumesh = _initialize_pk(mesh_shape, box_shape, kedges,
los)
n_bins = len(kavg) + 2
# FFTs
meshk = jnp.fft.fftn(mesh, norm='ortho')
if mesh2 is None:
mmk = meshk.real**2 + meshk.imag**2
else:
mmk = meshk * jnp.fft.fftn(mesh2, norm='ortho').conj()
# Sum powers
pk = jnp.empty((len(poles), n_bins))
for i_ell, ell in enumerate(poles):
weights = (mmk * (2 * ell + 1) * legendre(ell)(mumesh)).reshape(-1)
if mesh2 is None:
psum = jnp.bincount(dig, weights=weights, length=n_bins)
else: # XXX: bincount is really slow with complex numbers
psum_real = jnp.bincount(dig, weights=weights.real, length=n_bins)
psum_imag = jnp.bincount(dig, weights=weights.imag, length=n_bins)
psum = (psum_real**2 + psum_imag**2)**.5
pk = pk.at[i_ell].set(psum)
# Normalization and conversion from cell units to [Mpc/h]^3
pk = (pk / kcount)[:, 1:-1] * (box_shape / mesh_shape).prod()
# pk = jnp.concatenate([kavg[None], pk])
if np.ndim(multipoles) == 0:
return kavg, pk[0]
else:
return kavg, pk
def transfer(mesh0, mesh1, box_shape, kedges: int | float | list = None):
pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
ks, pk0 = pk_fn(mesh0)
ks, pk1 = pk_fn(mesh1)
return ks, (pk1 / pk0)**.5
def coherence(mesh0, mesh1, box_shape, kedges: int | float | list = None):
pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
ks, pk01 = pk_fn(mesh0, mesh1)
ks, pk0 = pk_fn(mesh0)
ks, pk1 = pk_fn(mesh1)
return ks, pk01 / (pk0 * pk1)**.5
def pktranscoh(mesh0, mesh1, box_shape, kedges: int | float | list = None):
pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
ks, pk01 = pk_fn(mesh0, mesh1)
ks, pk0 = pk_fn(mesh0)
ks, pk1 = pk_fn(mesh1)
return ks, pk0, pk1, (pk1 / pk0)**.5, pk01 / (pk0 * pk1)**.5
def cross_correlation_coefficients(field_a,
field_b,
kmin=5,
dk=0.5,
boxsize=False):
"""
Calculate the cross correlation coefficients given two real space field
@ -118,7 +190,8 @@ def cross_correlation_coefficients(field_a,field_b, kmin=5, dk=0.5, boxsize=Fals
real = jnp.real(pk).reshape([-1])
imag = jnp.imag(pk).reshape([-1])
Psum = jnp.bincount(dig, weights=(W.flatten() * imag), length=xsum.size) * 1j
Psum = jnp.bincount(dig, weights=(W.flatten() * imag),
length=xsum.size) * 1j
Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')

View file

@ -0,0 +1,179 @@
import os
os.environ["EQX_ON_ERROR"] = "nan" # avoid an allgather caused by diffrax
import jax
jax.distributed.initialize()
rank = jax.process_index()
size = jax.process_count()
if rank == 0:
print(f"SIZE is {jax.device_count()}")
import argparse
from functools import partial
import jax.numpy as jnp
import jax_cosmo as jc
import numpy as np
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
PIDController, SaveAt, diffeqsolve)
from jax.experimental.mesh_utils import create_device_mesh
from jax.experimental.multihost_utils import process_allgather
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from jaxpm.kernels import interpolate_power_spectrum
from jaxpm.painting import cic_paint_dx
from jaxpm.pm import linear_field, lpt, make_diffrax_ode
all_gather = partial(process_allgather, tiled=True)
def parse_arguments():
parser = argparse.ArgumentParser(
description="Run a cosmological simulation with JAX.")
parser.add_argument(
"-p",
"--pdims",
type=int,
nargs=2,
default=[1, jax.devices()],
help="Processor grid dimensions as two integers (e.g., 2 4).")
parser.add_argument(
"-m",
"--mesh_shape",
type=int,
nargs=3,
default=[512, 512, 512],
help="Shape of the simulation mesh as three values (e.g., 512 512 512)."
)
parser.add_argument(
"-b",
"--box_size",
type=float,
nargs=3,
default=[500.0, 500.0, 500.0],
help=
"Box size of the simulation as three values (e.g., 500.0 500.0 1000.0)."
)
parser.add_argument(
"-st",
"--snapshots",
type=int,
default=2,
help="Number of snapshots to save during the simulation.")
parser.add_argument("-H",
"--halo_size",
type=int,
default=64,
help="Halo size for the simulation.")
parser.add_argument("-s",
"--solver",
type=str,
choices=['leapfrog', 'dopri8'],
default='leapfrog',
help="ODE solver choice: 'leapfrog' or 'dopri8'.")
return parser.parse_args()
def create_mesh_and_sharding(pdims):
devices = create_device_mesh(pdims)
mesh = Mesh(devices, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
return mesh, sharding
@partial(jax.jit, static_argnums=(2, 3, 4, 5, 6))
def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
solver_choice, nb_snapshots, sharding):
k = jnp.logspace(-4, 1, 128)
pk = jc.power.linear_matter_power(
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)
initial_conditions = linear_field(mesh_shape,
box_size,
pk_fn,
seed=jax.random.PRNGKey(0),
sharding=sharding)
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
dx, p, _ = lpt(cosmo,
initial_conditions,
a=0.1,
halo_size=halo_size,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
# Choose solver
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()
stepsize_controller = ConstantStepSize(
) if solver_choice == "leapfrog" else PIDController(rtol=1e-5, atol=1e-5)
res = diffeqsolve(ode_fn,
solver,
t0=0.1,
t1=1.,
dt0=0.01,
y0=jnp.stack([dx, p], axis=0),
args=cosmo,
saveat=SaveAt(ts=jnp.linspace(0.2, 1., nb_snapshots)),
stepsize_controller=stepsize_controller)
ode_fields = [
cic_paint_dx(sol[0], halo_size=halo_size, sharding=sharding)
for sol in res.ys
]
lpt_field = cic_paint_dx(dx, halo_size=halo_size, sharding=sharding)
return initial_conditions, lpt_field, ode_fields, res.stats
def main():
args = parse_arguments()
mesh_shape = args.mesh_shape
box_size = args.box_size
halo_size = args.halo_size
solver_choice = args.solver
nb_snapshots = args.snapshots
sharding = create_mesh_and_sharding(args.pdims)
initial_conditions, lpt_displacements, ode_solutions, solver_stats = run_simulation(
0.25, 0.8, tuple(mesh_shape), tuple(box_size), halo_size,
solver_choice, nb_snapshots, sharding)
if rank == 0:
os.makedirs("fields", exist_ok=True)
print(f"[{rank}] Simulation done")
print(f"Solver stats: {solver_stats}")
# Save initial conditions
initial_conditions_g = all_gather(initial_conditions)
if rank == 0:
print(f"[{rank}] Saving initial_conditions")
np.save("fields/initial_conditions.npy", initial_conditions_g)
print(f"[{rank}] initial_conditions saved")
del initial_conditions_g, initial_conditions
# Save LPT displacements
lpt_displacements_g = all_gather(lpt_displacements)
if rank == 0:
print(f"[{rank}] Saving lpt_displacements")
np.save("fields/lpt_displacements.npy", lpt_displacements_g)
print(f"[{rank}] lpt_displacements saved")
del lpt_displacements_g, lpt_displacements
# Save each ODE solution separately
for i, sol in enumerate(ode_solutions):
sol_g = all_gather(sol)
if rank == 0:
print(f"[{rank}] Saving ode_solution_{i}")
np.save(f"fields/ode_solution_{i}.npy", sol_g)
print(f"[{rank}] ode_solution_{i} saved")
del sol_g
if __name__ == "__main__":
main()

View file

@ -0,0 +1 @@
c4a44973e4f11841a8c14f4d200e7e87887419aa

39
notebooks/README.md Normal file
View file

@ -0,0 +1,39 @@
# Particle Mesh Simulation with JAXPM on Multi-GPU and Multi-Host Systems
This collection of notebooks demonstrates how to perform Particle Mesh (PM) simulations using **JAXPM**, leveraging JAX for efficient computation on multi-GPU and multi-host systems. Each notebook progressively covers different setups, from single-GPU simulations to advanced, distributed, multi-host simulations across multiple nodes.
## Table of Contents
1. **[Single-GPU Particle Mesh Simulation](01-Introduction.ipynb)**
- Introduction to basic PM simulations on a single GPU.
- Uses JAXPM to run simulations with absolute particle positions and Cloud-in-Cell (CIC) painting.
2. **[Advanced Particle Mesh Simulation on a Single GPU](02-Advanced_usage.ipynb)**
- Explore using diffrax solvers in the ODE step.
- Explores second order Lagrangian Perturbation Theory (LPT) simulations.
- Introduces weighted density field projections
3. **[Multi-GPU Particle Mesh Simulation with Halo Exchange](03-MultiGPU_PM_Halo.ipynb)**
- Extends PM simulation to multi-GPU setups with halo exchange.
- Uses sharding and device mesh configurations to manage distributed data across GPUs.
4. **[Multi-GPU Particle Mesh Simulation with Advanced Solvers](04-MultiGPU_PM_Solvers.ipynb)**
- Compares different ODE solvers (Leapfrog and Dopri5) in multi-GPU simulations.
- Highlights performance, memory considerations, and solver impact on simulation quality.
5. **[Multi-Host Particle Mesh Simulation](05-MultiHost_PM.ipynb)**
- Extends PM simulations to multi-host, multi-GPU setups for large-scale simulations.
- Guides through job submission, device initialization, and retrieving results across nodes.
## Getting Started
Each notebook includes installation instructions and guidelines for configuring JAXPM and required dependencies. Follow the setup instructions in each notebook to ensure an optimal environment.
## Requirements
- **JAXPM** (included in the installation commands within notebooks)
- **Diffrax** for ODE solvers
- **JAX** with CUDA support for multi-GPU or TPU setups
- **SLURM** for job scheduling on clusters (if running multi-host setups)
> **Note**: These notebooks are tested on the **Jean Zay** supercomputer and may require configuration changes for different HPC clusters.

30
pyproject.toml Normal file
View file

@ -0,0 +1,30 @@
[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "JaxPM"
version = "0.0.1"
description = "A dead simple FastPM implementation in JAX"
authors = [{ name = "JaxPM developers" }]
readme = "README.md"
requires-python = ">=3.9"
license = { file = "LICENSE" }
urls = { "Homepage" = "https://github.com/DifferentiableUniverseInitiative/JaxPM" }
dependencies = ["jax_cosmo", "jax>=0.4.30", "jaxdecomp>=0.2.2"]
[project.optional-dependencies]
test = [
"jax>=0.4.30",
"numpy",
"jax_cosmo",
"jaxdecomp>=0.2.2",
"pytest>=8.0.0",
"pfft-python @ git+https://github.com/MP-Gadget/pfft-python",
"pmesh @ git+https://github.com/MP-Gadget/pmesh",
"fastpm @ git+https://github.com/ASKabalan/fastpm-python",
"diffrax"
]
[tool.setuptools]
packages = ["jaxpm"]

4
pytest.ini Normal file
View file

@ -0,0 +1,4 @@
[pytest]
markers =
distributed: mark a test as distributed
single_device: mark a test as single_device

View file

@ -1,11 +0,0 @@
from setuptools import find_packages, setup
setup(
name='JaxPM',
version='0.0.1',
url='https://github.com/DifferentiableUniverseInitiative/JaxPM',
author='JaxPM developers',
description='A dead simple FastPM implementation in JAX',
packages=find_packages(),
install_requires=['jax', 'jax_cosmo'],
)

175
tests/conftest.py Normal file
View file

@ -0,0 +1,175 @@
# Parameterized fixture for mesh_shape
import os
import pytest
os.environ["EQX_ON_ERROR"] = "nan"
setup_done = False
on_cluster = False
def is_on_cluster():
global on_cluster
return on_cluster
def initialize_distributed():
global setup_done
global on_cluster
if not setup_done:
if "SLURM_JOB_ID" in os.environ:
on_cluster = True
print("Running on cluster")
import jax
jax.distributed.initialize()
setup_done = True
on_cluster = True
else:
print("Running locally")
setup_done = True
on_cluster = False
os.environ["JAX_PLATFORM_NAME"] = "cpu"
os.environ[
"XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
import jax
@pytest.fixture(
scope="session",
params=[
((32, 32, 32), (256., 256., 256.)), # BOX
((32, 32, 64), (256., 256., 512.)), # RECTANGULAR
])
def simulation_config(request):
return request.param
@pytest.fixture(scope="session", params=[0.1, 0.5, 0.8])
def lpt_scale_factor(request):
return request.param
@pytest.fixture(scope="session")
def cosmo():
from functools import partial
from jax_cosmo import Cosmology
Planck18 = partial(
Cosmology,
# Omega_m = 0.3111
Omega_c=0.2607,
Omega_b=0.0490,
Omega_k=0.0,
h=0.6766,
n_s=0.9665,
sigma8=0.8102,
w0=-1.0,
wa=0.0,
)
return Planck18()
@pytest.fixture(scope="session")
def particle_mesh(simulation_config):
from pmesh.pm import ParticleMesh
mesh_shape, box_shape = simulation_config
return ParticleMesh(BoxSize=box_shape, Nmesh=mesh_shape, dtype='f4')
@pytest.fixture(scope="session")
def fpm_initial_conditions(cosmo, particle_mesh):
import jax_cosmo as jc
import numpy as np
from jax import numpy as jnp
# Generate initial particle positions
grid = particle_mesh.generate_uniform_particle_grid(shift=0).astype(
np.float32)
# Interpolate with linear_matter spectrum to get initial density field
k = jnp.logspace(-4, 1, 128)
pk = jc.power.linear_matter_power(cosmo, k)
def pk_fn(x):
return jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape)
whitec = particle_mesh.generate_whitenoise(42,
type='complex',
unitary=False)
lineark = whitec.apply(lambda k, v: pk_fn(sum(ki**2 for ki in k)**0.5)**0.5
* v * (1 / v.BoxSize).prod()**0.5)
init_mesh = lineark.c2r().value # XXX
return lineark, grid, init_mesh
@pytest.fixture(scope="session")
def initial_conditions(fpm_initial_conditions):
_, _, init_mesh = fpm_initial_conditions
return init_mesh
@pytest.fixture(scope="session")
def solver(cosmo, particle_mesh):
from fastpm.core import Cosmology as FastPMCosmology
from fastpm.core import Solver
ref_cosmo = FastPMCosmology(cosmo)
return Solver(particle_mesh, ref_cosmo, B=1)
@pytest.fixture(scope="session")
def fpm_lpt1(solver, fpm_initial_conditions, lpt_scale_factor):
lineark, grid, _ = fpm_initial_conditions
statelpt = solver.lpt(lineark, grid, lpt_scale_factor, order=1)
return statelpt
@pytest.fixture(scope="session")
def fpm_lpt1_field(fpm_lpt1, particle_mesh):
return particle_mesh.paint(fpm_lpt1.X).value
@pytest.fixture(scope="session")
def fpm_lpt2(solver, fpm_initial_conditions, lpt_scale_factor):
lineark, grid, _ = fpm_initial_conditions
statelpt = solver.lpt(lineark, grid, lpt_scale_factor, order=2)
return statelpt
@pytest.fixture(scope="session")
def fpm_lpt2_field(fpm_lpt2, particle_mesh):
return particle_mesh.paint(fpm_lpt2.X).value
@pytest.fixture(scope="session")
def nbody_from_lpt1(solver, fpm_lpt1, particle_mesh, lpt_scale_factor):
import numpy as np
from fastpm.core import leapfrog
if lpt_scale_factor == 0.8:
pytest.skip("Do not run nbody simulation from scale factor 0.8")
stages = np.linspace(lpt_scale_factor, 1.0, 10, endpoint=True)
finalstate = solver.nbody(fpm_lpt1, leapfrog(stages))
fpm_mesh = particle_mesh.paint(finalstate.X).value
return fpm_mesh
@pytest.fixture(scope="session")
def nbody_from_lpt2(solver, fpm_lpt2, particle_mesh, lpt_scale_factor):
import numpy as np
from fastpm.core import leapfrog
if lpt_scale_factor == 0.8:
pytest.skip("Do not run nbody simulation from scale factor 0.8")
stages = np.linspace(lpt_scale_factor, 1.0, 10, endpoint=True)
finalstate = solver.nbody(fpm_lpt2, leapfrog(stages))
fpm_mesh = particle_mesh.paint(finalstate.X).value
return fpm_mesh

13
tests/helpers.py Normal file
View file

@ -0,0 +1,13 @@
import jax.numpy as jnp
def MSE(x, y):
return jnp.mean((x - y)**2)
def MSE_3D(x, y):
return ((x - y)**2).mean(axis=0)
def MSRE(x, y):
return jnp.mean(((x - y) / y)**2)

155
tests/test_against_fpm.py Normal file
View file

@ -0,0 +1,155 @@
import pytest
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve
from helpers import MSE, MSRE
from jax import numpy as jnp
from jaxpm.distributed import uniform_particles
from jaxpm.painting import cic_paint, cic_paint_dx
from jaxpm.pm import lpt, make_diffrax_ode
from jaxpm.utils import power_spectrum
_TOLERANCE = 1e-4
_PM_TOLERANCE = 1e-3
@pytest.mark.single_device
@pytest.mark.parametrize("order", [1, 2])
def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor,
fpm_lpt1_field, fpm_lpt2_field, cosmo, order):
mesh_shape, box_shape = simulation_config
cosmo._workspace = {}
particles = uniform_particles(mesh_shape)
# Initial displacement
dx, _, _ = lpt(cosmo,
initial_conditions,
particles,
a=lpt_scale_factor,
order=order)
fpm_ref_field = fpm_lpt1_field if order == 1 else fpm_lpt2_field
lpt_field = cic_paint(jnp.zeros(mesh_shape), particles + dx)
_, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape)
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
@pytest.mark.single_device
@pytest.mark.parametrize("order", [1, 2])
def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor,
fpm_lpt1_field, fpm_lpt2_field, cosmo, order):
mesh_shape, box_shape = simulation_config
cosmo._workspace = {}
# Initial displacement
dx, _, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order)
lpt_field = cic_paint_dx(dx)
fpm_ref_field = fpm_lpt1_field if order == 1 else fpm_lpt2_field
_, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape)
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
@pytest.mark.single_device
@pytest.mark.parametrize("order", [1, 2])
def test_nbody_absolute(simulation_config, initial_conditions,
lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2,
cosmo, order):
mesh_shape, box_shape = simulation_config
cosmo._workspace = {}
particles = uniform_particles(mesh_shape)
# Initial displacement
dx, p, _ = lpt(cosmo,
initial_conditions,
particles,
a=lpt_scale_factor,
order=order)
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
solver = Dopri5()
controller = PIDController(rtol=1e-8,
atol=1e-8,
pcoeff=0.4,
icoeff=1,
dcoeff=0)
saveat = SaveAt(t1=True)
y0 = jnp.stack([particles + dx, p])
solutions = diffeqsolve(ode_fn,
solver,
t0=lpt_scale_factor,
t1=1.0,
dt0=None,
y0=y0,
stepsize_controller=controller,
saveat=saveat)
final_field = cic_paint(jnp.zeros(mesh_shape), solutions.ys[-1, 0])
fpm_ref_field = nbody_from_lpt1 if order == 1 else nbody_from_lpt2
_, jpm_ps = power_spectrum(final_field, box_shape=box_shape)
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE
@pytest.mark.single_device
@pytest.mark.parametrize("order", [1, 2])
def test_nbody_relative(simulation_config, initial_conditions,
lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2,
cosmo, order):
mesh_shape, box_shape = simulation_config
cosmo._workspace = {}
# Initial displacement
dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order)
ode_fn = ODETerm(
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
solver = Dopri5()
controller = PIDController(rtol=1e-9,
atol=1e-9,
pcoeff=0.4,
icoeff=1,
dcoeff=0)
saveat = SaveAt(t1=True)
y0 = jnp.stack([dx, p])
solutions = diffeqsolve(ode_fn,
solver,
t0=lpt_scale_factor,
t1=1.0,
dt0=None,
y0=y0,
stepsize_controller=controller,
saveat=saveat)
final_field = cic_paint_dx(solutions.ys[-1, 0])
fpm_ref_field = nbody_from_lpt1 if order == 1 else nbody_from_lpt2
_, jpm_ps = power_spectrum(final_field, box_shape=box_shape)
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE

View file

@ -0,0 +1,152 @@
from conftest import initialize_distributed
initialize_distributed() # ignore : E402
import jax # noqa : E402
import jax.numpy as jnp # noqa : E402
import pytest # noqa : E402
from diffrax import SaveAt # noqa : E402
from diffrax import Dopri5, ODETerm, PIDController, diffeqsolve
from helpers import MSE # noqa : E402
from jax import lax # noqa : E402
from jax.experimental.multihost_utils import process_allgather # noqa : E402
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P # noqa : E402
from jaxpm.distributed import uniform_particles # noqa : E402
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402
_TOLERANCE = 3.0 # 🙃🙃
@pytest.mark.distributed
@pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("absolute_painting", [True, False])
def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
absolute_painting):
mesh_shape, box_shape = simulation_config
# SINGLE DEVICE RUN
cosmo._workspace = {}
if absolute_painting:
particles = uniform_particles(mesh_shape)
# Initial displacement
dx, p, _ = lpt(cosmo,
initial_conditions,
particles,
a=0.1,
order=order)
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
y0 = jnp.stack([particles + dx, p])
else:
dx, p, _ = lpt(cosmo, initial_conditions, a=0.1, order=order)
ode_fn = ODETerm(
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
y0 = jnp.stack([dx, p])
solver = Dopri5()
controller = PIDController(rtol=1e-8,
atol=1e-8,
pcoeff=0.4,
icoeff=1,
dcoeff=0)
saveat = SaveAt(t1=True)
solutions = diffeqsolve(ode_fn,
solver,
t0=0.1,
t1=1.0,
dt0=None,
y0=y0,
stepsize_controller=controller,
saveat=saveat)
if absolute_painting:
single_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape),
solutions.ys[-1, 0])
else:
single_device_final_field = cic_paint_dx(solutions.ys[-1, 0])
print("Done with single device run")
# MULTI DEVICE RUN
mesh = jax.make_mesh((1, 8), ('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2
initial_conditions = lax.with_sharding_constraint(initial_conditions,
sharding)
print(f"sharded initial conditions {initial_conditions.sharding}")
cosmo._workspace = {}
if absolute_painting:
particles = uniform_particles(mesh_shape, sharding=sharding)
# Initial displacement
dx, p, _ = lpt(cosmo,
initial_conditions,
particles,
a=0.1,
order=order,
halo_size=halo_size,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(cosmo,
mesh_shape,
halo_size=halo_size,
sharding=sharding))
y0 = jnp.stack([particles + dx, p])
else:
dx, p, _ = lpt(cosmo,
initial_conditions,
a=0.1,
order=order,
halo_size=halo_size,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(cosmo,
mesh_shape,
paint_absolute_pos=False,
halo_size=halo_size,
sharding=sharding))
y0 = jnp.stack([dx, p])
solver = Dopri5()
controller = PIDController(rtol=1e-8,
atol=1e-8,
pcoeff=0.4,
icoeff=1,
dcoeff=0)
saveat = SaveAt(t1=True)
solutions = diffeqsolve(ode_fn,
solver,
t0=0.1,
t1=1.0,
dt0=None,
y0=y0,
stepsize_controller=controller,
saveat=saveat)
if absolute_painting:
multi_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape),
solutions.ys[-1, 0],
halo_size=halo_size,
sharding=sharding)
else:
multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0],
halo_size=halo_size,
sharding=sharding)
multi_device_final_field = process_allgather(multi_device_final_field,
tiled=True)
mse = MSE(single_device_final_field, multi_device_final_field)
print(f"MSE is {mse}")
assert mse < _TOLERANCE