jaxdecomp proto (#21)

* adding example of distributed solution

* put back old functgion

* update formatting

* add halo exchange and slice pad

* apply formatting

* implement distributed optimized cic_paint

* Use new cic_paint with halo

* Fix seed for distributed normal

* Wrap interpolation function to avoid all gather

* Return normal order frequencies for single GPU

* add example

* format

* add optimised bench script

* times in ms

* add lpt2

* update benchmark and add slurm

* Visualize only final field

* Update scripts/distributed_pm.py

Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>

* Adjust pencil type for frequencies

* fix painting issue with slabs

* Shared operation in fourrier space now take inverted sharding axis for
slabs

* add assert to make pyright happy

* adjust test for hpc-plotter

* add PMWD test

* bench

* format

* added github workflow

* fix formatting from main

* Update for jaxDecomp pure JAX

* revert single halo extent change

* update for latest jaxDecomp

* remove fourrier_space in autoshmap

* make normal_field work with single controller

* format

* make distributed pm work in single controller

* merge bench_pm

* update to leapfrog

* add a strict dependency on jaxdecomp

* global mesh no longer needed

* kernels.py no longer uses global mesh

* quick fix in distributed

* pm.py no longer uses global mesh

* painting.py no longer uses global mesh

* update demo script

* quick fix in kernels

* quick fix in distributed

* update demo

* merge hugos LPT2 code

* format

* Small fix

* format

* remove duplicate get_ode_fn

* update visualizer

* update compensate CIC

* By default check_rep is false for shard_map

* remove experimental distributed code

* update PGDCorrection and neural ode to use new fft3d

* jaxDecomp pfft3d promotes to complex automatically

* remove deprecated stuff

* fix painting issue with read_cic

* use jnp interp instead of jc interp

* delete old slurms

* add notebook examples

* apply formatting

* add distributed zeros

* fix code in LPT2

* jit cic_paint

* update notebooks

* apply formating

* get local shape and zeros can be used by users

* add a user facing function to create uniform particle grid

* use jax interp instead of jax_cosmo

* use float64 for enmeshing

* Allow applying weights with relative cic paint

* Weights can be traced

* remove script folder

* update example notebooks

* delete outdated design file

* add readme for tutorials

* update readme

* fix small error

* forgot particles in multi host

* clarifying why cic_paint_dx is slower

* clarifying the halo size dependence on the box size

* ability to choose snapshots number with MultiHost script

* Adding animation notebook

* Put plotting in package

* Add finite difference laplace kernel + powerspec functions from Hugo

Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com>

* Put plotting utils in package

* By default use absoulute painting with

* update code

* update notebooks

* add tests

* Upgrade setup.py to pyproject

* Format

* format tests

* update test dependencies

* add test workflow

* fix deprecated FftType in jaxpm.kernels

* Add aboucaud comments

* JAX version is 0.4.35 until Diffrax new release

* add numpy explicitly as dependency for tests

* fix install order for tests

* add numpy to be installed

* enforce no build isolation for fastpm

* pip install jaxpm test without build isolation

* bump jaxdecomp version

* revert test workflow

* remove outdated tests

---------

Co-authored-by: EiffL <fr.eiffel@gmail.com>
Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>
Co-authored-by: Wassim KABALAN <wassim@apc.in2p3.fr>
Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com>
Former-commit-id: 8c2e823d4669eac712089bf7f85ffb7912e8232d
This commit is contained in:
Wassim KABALAN 2024-12-20 11:44:02 +01:00 committed by GitHub
parent a0a79277e5
commit df8602b318
26 changed files with 1871 additions and 434 deletions

21
.github/workflows/formatting.yml vendored Normal file
View file

@ -0,0 +1,21 @@
name: Code Formatting
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
- name: Install dependencies
run: |
python -m pip install --upgrade pip isort
python -m pip install pre-commit
- name: Run pre-commit
run: python -m pre_commit run --all-files

45
.github/workflows/tests.yml vendored Normal file
View file

@ -0,0 +1,45 @@
name: Tests
on:
push:
branches:
- main
pull_request:
branches:
- main
jobs:
run_tests:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10" , "3.11" , "3.12"]
steps:
- name: Checkout Source
uses: actions/checkout@v2.3.1
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
sudo apt-get install -y libopenmpi-dev
python -m pip install --upgrade pip
pip install jax==0.4.35
pip install numpy setuptools cython wheel
pip install git+https://github.com/MP-Gadget/pfft-python
pip install git+https://github.com/MP-Gadget/pmesh
pip install git+https://github.com/ASKabalan/fastpm-python --no-build-isolation
pip install .[test]
- name: Run Single Device Tests
run: |
cd tests
pytest -v -m "not distributed"
- name: Run Distributed tests
run: |
pytest -v -m distributed

5
.gitignore vendored
View file

@ -98,6 +98,11 @@ __pypackages__/
celerybeat-schedule celerybeat-schedule
celerybeat.pid celerybeat.pid
out
traces
*.npy
*.out
# SageMath parsed files # SageMath parsed files
*.sage.py *.sage.py

View file

@ -4,16 +4,15 @@
<!-- ALL-CONTRIBUTORS-BADGE:END --> <!-- ALL-CONTRIBUTORS-BADGE:END -->
JAX-powered Cosmological Particle-Mesh N-body Solver JAX-powered Cosmological Particle-Mesh N-body Solver
**This project is currently in an early design phase. All inputs are welcome on the [design document](https://github.com/DifferentiableUniverseInitiative/JaxPM/blob/main/design.md)**
## Goals ## Goals
Provide a modern infrastructure to support differentiable PM N-body simulations using JAX: Provide a modern infrastructure to support differentiable PM N-body simulations using JAX:
- Keep implementation simple and readable, in pure NumPy API - Keep implementation simple and readable, in pure NumPy API
- Transparent distribution using builtin `xmap`
- Any order forward and backward automatic differentiation - Any order forward and backward automatic differentiation
- Support automated batching using `vmap` - Support automated batching using `vmap`
- Compatibility with external optimizer libraries like `optax` - Compatibility with external optimizer libraries like `optax`
- Now fully distributable on **multi-GPU and multi-node** systems using [jaxDecomp](https://github.com/DifferentiableUniverseInitiative/jaxDecomp) working with`JAX v0.4.35`
## Open development and use ## Open development and use
@ -23,6 +22,10 @@ Current expectations are:
- Everyone is welcome to contribute, and can join the JOSS publication (until it is submitted to the journal). - Everyone is welcome to contribute, and can join the JOSS publication (until it is submitted to the journal).
- Anyone (including main contributors) can use this code as a framework to build and publish their own applications, with no expectation that they *need* to extend authorship to all jaxpm developers. - Anyone (including main contributors) can use this code as a framework to build and publish their own applications, with no expectation that they *need* to extend authorship to all jaxpm developers.
## Getting Started
To dive into JaxPMs capabilities, please explore the **notebook section** for detailed tutorials and examples on various setups, from single-device simulations to multi-host configurations. You can find the notebooks' [README here](notebooks/README.md) for a structured guide through each tutorial.
## Contributors ✨ ## Contributors ✨

View file

@ -1,52 +0,0 @@
# Design Document for JaxPM
This document aims to detail some of the API, implementation choices, and internal mechanism.
## Objective
Provide a user-friendly framework for distributed Particle-Mesh N-body simulations.
## Related Work
This project would be the latest iteration of a number of past libraries that have provided differentiable N-body models.
- [FlowPM](https://github.com/DifferentiableUniverseInitiative/flowpm): TensorFlow
- [vmad FastPM](https://github.com/rainwoodman/vmad/blob/master/vmad/lib/fastpm.py): VMAD
- Borg
In addition, a number of fast N-body simulation projets exist out there:
- [FastPM](https://github.com/fastpm/fastpm)
- ...
## Design Overview
### Coding principles
Following recent trends and JAX philosophy, the library should have a functional programming type of interface.
### Illustration of API
Here is a potential illustration of what the user interface could be for the simulation code:
```python
import jaxpm as jpm
import jax_cosmo as jc
# Instantiate differentiable cosmology object
cosmo = jc.Planck()
# Creates initial conditions
inital_conditions = jpm.generate_ic(cosmo, boxsize, nmesh, dtype='float32')
# Create a particular solver
solver = jpm.solvers.fastpm(cosmo, B=1)
# Initialize and run the simulation
state = solver.init(initial_conditions)
state = solver.nbody(state)
# Painting the results
density = jpm.zeros(boxsize, nmesh)
density = jpm.paint(density, state.positions)
```

View file

@ -1,96 +0,0 @@
# Can be executed with:
# srun -n 4 -c 32 --gpus-per-task 1 --gpu-bind=none python test_pfft.py
from functools import partial
import jax
import jax.lax as lax
import jax.numpy as jnp
import numpy as np
from jax.experimental.maps import Mesh, xmap
from jax.experimental.pjit import PartitionSpec, pjit
jax.distributed.initialize()
cube_size = 2048
@partial(xmap,
in_axes=[...],
out_axes=['x', 'y', ...],
axis_sizes={
'x': cube_size,
'y': cube_size
},
axis_resources={
'x': 'nx',
'y': 'ny',
'key_x': 'nx',
'key_y': 'ny'
})
def pnormal(key):
return jax.random.normal(key, shape=[cube_size])
@partial(xmap,
in_axes={
0: 'x',
1: 'y'
},
out_axes=['x', 'y', ...],
axis_resources={
'x': 'nx',
'y': 'ny'
})
@jax.jit
def pfft3d(mesh):
# [x, y, z]
mesh = jnp.fft.fft(mesh) # Transform on z
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now x is exposed, [z,y,x]
mesh = jnp.fft.fft(mesh) # Transform on x
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now y is exposed, [z,x,y]
mesh = jnp.fft.fft(mesh) # Transform on y
# [z, x, y]
return mesh
@partial(xmap,
in_axes={
0: 'x',
1: 'y'
},
out_axes=['x', 'y', ...],
axis_resources={
'x': 'nx',
'y': 'ny'
})
@jax.jit
def pifft3d(mesh):
# [z, x, y]
mesh = jnp.fft.ifft(mesh) # Transform on y
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now x is exposed, [z,y,x]
mesh = jnp.fft.ifft(mesh) # Transform on x
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now z is exposed, [x,y,z]
mesh = jnp.fft.ifft(mesh) # Transform on z
# [x, y, z]
return mesh
key = jax.random.PRNGKey(42)
# keys = jax.random.split(key, 4).reshape((2,2,2))
# We reshape all our devices to the mesh shape we want
devices = np.array(jax.devices()).reshape((2, 4))
with Mesh(devices, ('nx', 'ny')):
mesh = pnormal(key)
kmesh = pfft3d(mesh)
kmesh.block_until_ready()
# jax.profiler.start_trace("tensorboard")
# with Mesh(devices, ('nx', 'ny')):
# mesh = pnormal(key)
# kmesh = pfft3d(mesh)
# kmesh.block_until_ready()
# jax.profiler.stop_trace()
print('Done')

View file

@ -1,68 +0,0 @@
# Start this script with:
# mpirun -np 4 python test_script.py
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'
import jax
import jax.lax as lax
import jax.numpy as jnp
import matplotlib.pylab as plt
import numpy as np
import tensorflow_probability as tfp
from jax.experimental.maps import mesh, xmap
from jax.experimental.pjit import PartitionSpec, pjit
tfp = tfp.substrates.jax
tfd = tfp.distributions
def cic_paint(mesh, positions):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [npart, 3]
"""
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1,
2))
mesh = lax.scatter_add(
mesh,
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
kernel.reshape([-1, 8]), dnums)
return mesh
# And let's draw some points from some 3D distribution
dist = tfd.MultivariateNormalDiag(loc=[16., 16., 16.],
scale_identity_multiplier=3.)
pos = dist.sample(1e4, seed=jax.random.PRNGKey(0))
f = pjit(lambda x: cic_paint(x, pos),
in_axis_resources=PartitionSpec('x', 'y', 'z'),
out_axis_resources=None)
devices = np.array(jax.devices()).reshape((2, 2, 1))
# Let's import the mesh
m = jnp.zeros([32, 32, 32])
with mesh(devices, ('x', 'y', 'z')):
# Shard the mesh, I'm not sure this is absolutely necessary
m = pjit(lambda x: x,
in_axis_resources=None,
out_axis_resources=PartitionSpec('x', 'y', 'z'))(m)
# Apply the sharded CiC function
res = f(m)
plt.imshow(res.sum(axis=2))
plt.show()

198
jaxpm/distributed.py Normal file
View file

@ -0,0 +1,198 @@
from typing import Any, Callable, Hashable
Specs = Any
AxisName = Hashable
from functools import partial
import jax
import jax.numpy as jnp
import jaxdecomp
from jax import lax
from jax.experimental.shard_map import shard_map
from jax.sharding import AbstractMesh, Mesh
from jax.sharding import PartitionSpec as P
def autoshmap(
f: Callable,
gpu_mesh: Mesh | AbstractMesh | None,
in_specs: Specs,
out_specs: Specs,
check_rep: bool = False,
auto: frozenset[AxisName] = frozenset()) -> Callable:
"""Helper function to wrap the provided function in a shard map if
the code is being executed in a mesh context."""
if gpu_mesh is None or gpu_mesh.empty:
return f
else:
return shard_map(f, gpu_mesh, in_specs, out_specs, check_rep, auto)
def fft3d(x):
return jaxdecomp.pfft3d(x)
def ifft3d(x):
return jaxdecomp.pifft3d(x).real
def get_halo_size(halo_size, sharding):
gpu_mesh = sharding.mesh if sharding is not None else None
if gpu_mesh is None or gpu_mesh.empty:
zero_ext = (0, 0)
zero_tuple = (0, 0)
return (zero_tuple, zero_tuple, zero_tuple), zero_ext
else:
pdims = gpu_mesh.devices.shape
halo_x = (0, 0) if pdims[0] == 1 else (halo_size, halo_size)
halo_y = (0, 0) if pdims[1] == 1 else (halo_size, halo_size)
halo_x_ext = 0 if pdims[0] == 1 else halo_size // 2
halo_y_ext = 0 if pdims[1] == 1 else halo_size // 2
return ((halo_x, halo_y, (0, 0)), (halo_x_ext, halo_y_ext))
def halo_exchange(x, halo_extents, halo_periods=(True, True)):
if (halo_extents[0] > 0 or halo_extents[1] > 0):
return jaxdecomp.halo_exchange(x, halo_extents, halo_periods)
else:
return x
def slice_unpad_impl(x, pad_width):
halo_x, _ = pad_width[0]
halo_y, _ = pad_width[1]
# Apply corrections along x
x = x.at[halo_x:halo_x + halo_x // 2].add(x[:halo_x // 2])
x = x.at[-(halo_x + halo_x // 2):-halo_x].add(x[-halo_x // 2:])
# Apply corrections along y
x = x.at[:, halo_y:halo_y + halo_y // 2].add(x[:, :halo_y // 2])
x = x.at[:, -(halo_y + halo_y // 2):-halo_y].add(x[:, -halo_y // 2:])
unpad_slice = [slice(None)] * 3
if halo_x > 0:
unpad_slice[0] = slice(halo_x, -halo_x)
if halo_y > 0:
unpad_slice[1] = slice(halo_y, -halo_y)
return x[tuple(unpad_slice)]
def slice_pad(x, pad_width, sharding):
gpu_mesh = sharding.mesh if sharding is not None else None
if gpu_mesh is not None and not (gpu_mesh.empty) and (
pad_width[0][0] > 0 or pad_width[1][0] > 0):
assert sharding is not None
spec = sharding.spec
return shard_map((partial(jnp.pad, pad_width=pad_width)),
mesh=gpu_mesh,
in_specs=spec,
out_specs=spec)(x)
else:
return x
def slice_unpad(x, pad_width, sharding):
mesh = sharding.mesh if sharding is not None else None
if mesh is not None and not (mesh.empty) and (pad_width[0][0] > 0
or pad_width[1][0] > 0):
assert sharding is not None
spec = sharding.spec
return shard_map(partial(slice_unpad_impl, pad_width=pad_width),
mesh=mesh,
in_specs=spec,
out_specs=spec)(x)
else:
return x
def get_local_shape(mesh_shape, sharding=None):
""" Helper function to get the local size of a mesh given the global size.
"""
gpu_mesh = sharding.mesh if sharding is not None else None
if gpu_mesh is None or gpu_mesh.empty:
return mesh_shape
else:
pdims = gpu_mesh.devices.shape
return [
mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1],
*mesh_shape[2:]
]
def _axis_names(spec):
if len(spec) == 1:
x_axis, = spec
y_axis = None
single_axis = True
elif len(spec) == 2:
x_axis, y_axis = spec
if y_axis == None:
single_axis = True
elif x_axis == None:
x_axis = y_axis
single_axis = True
else:
single_axis = False
else:
raise ValueError("Only 1 or 2 axis sharding is supported")
return x_axis, y_axis, single_axis
def uniform_particles(mesh_shape, sharding=None):
gpu_mesh = sharding.mesh if sharding is not None else None
if gpu_mesh is not None and not (gpu_mesh.empty):
local_mesh_shape = get_local_shape(mesh_shape, sharding)
spec = sharding.spec
x_axis, y_axis, single_axis = _axis_names(spec)
def particles():
x_indx = lax.axis_index(x_axis)
y_indx = 0 if single_axis else lax.axis_index(y_axis)
x = jnp.arange(local_mesh_shape[0]) + x_indx * local_mesh_shape[0]
y = jnp.arange(local_mesh_shape[1]) + y_indx * local_mesh_shape[1]
z = jnp.arange(local_mesh_shape[2])
return jnp.stack(jnp.meshgrid(x, y, z, indexing='ij'), axis=-1)
return shard_map(particles, mesh=gpu_mesh, in_specs=(),
out_specs=spec)()
else:
return jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape],
indexing='ij'),
axis=-1)
def normal_field(mesh_shape, seed, sharding=None):
"""Generate a Gaussian random field with the given power spectrum."""
gpu_mesh = sharding.mesh if sharding is not None else None
if gpu_mesh is not None and not (gpu_mesh.empty):
local_mesh_shape = get_local_shape(mesh_shape, sharding)
size = jax.device_count()
# rank = jax.process_index()
# process_index is multi_host only
# to make the code work both in multi host and single controller we can do this trick
keys = jax.random.split(seed, size)
spec = sharding.spec
x_axis, y_axis, single_axis = _axis_names(spec)
def normal(keys, shape, dtype):
idx = lax.axis_index(x_axis)
if not single_axis:
y_index = lax.axis_index(y_axis)
x_size = lax.psum(1, axis_name=x_axis)
idx += y_index * x_size
return jax.random.normal(key=keys[idx], shape=shape, dtype=dtype)
return shard_map(
partial(normal, shape=local_mesh_shape, dtype='float32'),
mesh=gpu_mesh,
in_specs=P(None),
out_specs=spec)(keys) # yapf: disable
else:
return jax.random.normal(shape=mesh_shape, key=seed)

View file

@ -1,6 +1,6 @@
import jax.numpy as np import jax.numpy as np
from jax.numpy import interp
from jax_cosmo.background import * from jax_cosmo.background import *
from jax_cosmo.scipy.interpolate import interp
from jax_cosmo.scipy.ode import odeint from jax_cosmo.scipy.ode import odeint
@ -587,5 +587,6 @@ def dGf2a(cosmo, a):
cache = cosmo._workspace['background.growth_factor'] cache = cosmo._workspace['background.growth_factor']
f2p = cache['h2'] / cache['a'] * cache['g2'] f2p = cache['h2'] / cache['a'] * cache['g2']
f2p = interp(np.log(a), np.log(cache['a']), f2p) f2p = interp(np.log(a), np.log(cache['a']), f2p)
E = E(cosmo, a) E_a = E(cosmo, a)
return (f2p * a**3 * E + D2f * a**3 * dEa(cosmo, a) + 3 * a**2 * E * D2f) return (f2p * a**3 * E_a + D2f * a**3 * dEa(cosmo, a) +
3 * a**2 * E_a * D2f)

View file

@ -1,30 +1,46 @@
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np import numpy as np
from jax.lax import FftType
from jax.sharding import PartitionSpec as P
from jaxdecomp import fftfreq3d, get_output_specs
from jaxpm.distributed import autoshmap
def fftk(shape, symmetric=True, finite=False, dtype=np.float32): def fftk(k_array):
""" """
Return wave-vectors for a given shape Generate Fourier transform wave numbers for a given mesh.
"""
k = []
for d in range(len(shape)):
kd = np.fft.fftfreq(shape[d])
kd *= 2 * np.pi
kdshape = np.ones(len(shape), dtype='int')
if symmetric and d == len(shape) - 1:
kd = kd[:shape[d] // 2 + 1]
kdshape[d] = len(kd)
kd = kd.reshape(kdshape)
k.append(kd.astype(dtype)) Args:
del kd, kdshape nc (int): Shape of the mesh grid.
return k
Returns:
list: List of wave number arrays for each dimension in
the order [kx, ky, kz].
"""
kx, ky, kz = fftfreq3d(k_array)
# to the order of dimensions in the transposed FFT
return kx, ky, kz
def interpolate_power_spectrum(input, k, pk, sharding=None):
pk_fn = lambda x: jnp.interp(x.reshape(-1), k, pk).reshape(x.shape)
gpu_mesh = sharding.mesh if sharding is not None else None
specs = sharding.spec if sharding is not None else P()
out_specs = P(*get_output_specs(
FftType.FFT, specs, mesh=gpu_mesh)) if gpu_mesh is not None else P()
return autoshmap(pk_fn,
gpu_mesh=gpu_mesh,
in_specs=out_specs,
out_specs=out_specs)(input)
def gradient_kernel(kvec, direction, order=1): def gradient_kernel(kvec, direction, order=1):
""" """
Computes the gradient kernel in the requested direction Computes the gradient kernel in the requested direction
Parameters Parameters
----------- -----------
kvec: list kvec: list
@ -50,23 +66,30 @@ def gradient_kernel(kvec, direction, order=1):
return wts return wts
def invlaplace_kernel(kvec): def invlaplace_kernel(kvec, fd=False):
""" """
Compute the inverse Laplace kernel Compute the inverse Laplace kernel.
cf. [Feng+2016](https://arxiv.org/pdf/1603.00476)
Parameters Parameters
----------- -----------
kvec: list kvec: list
List of wave-vectors List of wave-vectors
fd: bool
Finite difference kernel
Returns Returns
-------- --------
wts: array wts: array
Complex kernel values Complex kernel values
""" """
if fd:
kk = sum((ki * jnp.sinc(ki / (2 * jnp.pi)))**2 for ki in kvec)
else:
kk = sum(ki**2 for ki in kvec) kk = sum(ki**2 for ki in kvec)
kk_nozeros = jnp.where(kk==0, 1, kk) kk_nozeros = jnp.where(kk == 0, 1, kk)
return - jnp.where(kk==0, 0, 1 / kk_nozeros) return -jnp.where(kk == 0, 0, 1 / kk_nozeros)
def longrange_kernel(kvec, r_split): def longrange_kernel(kvec, r_split):
@ -79,12 +102,10 @@ def longrange_kernel(kvec, r_split):
List of wave-vectors List of wave-vectors
r_split: float r_split: float
Splitting radius Splitting radius
Returns Returns
-------- --------
wts: array wts: array
Complex kernel values Complex kernel values
TODO: @modichirag add documentation TODO: @modichirag add documentation
""" """
if r_split != 0: if r_split != 0:
@ -105,13 +126,12 @@ def cic_compensation(kvec):
----------- -----------
kvec: list kvec: list
List of wave-vectors List of wave-vectors
Returns: Returns:
-------- --------
wts: array wts: array
Complex kernel values Complex kernel values
""" """
kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)] kwts = [jnp.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
wts = (kwts[0] * kwts[1] * kwts[2])**(-2) wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
return wts return wts

View file

@ -1,15 +1,24 @@
from functools import partial
import jax import jax
import jax.lax as lax import jax.lax as lax
import jax.numpy as jnp import jax.numpy as jnp
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P
from jaxpm.distributed import (autoshmap, fft3d, get_halo_size, halo_exchange,
ifft3d, slice_pad, slice_unpad)
from jaxpm.kernels import cic_compensation, fftk from jaxpm.kernels import cic_compensation, fftk
from jaxpm.painting_utils import gather, scatter
def cic_paint(mesh, positions, weight=None): def _cic_paint_impl(grid_mesh, positions, weight=None):
""" Paints positions onto mesh """ Paints positions onto mesh
mesh: [nx, ny, nz] mesh: [nx, ny, nz]
positions: [npart, 3] displacement field: [nx, ny, nz, 3]
""" """
positions = positions.reshape([-1, 3])
positions = jnp.expand_dims(positions, 1) positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions) floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
@ -19,40 +28,98 @@ def cic_paint(mesh, positions, weight=None):
kernel = 1. - jnp.abs(positions - neighboor_coords) kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
if weight is not None: if weight is not None:
if jnp.isscalar(weight):
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel) kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
else:
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]),
kernel)
neighboor_coords = jnp.mod( neighboor_coords = jnp.mod(
neighboor_coords.reshape([-1, 8, 3]).astype('int32'), neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
jnp.array(mesh.shape)) jnp.array(grid_mesh.shape))
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(), dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
inserted_window_dims=(0, 1, 2), inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1, scatter_dims_to_operand_dims=(0, 1,
2)) 2))
mesh = lax.scatter_add(mesh, neighboor_coords, kernel.reshape([-1, 8]), mesh = lax.scatter_add(grid_mesh, neighboor_coords,
dnums) kernel.reshape([-1, 8]), dnums)
return mesh return mesh
def cic_read(mesh, positions): @partial(jax.jit, static_argnums=(3, 4))
def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
positions = positions.reshape((*grid_mesh.shape, 3))
halo_size, halo_extents = get_halo_size(halo_size, sharding)
grid_mesh = slice_pad(grid_mesh, halo_size, sharding)
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
grid_mesh = autoshmap(_cic_paint_impl,
gpu_mesh=gpu_mesh,
in_specs=(spec, spec, P()),
out_specs=spec)(grid_mesh, positions, weight)
grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,
halo_periods=(True, True))
grid_mesh = slice_unpad(grid_mesh, halo_size, sharding)
return grid_mesh
def _cic_read_impl(grid_mesh, positions):
""" Paints positions onto mesh """ Paints positions onto mesh
mesh: [nx, ny, nz] mesh: [nx, ny, nz]
positions: [npart, 3] positions: [nx,ny,nz, 3]
""" """
# Save original shape for reshaping output later
original_shape = positions.shape
# Reshape positions to a flat list of 3D coordinates
positions = positions.reshape([-1, 3])
# Expand dimensions to calculate neighbor coordinates
positions = jnp.expand_dims(positions, 1) positions = jnp.expand_dims(positions, 1)
# Floor the positions to get the base grid cell for each particle
floor = jnp.floor(positions) floor = jnp.floor(positions)
# Define connections to calculate all neighbor coordinates
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]]) [1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
# Calculate the 8 neighboring coordinates
neighboor_coords = floor + connection neighboor_coords = floor + connection
# Calculate kernel weights based on distance from each neighboring coordinate
kernel = 1. - jnp.abs(positions - neighboor_coords) kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
# Modulo operation to wrap around edges if necessary
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'), neighboor_coords = jnp.mod(neighboor_coords.astype('int32'),
jnp.array(mesh.shape)) jnp.array(grid_mesh.shape))
# Ensure grid_mesh shape is as expected
# Retrieve values from grid_mesh at each neighboring coordinate and multiply by kernel
return (grid_mesh[neighboor_coords[..., 0],
neighboor_coords[..., 1],
neighboor_coords[..., 2]] * kernel).sum(axis=-1).reshape(original_shape[:-1]) # yapf: disable
return (mesh[neighboor_coords[..., 0], neighboor_coords[..., 1],
neighboor_coords[..., 3]] * kernel).sum(axis=-1) @partial(jax.jit, static_argnums=(2, 3))
def cic_read(grid_mesh, positions, halo_size=0, sharding=None):
original_shape = positions.shape
positions = positions.reshape((*grid_mesh.shape, 3))
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,
halo_periods=(True, True))
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
displacement = autoshmap(_cic_read_impl,
gpu_mesh=gpu_mesh,
in_specs=(spec, spec),
out_specs=spec)(grid_mesh, positions)
return displacement.reshape(original_shape[:-1])
def cic_paint_2d(mesh, positions, weight): def cic_paint_2d(mesh, positions, weight):
@ -84,6 +151,99 @@ def cic_paint_2d(mesh, positions, weight):
return mesh return mesh
def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
halo_x, _ = halo_size[0]
halo_y, _ = halo_size[1]
original_shape = displacements.shape
particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32')
if not jnp.isscalar(weight):
if weight.shape != original_shape[:-1]:
raise ValueError("Weight shape must match particle shape")
else:
weight = weight.flatten()
# Padding is forced to be zero in a single gpu run
a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]),
jnp.arange(particle_mesh.shape[1]),
jnp.arange(particle_mesh.shape[2]),
indexing='ij')
particle_mesh = jnp.pad(particle_mesh, halo_size)
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
return scatter(pmid.reshape([-1, 3]),
displacements.reshape([-1, 3]),
particle_mesh,
chunk_size=2**24,
val=weight)
@partial(jax.jit, static_argnums=(1, 2, 4))
def cic_paint_dx(displacements,
halo_size=0,
sharding=None,
weight=1.0,
chunk_size=2**24):
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
grid_mesh = autoshmap(partial(_cic_paint_dx_impl,
halo_size=halo_size,
weight=weight,
chunk_size=chunk_size),
gpu_mesh=gpu_mesh,
in_specs=spec,
out_specs=spec)(displacements)
grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,
halo_periods=(True, True))
grid_mesh = slice_unpad(grid_mesh, halo_size, sharding)
return grid_mesh
def _cic_read_dx_impl(grid_mesh, disp, halo_size):
halo_x, _ = halo_size[0]
halo_y, _ = halo_size[1]
original_shape = [
dim - 2 * halo[0] for dim, halo in zip(grid_mesh.shape, halo_size)
]
a, b, c = jnp.meshgrid(jnp.arange(original_shape[0]),
jnp.arange(original_shape[1]),
jnp.arange(original_shape[2]),
indexing='ij')
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
pmid = pmid.reshape([-1, 3])
disp = disp.reshape([-1, 3])
return gather(pmid, disp, grid_mesh).reshape(original_shape)
@partial(jax.jit, static_argnums=(2, 3))
def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None):
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,
halo_periods=(True, True))
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
displacements = autoshmap(partial(_cic_read_dx_impl, halo_size=halo_size),
gpu_mesh=gpu_mesh,
in_specs=(spec),
out_specs=spec)(grid_mesh, disp)
return displacements
def compensate_cic(field): def compensate_cic(field):
""" """
Compensate for CiC painting Compensate for CiC painting
@ -92,9 +252,8 @@ def compensate_cic(field):
Returns: Returns:
compensated_field compensated_field
""" """
nc = field.shape delta_k = fft3d(field)
kvec = fftk(nc)
delta_k = jnp.fft.rfftn(field) kvec = fftk(delta_k)
delta_k = cic_compensation(kvec) * delta_k delta_k = cic_compensation(kvec) * delta_k
return jnp.fft.irfftn(delta_k) return ifft3d(delta_k)

190
jaxpm/painting_utils.py Normal file
View file

@ -0,0 +1,190 @@
import jax
import jax.numpy as jnp
from jax.lax import scan
def _chunk_split(ptcl_num, chunk_size, *arrays):
"""Split and reshape particle arrays into chunks and remainders, with the remainders
preceding the chunks. 0D ones are duplicated as full arrays in the chunks."""
chunk_size = ptcl_num if chunk_size is None else min(chunk_size, ptcl_num)
remainder_size = ptcl_num % chunk_size
chunk_num = ptcl_num // chunk_size
remainder = None
chunks = arrays
if remainder_size:
remainder = [x[:remainder_size] if x.ndim != 0 else x for x in arrays]
chunks = [x[remainder_size:] if x.ndim != 0 else x for x in arrays]
# `scan` triggers errors in scatter and gather without the `full`
chunks = [
x.reshape(chunk_num, chunk_size, *x.shape[1:])
if x.ndim != 0 else jnp.full(chunk_num, x) for x in chunks
]
return remainder, chunks
def enmesh(base_indices, displacements, cell_size, base_shape, offset,
new_cell_size, new_shape):
"""Multilinear enmeshing."""
base_indices = jnp.asarray(base_indices)
displacements = jnp.asarray(displacements)
with jax.experimental.enable_x64():
cell_size = jnp.float64(
cell_size) if new_cell_size is not None else jnp.array(
cell_size, dtype=displacements.dtype)
if base_shape is not None:
base_shape = jnp.array(base_shape, dtype=base_indices.dtype)
offset = jnp.float64(offset)
if new_cell_size is not None:
new_cell_size = jnp.float64(new_cell_size)
if new_shape is not None:
new_shape = jnp.array(new_shape, dtype=base_indices.dtype)
spatial_dim = base_indices.shape[1]
neighbor_offsets = (
jnp.arange(2**spatial_dim, dtype=base_indices.dtype)[:, jnp.newaxis] >>
jnp.arange(spatial_dim, dtype=base_indices.dtype)) & 1
if new_cell_size is not None:
particle_positions = base_indices * cell_size + displacements - offset
particle_positions = particle_positions[:, jnp.
newaxis] # insert neighbor axis
new_indices = particle_positions + neighbor_offsets * new_cell_size # multilinear
if base_shape is not None:
grid_length = base_shape * cell_size
new_indices %= grid_length
new_indices //= new_cell_size
new_displacements = particle_positions - new_indices * new_cell_size
if base_shape is not None:
new_displacements -= jnp.rint(
new_displacements / grid_length
) * grid_length # also abs(new_displacements) < new_cell_size is expected
new_indices = new_indices.astype(base_indices.dtype)
new_displacements = new_displacements.astype(displacements.dtype)
new_cell_size = new_cell_size.astype(displacements.dtype)
new_displacements /= new_cell_size
else:
offset_indices, offset_displacements = jnp.divmod(offset, cell_size)
base_indices -= offset_indices.astype(base_indices.dtype)
displacements -= offset_displacements.astype(displacements.dtype)
# insert neighbor axis
base_indices = base_indices[:, jnp.newaxis]
displacements = displacements[:, jnp.newaxis]
# multilinear
displacements /= cell_size
new_indices = jnp.floor(displacements).astype(base_indices.dtype)
new_indices += neighbor_offsets
new_displacements = displacements - new_indices
new_indices += base_indices
if base_shape is not None:
new_indices %= base_shape
weights = 1 - jnp.abs(new_displacements)
if base_shape is None and new_shape is not None: # all new_indices >= 0 if base_shape is not None
new_indices = jnp.where(new_indices < 0, new_shape, new_indices)
weights = weights.prod(axis=-1)
return new_indices, weights
def _scatter_chunk(carry, chunk):
mesh, offset, cell_size, mesh_shape = carry
pmid, disp, val = chunk
spatial_ndim = pmid.shape[1]
spatial_shape = mesh.shape
# multilinear mesh indices and fractions
ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size,
spatial_shape)
# scatter
ind = tuple(ind[..., i] for i in range(spatial_ndim))
mesh = mesh.at[ind].add(jnp.multiply(jnp.expand_dims(val, axis=-1), frac))
carry = mesh, offset, cell_size, mesh_shape
return carry, None
def scatter(pmid,
disp,
mesh,
chunk_size=2**24,
val=1.,
offset=0,
cell_size=1.):
ptcl_num, spatial_ndim = pmid.shape
val = jnp.asarray(val)
mesh = jnp.asarray(mesh)
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
carry = mesh, offset, cell_size, mesh.shape
if remainder is not None:
carry = _scatter_chunk(carry, remainder)[0]
carry = scan(_scatter_chunk, carry, chunks)[0]
mesh = carry[0]
return mesh
def _chunk_cat(remainder_array, chunked_array):
"""Reshape and concatenate one remainder and one chunked particle arrays."""
array = chunked_array.reshape(-1, *chunked_array.shape[2:])
if remainder_array is not None:
array = jnp.concatenate((remainder_array, array), axis=0)
return array
def gather(pmid, disp, mesh, chunk_size=2**24, val=0, offset=0, cell_size=1.):
ptcl_num, spatial_ndim = pmid.shape
mesh = jnp.asarray(mesh)
val = jnp.asarray(val)
if mesh.shape[spatial_ndim:] != val.shape[1:]:
raise ValueError('channel shape mismatch: '
f'{mesh.shape[spatial_ndim:]} != {val.shape[1:]}')
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
carry = mesh, offset, cell_size, mesh.shape
val_0 = None
if remainder is not None:
val_0 = _gather_chunk(carry, remainder)[1]
val = scan(_gather_chunk, carry, chunks)[1]
val = _chunk_cat(val_0, val)
return val
def _gather_chunk(carry, chunk):
mesh, offset, cell_size, mesh_shape = carry
pmid, disp, val = chunk
spatial_ndim = pmid.shape[1]
spatial_shape = mesh.shape[:spatial_ndim]
chan_ndim = mesh.ndim - spatial_ndim
chan_axis = tuple(range(-chan_ndim, 0))
# multilinear mesh indices and fractions
ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size,
spatial_shape)
# gather
ind = tuple(ind[..., i] for i in range(spatial_ndim))
frac = jnp.expand_dims(frac, chan_axis)
val += (mesh.at[ind].get(mode='drop', fill_value=0) * frac).sum(axis=1)
return carry, val

129
jaxpm/plotting.py Normal file
View file

@ -0,0 +1,129 @@
import matplotlib.pyplot as plt
import numpy as np
def plot_fields(fields_dict, sum_over=None):
"""
Plots sum projections of 3D fields along different axes,
slicing only the first `sum_over` elements along each axis.
Args:
- fields: list of 3D arrays representing fields to plot
- names: list of names for each field, used in titles
- sum_over: number of slices to sum along each axis (default: fields[0].shape[0] // 8)
"""
sum_over = sum_over or list(fields_dict.values())[0].shape[0] // 8
nb_rows = len(fields_dict)
nb_cols = 3
fig, axes = plt.subplots(nb_rows, nb_cols, figsize=(15, 5 * nb_rows))
def plot_subplots(proj_axis, field, row, title):
slicing = [slice(None)] * field.ndim
slicing[proj_axis] = slice(None, sum_over)
slicing = tuple(slicing)
# Sum projection over the specified axis and plot
axes[row, proj_axis].imshow(
field[slicing].sum(axis=proj_axis) + 1,
cmap='magma',
extent=[0, field.shape[proj_axis], 0, field.shape[proj_axis]])
axes[row, proj_axis].set_xlabel('Mpc/h')
axes[row, proj_axis].set_ylabel('Mpc/h')
axes[row, proj_axis].set_title(title)
# Plot each field across the three axes
for i, (name, field) in enumerate(fields_dict.items()):
for proj_axis in range(3):
plot_subplots(proj_axis, field, i,
f"{name} projection {proj_axis}")
plt.tight_layout()
plt.show()
def plot_fields_single_projection(fields_dict,
sum_over=None,
project_axis=0,
vmin=None,
vmax=None,
colorbar=False):
"""
Plots a single projection (along axis 0) of 3D fields in a grid,
summing over the first `sum_over` elements along the 0-axis, with 4 images per row.
Args:
- fields_dict: dictionary where keys are field names and values are 3D arrays
- sum_over: number of slices to sum along the projection axis (default: fields[0].shape[0] // 8)
"""
sum_over = sum_over or list(fields_dict.values())[0].shape[0] // 8
nb_fields = len(fields_dict)
nb_cols = 4 # Set number of images per row
nb_rows = (nb_fields + nb_cols - 1) // nb_cols # Calculate required rows
fig, axes = plt.subplots(nb_rows,
nb_cols,
figsize=(5 * nb_cols, 5 * nb_rows))
axes = np.atleast_2d(axes) # Ensure axes is always a 2D array
for i, (name, field) in enumerate(fields_dict.items()):
row, col = divmod(i, nb_cols)
# Define the slice for the 0-axis projection
slicing = [slice(None)] * field.ndim
slicing[project_axis] = slice(None, sum_over)
slicing = tuple(slicing)
# Sum projection over axis 0 and plot
a = axes[row,
col].imshow(field[slicing].sum(axis=project_axis) + 1,
cmap='magma',
extent=[0, field.shape[1], 0, field.shape[2]],
vmin=vmin,
vmax=vmax)
axes[row, col].set_xlabel('Mpc/h')
axes[row, col].set_ylabel('Mpc/h')
axes[row, col].set_title(f"{name} projection 0")
if colorbar:
fig.colorbar(a, ax=axes[row, col], shrink=0.7)
# Remove any empty subplots
for j in range(i + 1, nb_rows * nb_cols):
fig.delaxes(axes.flatten()[j])
plt.tight_layout()
plt.show()
def stack_slices(array):
"""
Stacks 2D slices of an array into a single array based on provided partition dimensions.
Args:
- array_slices: a 2D list of array slices (list of lists format) where
array_slices[i][j] is the slice located at row i, column j in the grid.
- pdims: a tuple representing the grid dimensions (rows, columns).
Returns:
- A single array constructed by stacking the slices.
"""
# Initialize an empty list to store the vertically stacked rows
pdims = array.sharding.mesh.devices.shape
field_slices = []
# Iterate over rows in pdims[0]
for i in range(pdims[0]):
row_slices = []
# Iterate over columns in pdims[1]
for j in range(pdims[1]):
slice_index = i * pdims[0] + j
row_slices.append(array.addressable_data(slice_index))
# Stack the current row of slices vertically
stacked_row = np.hstack(row_slices)
field_slices.append(stacked_row)
# Stack all rows horizontally to form the full array
full_array = np.vstack(field_slices)
return full_array

View file

@ -1,50 +1,92 @@
import jax
import jax.numpy as jnp import jax.numpy as jnp
import jax_cosmo as jc import jax_cosmo as jc
from jax_cosmo import Cosmology
from jaxpm.growth import growth_factor, growth_rate, dGfa, growth_factor_second, growth_rate_second, dGf2a from jaxpm.distributed import fft3d, ifft3d, normal_field
from jaxpm.kernels import PGD_kernel, fftk, gradient_kernel, invlaplace_kernel, longrange_kernel from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second,
from jaxpm.painting import cic_paint, cic_read growth_rate, growth_rate_second)
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel,
invlaplace_kernel, longrange_kernel)
from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx
def pm_forces(positions,
def pm_forces(positions, mesh_shape, delta=None, r_split=0): mesh_shape=None,
delta=None,
r_split=0,
paint_absolute_pos=True,
halo_size=0,
sharding=None):
""" """
Computes gravitational forces on particles using a PM scheme Computes gravitational forces on particles using a PM scheme
""" """
if mesh_shape is None:
assert (delta is not None),\
"If mesh_shape is not provided, delta should be provided"
mesh_shape = delta.shape
if paint_absolute_pos:
paint_fn = lambda pos: cic_paint(jnp.zeros(shape=mesh_shape,
device=sharding),
pos,
halo_size=halo_size,
sharding=sharding)
read_fn = lambda grid_mesh, pos: cic_read(
grid_mesh, pos, halo_size=halo_size, sharding=sharding)
else:
paint_fn = lambda disp: cic_paint_dx(
disp, halo_size=halo_size, sharding=sharding)
read_fn = lambda grid_mesh, disp: cic_read_dx(
grid_mesh, disp, halo_size=halo_size, sharding=sharding)
if delta is None: if delta is None:
delta_k = jnp.fft.rfftn(cic_paint(jnp.zeros(mesh_shape), positions)) field = paint_fn(positions)
delta_k = fft3d(field)
elif jnp.isrealobj(delta): elif jnp.isrealobj(delta):
delta_k = jnp.fft.rfftn(delta) delta_k = fft3d(delta)
else: else:
delta_k = delta delta_k = delta
kvec = fftk(delta_k)
# Computes gravitational potential # Computes gravitational potential
kvec = fftk(mesh_shape) pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec, r_split=r_split) kvec, r_split=r_split)
# Computes gravitational forces # Computes gravitational forces
return jnp.stack([cic_read(jnp.fft.irfftn(- gradient_kernel(kvec, i) * pot_k), positions) forces = jnp.stack([
for i in range(3)], axis=-1) read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),positions
) for i in range(3)], axis=-1) # yapf: disable
return forces
def lpt(cosmo:Cosmology, init_mesh, positions, a, order=1): def lpt(cosmo,
initial_conditions,
particles=None,
a=0.1,
halo_size=0,
sharding=None,
order=1):
""" """
Computes first and second order LPT displacement and momentum, Computes first and second order LPT displacement and momentum,
e.g. Eq. 2 and 3 [Jenkins2010](https://arxiv.org/pdf/0910.0258) e.g. Eq. 2 and 3 [Jenkins2010](https://arxiv.org/pdf/0910.0258)
""" """
paint_absolute_pos = particles is not None
if particles is None:
particles = jnp.zeros_like(initial_conditions,
shape=(*initial_conditions.shape, 3))
a = jnp.atleast_1d(a) a = jnp.atleast_1d(a)
E = jnp.sqrt(jc.background.Esqr(cosmo, a)) E = jnp.sqrt(jc.background.Esqr(cosmo, a))
delta_k = jnp.fft.rfftn(init_mesh) # TODO: pass the modes directly to save one or two fft? delta_k = fft3d(initial_conditions)
mesh_shape = init_mesh.shape initial_force = pm_forces(particles,
delta=delta_k,
init_force = pm_forces(positions, mesh_shape, delta=delta_k) paint_absolute_pos=paint_absolute_pos,
dx = growth_factor(cosmo, a) * init_force halo_size=halo_size,
sharding=sharding)
dx = growth_factor(cosmo, a) * initial_force
p = a**2 * growth_rate(cosmo, a) * E * dx p = a**2 * growth_rate(cosmo, a) * E * dx
f = a**2 * E * dGfa(cosmo, a) * init_force f = a**2 * E * dGfa(cosmo, a) * initial_force
if order == 2: if order == 2:
kvec = fftk(mesh_shape) kvec = fftk(delta_k)
pot_k = delta_k * invlaplace_kernel(kvec) pot_k = delta_k * invlaplace_kernel(kvec)
delta2 = 0 delta2 = 0
@ -54,20 +96,26 @@ def lpt(cosmo:Cosmology, init_mesh, positions, a, order=1):
# Add products of diagonal terms = 0 + s11*s00 + s22*(s11+s00)... # Add products of diagonal terms = 0 + s11*s00 + s22*(s11+s00)...
# shear_ii = jnp.fft.irfftn(- ki**2 * pot_k) # shear_ii = jnp.fft.irfftn(- ki**2 * pot_k)
nabla_i_nabla_i = gradient_kernel(kvec, i)**2 nabla_i_nabla_i = gradient_kernel(kvec, i)**2
shear_ii = jnp.fft.irfftn(nabla_i_nabla_i * pot_k) shear_ii = ifft3d(nabla_i_nabla_i * pot_k)
delta2 += shear_ii * shear_acc delta2 += shear_ii * shear_acc
shear_acc += shear_ii shear_acc += shear_ii
# for kj in kvec[i+1:]: # for kj in kvec[i+1:]:
for j in range(i+1, 3): for j in range(i + 1, 3):
# Substract squared strict-up-triangle terms # Substract squared strict-up-triangle terms
# delta2 -= jnp.fft.irfftn(- ki * kj * pot_k)**2 # delta2 -= jnp.fft.irfftn(- ki * kj * pot_k)**2
nabla_i_nabla_j = gradient_kernel(kvec, i) * gradient_kernel(kvec, j) nabla_i_nabla_j = gradient_kernel(kvec, i) * gradient_kernel(
delta2 -= jnp.fft.irfftn(nabla_i_nabla_j * pot_k)**2 kvec, j)
delta2 -= ifft3d(nabla_i_nabla_j * pot_k)**2
init_force2 = pm_forces(positions, mesh_shape, delta=jnp.fft.rfftn(delta2)) delta_k2 = fft3d(delta2)
init_force2 = pm_forces(particles,
delta=delta_k2,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding)
# NOTE: growth_factor_second is renormalized: - D2 = 3/7 * growth_factor_second # NOTE: growth_factor_second is renormalized: - D2 = 3/7 * growth_factor_second
dx2 = 3/7 * growth_factor_second(cosmo, a) * init_force2 dx2 = 3 / 7 * growth_factor_second(cosmo, a) * init_force2
p2 = a**2 * growth_rate_second(cosmo, a) * E * dx2 p2 = a**2 * growth_rate_second(cosmo, a) * E * dx2
f2 = a**2 * E * dGf2a(cosmo, a) * init_force2 f2 = a**2 * E * dGf2a(cosmo, a) * init_force2
@ -78,23 +126,28 @@ def lpt(cosmo:Cosmology, init_mesh, positions, a, order=1):
return dx, p, f return dx, p, f
def linear_field(mesh_shape, box_size, pk, seed): def linear_field(mesh_shape, box_size, pk, seed, sharding=None):
""" """
Generate initial conditions. Generate initial conditions.
""" """
kvec = fftk(mesh_shape) # Initialize a random field with one slice on each gpu
field = normal_field(mesh_shape, seed=seed, sharding=sharding)
field = fft3d(field)
kvec = fftk(field)
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2 kmesh = sum((kk / box_size[i] * mesh_shape[i])**2
for i, kk in enumerate(kvec))**0.5 for i, kk in enumerate(kvec))**0.5
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / ( pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (
box_size[0] * box_size[1] * box_size[2]) box_size[0] * box_size[1] * box_size[2])
field = jax.random.normal(seed, mesh_shape) field = field * (pkmesh)**0.5
field = jnp.fft.rfftn(field) * pkmesh**0.5 field = ifft3d(field)
field = jnp.fft.irfftn(field)
return field return field
def make_ode_fn(mesh_shape): def make_ode_fn(mesh_shape,
paint_absolute_pos=True,
halo_size=0,
sharding=None):
def nbody_ode(state, a, cosmo): def nbody_ode(state, a, cosmo):
""" """
@ -102,7 +155,11 @@ def make_ode_fn(mesh_shape):
""" """
pos, vel = state pos, vel = state
forces = pm_forces(pos, mesh_shape=mesh_shape) * 1.5 * cosmo.Omega_m forces = pm_forces(pos,
mesh_shape=mesh_shape,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding) * 1.5 * cosmo.Omega_m
# Computes the update of position (drift) # Computes the update of position (drift)
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
@ -114,16 +171,24 @@ def make_ode_fn(mesh_shape):
return nbody_ode return nbody_ode
def get_ode_fn(cosmo:Cosmology, mesh_shape):
def make_diffrax_ode(cosmo,
mesh_shape,
paint_absolute_pos=True,
halo_size=0,
sharding=None):
def nbody_ode(a, state, args): def nbody_ode(a, state, args):
""" """
State is an array [position, velocities] state is a tuple (position, velocities)
Compatible with [Diffrax API](https://docs.kidger.site/diffrax/)
""" """
pos, vel = state pos, vel = state
forces = pm_forces(pos, mesh_shape) * 1.5 * cosmo.Omega_m
forces = pm_forces(pos,
mesh_shape=mesh_shape,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding) * 1.5 * cosmo.Omega_m
# Computes the update of position (drift) # Computes the update of position (drift)
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
@ -145,44 +210,50 @@ def pgd_correction(pos, mesh_shape, params):
pos: particle positions [npart, 3] pos: particle positions [npart, 3]
params: [alpha, kl, ks] pgd parameters params: [alpha, kl, ks] pgd parameters
""" """
kvec = fftk(mesh_shape)
delta = cic_paint(jnp.zeros(mesh_shape), pos) delta = cic_paint(jnp.zeros(mesh_shape), pos)
delta_k = fft3d(delta)
kvec = fftk(delta_k)
alpha, kl, ks = params alpha, kl, ks = params
delta_k = jnp.fft.rfftn(delta) PGD_range = PGD_kernel(kvec, kl, ks)
PGD_range=PGD_kernel(kvec, kl, ks)
pot_k_pgd=(delta_k * invlaplace_kernel(kvec))*PGD_range pot_k_pgd = (delta_k * invlaplace_kernel(kvec)) * PGD_range
forces_pgd= jnp.stack([cic_read(jnp.fft.irfftn(- gradient_kernel(kvec, i)*pot_k_pgd), pos) forces_pgd = jnp.stack([
for i in range(3)],axis=-1) cic_read(fft3d(-gradient_kernel(kvec, i) * pot_k_pgd), pos)
for i in range(3)
],
axis=-1)
dpos_pgd = forces_pgd*alpha dpos_pgd = forces_pgd * alpha
return dpos_pgd return dpos_pgd
def make_neural_ode_fn(model, mesh_shape): def make_neural_ode_fn(model, mesh_shape):
def neural_nbody_ode(state, a, cosmo:Cosmology, params):
def neural_nbody_ode(state, a, cosmo: Cosmology, params):
""" """
state is a tuple (position, velocities) state is a tuple (position, velocities)
""" """
pos, vel = state pos, vel = state
kvec = fftk(mesh_shape)
delta = cic_paint(jnp.zeros(mesh_shape), pos) delta = cic_paint(jnp.zeros(mesh_shape), pos)
delta_k = fft3d(delta)
delta_k = jnp.fft.rfftn(delta) kvec = fftk(delta_k)
# Computes gravitational potential # Computes gravitational potential
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec, r_split=0) pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec,
r_split=0)
# Apply a correction filter # Apply a correction filter
kk = jnp.sqrt(sum((ki/jnp.pi)**2 for ki in kvec)) kk = jnp.sqrt(sum((ki / jnp.pi)**2 for ki in kvec))
pot_k = pot_k *(1. + model.apply(params, kk, jnp.atleast_1d(a))) pot_k = pot_k * (1. + model.apply(params, kk, jnp.atleast_1d(a)))
# Computes gravitational forces # Computes gravitational forces
forces = jnp.stack([cic_read(jnp.fft.irfftn(- gradient_kernel(kvec, i)*pot_k), pos) forces = jnp.stack([
for i in range(3)],axis=-1) cic_read(fft3d(-gradient_kernel(kvec, i) * pot_k), pos)
for i in range(3)
],
axis=-1)
forces = forces * 1.5 * cosmo.Omega_m forces = forces * 1.5 * cosmo.Omega_m
@ -193,4 +264,5 @@ def make_neural_ode_fn(model, mesh_shape):
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
return dpos, dvel return dpos, dvel
return neural_nbody_ode return neural_nbody_ode

View file

@ -1,89 +1,161 @@
from functools import partial
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np import numpy as np
from jax.scipy.stats import norm from jax.scipy.stats import norm
from scipy.special import legendre
__all__ = ['power_spectrum'] __all__ = [
'power_spectrum', 'transfer', 'coherence', 'pktranscoh',
'cross_correlation_coefficients', 'gaussian_smoothing'
]
def _initialize_pk(shape, boxsize, kmin, dk): def _initialize_pk(mesh_shape, box_shape, kedges, los):
""" """
Helper function to initialize various (fixed) values for powerspectra... not differentiable! Parameters
----------
mesh_shape : tuple of int
Shape of the mesh grid.
box_shape : tuple of float
Physical dimensions of the box.
kedges : None, int, float, or list
If None, set dk to twice the minimum.
If int, specifies number of edges.
If float, specifies dk.
los : array_like
Line-of-sight vector.
Returns
-------
dig : ndarray
Indices of the bins to which each value in input array belongs.
kcount : ndarray
Count of values in each bin.
kedges : ndarray
Edges of the bins.
mumesh : ndarray
Mu values for the mesh grid.
""" """
I = np.eye(len(shape), dtype='int') * -2 + 1 kmax = np.pi * np.min(mesh_shape / box_shape) # = knyquist
W = np.empty(shape, dtype='f4') if isinstance(kedges, None | int | float):
W[...] = 2.0 if kedges is None:
W[..., 0] = 1.0 dk = 2 * np.pi / np.min(
W[..., -1] = 1.0 box_shape) * 2 # twice the minimum wavenumber
if isinstance(kedges, int):
dk = kmax / (kedges + 1) # final number of bins will be kedges-1
elif isinstance(kedges, float):
dk = kedges
kedges = np.arange(dk, kmax, dk) + dk / 2 # from dk/2 to kmax-dk/2
kmax = np.pi * np.min(np.array(shape)) / np.max(np.array(boxsize)) + dk / 2 kshapes = np.eye(len(mesh_shape), dtype=np.int32) * -2 + 1
kedges = np.arange(kmin, kmax, dk) kvec = [(2 * np.pi * m / l) * np.fft.fftfreq(m).reshape(kshape)
for m, l, kshape in zip(mesh_shape, box_shape, kshapes)]
kmesh = sum(ki**2 for ki in kvec)**0.5
k = [ dig = np.digitize(kmesh.reshape(-1), kedges)
np.fft.fftfreq(N, 1. / (N * 2 * np.pi / L))[:pkshape].reshape(kshape) kcount = np.bincount(dig, minlength=len(kedges) + 1)
for N, L, kshape, pkshape in zip(shape, boxsize, I, shape)
]
kmag = sum(ki**2 for ki in k)**0.5
xsum = np.zeros(len(kedges) + 1) # Central value of each bin
Nsum = np.zeros(len(kedges) + 1) # kavg = (kedges[1:] + kedges[:-1]) / 2
kavg = np.bincount(
dig, weights=kmesh.reshape(-1), minlength=len(kedges) + 1) / kcount
kavg = kavg[1:-1]
dig = np.digitize(kmag.flat, kedges) if los is None:
mumesh = 1.
else:
mumesh = sum(ki * losi for ki, losi in zip(kvec, los))
kmesh_nozeros = np.where(kmesh == 0, 1, kmesh)
mumesh = np.where(kmesh == 0, 0, mumesh / kmesh_nozeros)
xsum.flat += np.bincount(dig, weights=(W * kmag).flat, minlength=xsum.size) return dig, kcount, kavg, mumesh
Nsum.flat += np.bincount(dig, weights=W.flat, minlength=xsum.size)
return dig, Nsum, xsum, W, k, kedges
def power_spectrum(field, kmin=5, dk=0.5, boxsize=False): def power_spectrum(mesh,
mesh2=None,
box_shape=None,
kedges: int | float | list = None,
multipoles=0,
los=[0., 0., 1.]):
""" """
Calculate the powerspectra given real space field Compute the auto and cross spectrum of 3D fields, with multipoles.
Args:
field: real valued field
kmin: minimum k-value for binned powerspectra
dk: differential in each kbin
boxsize: length of each boxlength (can be strangly shaped?)
Returns:
kbins: the central value of the bins for plotting
power: real valued array of power in each bin
""" """
shape = field.shape # Initialize
nx, ny, nz = shape mesh_shape = np.array(mesh.shape)
if box_shape is None:
box_shape = mesh_shape
else:
box_shape = np.asarray(box_shape)
#initialze values related to powerspectra (mode bins and weights) if multipoles == 0:
dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk) los = None
else:
los = np.asarray(los)
los = los / np.linalg.norm(los)
poles = np.atleast_1d(multipoles)
dig, kcount, kavg, mumesh = _initialize_pk(mesh_shape, box_shape, kedges,
los)
n_bins = len(kavg) + 2
#fast fourier transform # FFTs
fft_image = jnp.fft.fftn(field) meshk = jnp.fft.fftn(mesh, norm='ortho')
if mesh2 is None:
mmk = meshk.real**2 + meshk.imag**2
else:
mmk = meshk * jnp.fft.fftn(mesh2, norm='ortho').conj()
#absolute value of fast fourier transform # Sum powers
pk = jnp.real(fft_image * jnp.conj(fft_image)) pk = jnp.empty((len(poles), n_bins))
for i_ell, ell in enumerate(poles):
weights = (mmk * (2 * ell + 1) * legendre(ell)(mumesh)).reshape(-1)
if mesh2 is None:
psum = jnp.bincount(dig, weights=weights, length=n_bins)
else: # XXX: bincount is really slow with complex numbers
psum_real = jnp.bincount(dig, weights=weights.real, length=n_bins)
psum_imag = jnp.bincount(dig, weights=weights.imag, length=n_bins)
psum = (psum_real**2 + psum_imag**2)**.5
pk = pk.at[i_ell].set(psum)
#calculating powerspectra # Normalization and conversion from cell units to [Mpc/h]^3
real = jnp.real(pk).reshape([-1]) pk = (pk / kcount)[:, 1:-1] * (box_shape / mesh_shape).prod()
imag = jnp.imag(pk).reshape([-1])
Psum = jnp.bincount(dig, weights=(W.flatten() * imag), # pk = jnp.concatenate([kavg[None], pk])
length=xsum.size) * 1j if np.ndim(multipoles) == 0:
Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size) return kavg, pk[0]
else:
P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32') return kavg, pk
#normalization for powerspectra
norm = np.prod(np.array(shape[:])).astype('float32')**2
#find central values of each bin
kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2
return kbins, P / norm
def cross_correlation_coefficients(field_a,field_b, kmin=5, dk=0.5, boxsize=False): def transfer(mesh0, mesh1, box_shape, kedges: int | float | list = None):
pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
ks, pk0 = pk_fn(mesh0)
ks, pk1 = pk_fn(mesh1)
return ks, (pk1 / pk0)**.5
def coherence(mesh0, mesh1, box_shape, kedges: int | float | list = None):
pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
ks, pk01 = pk_fn(mesh0, mesh1)
ks, pk0 = pk_fn(mesh0)
ks, pk1 = pk_fn(mesh1)
return ks, pk01 / (pk0 * pk1)**.5
def pktranscoh(mesh0, mesh1, box_shape, kedges: int | float | list = None):
pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
ks, pk01 = pk_fn(mesh0, mesh1)
ks, pk0 = pk_fn(mesh0)
ks, pk1 = pk_fn(mesh1)
return ks, pk0, pk1, (pk1 / pk0)**.5, pk01 / (pk0 * pk1)**.5
def cross_correlation_coefficients(field_a,
field_b,
kmin=5,
dk=0.5,
boxsize=False):
""" """
Calculate the cross correlation coefficients given two real space field Calculate the cross correlation coefficients given two real space field
@ -118,7 +190,8 @@ def cross_correlation_coefficients(field_a,field_b, kmin=5, dk=0.5, boxsize=Fals
real = jnp.real(pk).reshape([-1]) real = jnp.real(pk).reshape([-1])
imag = jnp.imag(pk).reshape([-1]) imag = jnp.imag(pk).reshape([-1])
Psum = jnp.bincount(dig, weights=(W.flatten() * imag), length=xsum.size) * 1j Psum = jnp.bincount(dig, weights=(W.flatten() * imag),
length=xsum.size) * 1j
Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size) Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32') P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')

View file

@ -0,0 +1,179 @@
import os
os.environ["EQX_ON_ERROR"] = "nan" # avoid an allgather caused by diffrax
import jax
jax.distributed.initialize()
rank = jax.process_index()
size = jax.process_count()
if rank == 0:
print(f"SIZE is {jax.device_count()}")
import argparse
from functools import partial
import jax.numpy as jnp
import jax_cosmo as jc
import numpy as np
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
PIDController, SaveAt, diffeqsolve)
from jax.experimental.mesh_utils import create_device_mesh
from jax.experimental.multihost_utils import process_allgather
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from jaxpm.kernels import interpolate_power_spectrum
from jaxpm.painting import cic_paint_dx
from jaxpm.pm import linear_field, lpt, make_diffrax_ode
all_gather = partial(process_allgather, tiled=True)
def parse_arguments():
parser = argparse.ArgumentParser(
description="Run a cosmological simulation with JAX.")
parser.add_argument(
"-p",
"--pdims",
type=int,
nargs=2,
default=[1, jax.devices()],
help="Processor grid dimensions as two integers (e.g., 2 4).")
parser.add_argument(
"-m",
"--mesh_shape",
type=int,
nargs=3,
default=[512, 512, 512],
help="Shape of the simulation mesh as three values (e.g., 512 512 512)."
)
parser.add_argument(
"-b",
"--box_size",
type=float,
nargs=3,
default=[500.0, 500.0, 500.0],
help=
"Box size of the simulation as three values (e.g., 500.0 500.0 1000.0)."
)
parser.add_argument(
"-st",
"--snapshots",
type=int,
default=2,
help="Number of snapshots to save during the simulation.")
parser.add_argument("-H",
"--halo_size",
type=int,
default=64,
help="Halo size for the simulation.")
parser.add_argument("-s",
"--solver",
type=str,
choices=['leapfrog', 'dopri8'],
default='leapfrog',
help="ODE solver choice: 'leapfrog' or 'dopri8'.")
return parser.parse_args()
def create_mesh_and_sharding(pdims):
devices = create_device_mesh(pdims)
mesh = Mesh(devices, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
return mesh, sharding
@partial(jax.jit, static_argnums=(2, 3, 4, 5, 6))
def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
solver_choice, nb_snapshots, sharding):
k = jnp.logspace(-4, 1, 128)
pk = jc.power.linear_matter_power(
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)
initial_conditions = linear_field(mesh_shape,
box_size,
pk_fn,
seed=jax.random.PRNGKey(0),
sharding=sharding)
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
dx, p, _ = lpt(cosmo,
initial_conditions,
a=0.1,
halo_size=halo_size,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
# Choose solver
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()
stepsize_controller = ConstantStepSize(
) if solver_choice == "leapfrog" else PIDController(rtol=1e-5, atol=1e-5)
res = diffeqsolve(ode_fn,
solver,
t0=0.1,
t1=1.,
dt0=0.01,
y0=jnp.stack([dx, p], axis=0),
args=cosmo,
saveat=SaveAt(ts=jnp.linspace(0.2, 1., nb_snapshots)),
stepsize_controller=stepsize_controller)
ode_fields = [
cic_paint_dx(sol[0], halo_size=halo_size, sharding=sharding)
for sol in res.ys
]
lpt_field = cic_paint_dx(dx, halo_size=halo_size, sharding=sharding)
return initial_conditions, lpt_field, ode_fields, res.stats
def main():
args = parse_arguments()
mesh_shape = args.mesh_shape
box_size = args.box_size
halo_size = args.halo_size
solver_choice = args.solver
nb_snapshots = args.snapshots
sharding = create_mesh_and_sharding(args.pdims)
initial_conditions, lpt_displacements, ode_solutions, solver_stats = run_simulation(
0.25, 0.8, tuple(mesh_shape), tuple(box_size), halo_size,
solver_choice, nb_snapshots, sharding)
if rank == 0:
os.makedirs("fields", exist_ok=True)
print(f"[{rank}] Simulation done")
print(f"Solver stats: {solver_stats}")
# Save initial conditions
initial_conditions_g = all_gather(initial_conditions)
if rank == 0:
print(f"[{rank}] Saving initial_conditions")
np.save("fields/initial_conditions.npy", initial_conditions_g)
print(f"[{rank}] initial_conditions saved")
del initial_conditions_g, initial_conditions
# Save LPT displacements
lpt_displacements_g = all_gather(lpt_displacements)
if rank == 0:
print(f"[{rank}] Saving lpt_displacements")
np.save("fields/lpt_displacements.npy", lpt_displacements_g)
print(f"[{rank}] lpt_displacements saved")
del lpt_displacements_g, lpt_displacements
# Save each ODE solution separately
for i, sol in enumerate(ode_solutions):
sol_g = all_gather(sol)
if rank == 0:
print(f"[{rank}] Saving ode_solution_{i}")
np.save(f"fields/ode_solution_{i}.npy", sol_g)
print(f"[{rank}] ode_solution_{i} saved")
del sol_g
if __name__ == "__main__":
main()

View file

@ -0,0 +1 @@
c4a44973e4f11841a8c14f4d200e7e87887419aa

39
notebooks/README.md Normal file
View file

@ -0,0 +1,39 @@
# Particle Mesh Simulation with JAXPM on Multi-GPU and Multi-Host Systems
This collection of notebooks demonstrates how to perform Particle Mesh (PM) simulations using **JAXPM**, leveraging JAX for efficient computation on multi-GPU and multi-host systems. Each notebook progressively covers different setups, from single-GPU simulations to advanced, distributed, multi-host simulations across multiple nodes.
## Table of Contents
1. **[Single-GPU Particle Mesh Simulation](01-Introduction.ipynb)**
- Introduction to basic PM simulations on a single GPU.
- Uses JAXPM to run simulations with absolute particle positions and Cloud-in-Cell (CIC) painting.
2. **[Advanced Particle Mesh Simulation on a Single GPU](02-Advanced_usage.ipynb)**
- Explore using diffrax solvers in the ODE step.
- Explores second order Lagrangian Perturbation Theory (LPT) simulations.
- Introduces weighted density field projections
3. **[Multi-GPU Particle Mesh Simulation with Halo Exchange](03-MultiGPU_PM_Halo.ipynb)**
- Extends PM simulation to multi-GPU setups with halo exchange.
- Uses sharding and device mesh configurations to manage distributed data across GPUs.
4. **[Multi-GPU Particle Mesh Simulation with Advanced Solvers](04-MultiGPU_PM_Solvers.ipynb)**
- Compares different ODE solvers (Leapfrog and Dopri5) in multi-GPU simulations.
- Highlights performance, memory considerations, and solver impact on simulation quality.
5. **[Multi-Host Particle Mesh Simulation](05-MultiHost_PM.ipynb)**
- Extends PM simulations to multi-host, multi-GPU setups for large-scale simulations.
- Guides through job submission, device initialization, and retrieving results across nodes.
## Getting Started
Each notebook includes installation instructions and guidelines for configuring JAXPM and required dependencies. Follow the setup instructions in each notebook to ensure an optimal environment.
## Requirements
- **JAXPM** (included in the installation commands within notebooks)
- **Diffrax** for ODE solvers
- **JAX** with CUDA support for multi-GPU or TPU setups
- **SLURM** for job scheduling on clusters (if running multi-host setups)
> **Note**: These notebooks are tested on the **Jean Zay** supercomputer and may require configuration changes for different HPC clusters.

30
pyproject.toml Normal file
View file

@ -0,0 +1,30 @@
[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "JaxPM"
version = "0.0.1"
description = "A dead simple FastPM implementation in JAX"
authors = [{ name = "JaxPM developers" }]
readme = "README.md"
requires-python = ">=3.9"
license = { file = "LICENSE" }
urls = { "Homepage" = "https://github.com/DifferentiableUniverseInitiative/JaxPM" }
dependencies = ["jax_cosmo", "jax>=0.4.30", "jaxdecomp>=0.2.2"]
[project.optional-dependencies]
test = [
"jax>=0.4.30",
"numpy",
"jax_cosmo",
"jaxdecomp>=0.2.2",
"pytest>=8.0.0",
"pfft-python @ git+https://github.com/MP-Gadget/pfft-python",
"pmesh @ git+https://github.com/MP-Gadget/pmesh",
"fastpm @ git+https://github.com/ASKabalan/fastpm-python",
"diffrax"
]
[tool.setuptools]
packages = ["jaxpm"]

4
pytest.ini Normal file
View file

@ -0,0 +1,4 @@
[pytest]
markers =
distributed: mark a test as distributed
single_device: mark a test as single_device

View file

@ -1,11 +0,0 @@
from setuptools import find_packages, setup
setup(
name='JaxPM',
version='0.0.1',
url='https://github.com/DifferentiableUniverseInitiative/JaxPM',
author='JaxPM developers',
description='A dead simple FastPM implementation in JAX',
packages=find_packages(),
install_requires=['jax', 'jax_cosmo'],
)

175
tests/conftest.py Normal file
View file

@ -0,0 +1,175 @@
# Parameterized fixture for mesh_shape
import os
import pytest
os.environ["EQX_ON_ERROR"] = "nan"
setup_done = False
on_cluster = False
def is_on_cluster():
global on_cluster
return on_cluster
def initialize_distributed():
global setup_done
global on_cluster
if not setup_done:
if "SLURM_JOB_ID" in os.environ:
on_cluster = True
print("Running on cluster")
import jax
jax.distributed.initialize()
setup_done = True
on_cluster = True
else:
print("Running locally")
setup_done = True
on_cluster = False
os.environ["JAX_PLATFORM_NAME"] = "cpu"
os.environ[
"XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
import jax
@pytest.fixture(
scope="session",
params=[
((32, 32, 32), (256., 256., 256.)), # BOX
((32, 32, 64), (256., 256., 512.)), # RECTANGULAR
])
def simulation_config(request):
return request.param
@pytest.fixture(scope="session", params=[0.1, 0.5, 0.8])
def lpt_scale_factor(request):
return request.param
@pytest.fixture(scope="session")
def cosmo():
from functools import partial
from jax_cosmo import Cosmology
Planck18 = partial(
Cosmology,
# Omega_m = 0.3111
Omega_c=0.2607,
Omega_b=0.0490,
Omega_k=0.0,
h=0.6766,
n_s=0.9665,
sigma8=0.8102,
w0=-1.0,
wa=0.0,
)
return Planck18()
@pytest.fixture(scope="session")
def particle_mesh(simulation_config):
from pmesh.pm import ParticleMesh
mesh_shape, box_shape = simulation_config
return ParticleMesh(BoxSize=box_shape, Nmesh=mesh_shape, dtype='f4')
@pytest.fixture(scope="session")
def fpm_initial_conditions(cosmo, particle_mesh):
import jax_cosmo as jc
import numpy as np
from jax import numpy as jnp
# Generate initial particle positions
grid = particle_mesh.generate_uniform_particle_grid(shift=0).astype(
np.float32)
# Interpolate with linear_matter spectrum to get initial density field
k = jnp.logspace(-4, 1, 128)
pk = jc.power.linear_matter_power(cosmo, k)
def pk_fn(x):
return jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape)
whitec = particle_mesh.generate_whitenoise(42,
type='complex',
unitary=False)
lineark = whitec.apply(lambda k, v: pk_fn(sum(ki**2 for ki in k)**0.5)**0.5
* v * (1 / v.BoxSize).prod()**0.5)
init_mesh = lineark.c2r().value # XXX
return lineark, grid, init_mesh
@pytest.fixture(scope="session")
def initial_conditions(fpm_initial_conditions):
_, _, init_mesh = fpm_initial_conditions
return init_mesh
@pytest.fixture(scope="session")
def solver(cosmo, particle_mesh):
from fastpm.core import Cosmology as FastPMCosmology
from fastpm.core import Solver
ref_cosmo = FastPMCosmology(cosmo)
return Solver(particle_mesh, ref_cosmo, B=1)
@pytest.fixture(scope="session")
def fpm_lpt1(solver, fpm_initial_conditions, lpt_scale_factor):
lineark, grid, _ = fpm_initial_conditions
statelpt = solver.lpt(lineark, grid, lpt_scale_factor, order=1)
return statelpt
@pytest.fixture(scope="session")
def fpm_lpt1_field(fpm_lpt1, particle_mesh):
return particle_mesh.paint(fpm_lpt1.X).value
@pytest.fixture(scope="session")
def fpm_lpt2(solver, fpm_initial_conditions, lpt_scale_factor):
lineark, grid, _ = fpm_initial_conditions
statelpt = solver.lpt(lineark, grid, lpt_scale_factor, order=2)
return statelpt
@pytest.fixture(scope="session")
def fpm_lpt2_field(fpm_lpt2, particle_mesh):
return particle_mesh.paint(fpm_lpt2.X).value
@pytest.fixture(scope="session")
def nbody_from_lpt1(solver, fpm_lpt1, particle_mesh, lpt_scale_factor):
import numpy as np
from fastpm.core import leapfrog
if lpt_scale_factor == 0.8:
pytest.skip("Do not run nbody simulation from scale factor 0.8")
stages = np.linspace(lpt_scale_factor, 1.0, 10, endpoint=True)
finalstate = solver.nbody(fpm_lpt1, leapfrog(stages))
fpm_mesh = particle_mesh.paint(finalstate.X).value
return fpm_mesh
@pytest.fixture(scope="session")
def nbody_from_lpt2(solver, fpm_lpt2, particle_mesh, lpt_scale_factor):
import numpy as np
from fastpm.core import leapfrog
if lpt_scale_factor == 0.8:
pytest.skip("Do not run nbody simulation from scale factor 0.8")
stages = np.linspace(lpt_scale_factor, 1.0, 10, endpoint=True)
finalstate = solver.nbody(fpm_lpt2, leapfrog(stages))
fpm_mesh = particle_mesh.paint(finalstate.X).value
return fpm_mesh

13
tests/helpers.py Normal file
View file

@ -0,0 +1,13 @@
import jax.numpy as jnp
def MSE(x, y):
return jnp.mean((x - y)**2)
def MSE_3D(x, y):
return ((x - y)**2).mean(axis=0)
def MSRE(x, y):
return jnp.mean(((x - y) / y)**2)

155
tests/test_against_fpm.py Normal file
View file

@ -0,0 +1,155 @@
import pytest
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve
from helpers import MSE, MSRE
from jax import numpy as jnp
from jaxpm.distributed import uniform_particles
from jaxpm.painting import cic_paint, cic_paint_dx
from jaxpm.pm import lpt, make_diffrax_ode
from jaxpm.utils import power_spectrum
_TOLERANCE = 1e-4
_PM_TOLERANCE = 1e-3
@pytest.mark.single_device
@pytest.mark.parametrize("order", [1, 2])
def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor,
fpm_lpt1_field, fpm_lpt2_field, cosmo, order):
mesh_shape, box_shape = simulation_config
cosmo._workspace = {}
particles = uniform_particles(mesh_shape)
# Initial displacement
dx, _, _ = lpt(cosmo,
initial_conditions,
particles,
a=lpt_scale_factor,
order=order)
fpm_ref_field = fpm_lpt1_field if order == 1 else fpm_lpt2_field
lpt_field = cic_paint(jnp.zeros(mesh_shape), particles + dx)
_, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape)
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
@pytest.mark.single_device
@pytest.mark.parametrize("order", [1, 2])
def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor,
fpm_lpt1_field, fpm_lpt2_field, cosmo, order):
mesh_shape, box_shape = simulation_config
cosmo._workspace = {}
# Initial displacement
dx, _, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order)
lpt_field = cic_paint_dx(dx)
fpm_ref_field = fpm_lpt1_field if order == 1 else fpm_lpt2_field
_, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape)
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
@pytest.mark.single_device
@pytest.mark.parametrize("order", [1, 2])
def test_nbody_absolute(simulation_config, initial_conditions,
lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2,
cosmo, order):
mesh_shape, box_shape = simulation_config
cosmo._workspace = {}
particles = uniform_particles(mesh_shape)
# Initial displacement
dx, p, _ = lpt(cosmo,
initial_conditions,
particles,
a=lpt_scale_factor,
order=order)
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
solver = Dopri5()
controller = PIDController(rtol=1e-8,
atol=1e-8,
pcoeff=0.4,
icoeff=1,
dcoeff=0)
saveat = SaveAt(t1=True)
y0 = jnp.stack([particles + dx, p])
solutions = diffeqsolve(ode_fn,
solver,
t0=lpt_scale_factor,
t1=1.0,
dt0=None,
y0=y0,
stepsize_controller=controller,
saveat=saveat)
final_field = cic_paint(jnp.zeros(mesh_shape), solutions.ys[-1, 0])
fpm_ref_field = nbody_from_lpt1 if order == 1 else nbody_from_lpt2
_, jpm_ps = power_spectrum(final_field, box_shape=box_shape)
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE
@pytest.mark.single_device
@pytest.mark.parametrize("order", [1, 2])
def test_nbody_relative(simulation_config, initial_conditions,
lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2,
cosmo, order):
mesh_shape, box_shape = simulation_config
cosmo._workspace = {}
# Initial displacement
dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order)
ode_fn = ODETerm(
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
solver = Dopri5()
controller = PIDController(rtol=1e-9,
atol=1e-9,
pcoeff=0.4,
icoeff=1,
dcoeff=0)
saveat = SaveAt(t1=True)
y0 = jnp.stack([dx, p])
solutions = diffeqsolve(ode_fn,
solver,
t0=lpt_scale_factor,
t1=1.0,
dt0=None,
y0=y0,
stepsize_controller=controller,
saveat=saveat)
final_field = cic_paint_dx(solutions.ys[-1, 0])
fpm_ref_field = nbody_from_lpt1 if order == 1 else nbody_from_lpt2
_, jpm_ps = power_spectrum(final_field, box_shape=box_shape)
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE

View file

@ -0,0 +1,152 @@
from conftest import initialize_distributed
initialize_distributed() # ignore : E402
import jax # noqa : E402
import jax.numpy as jnp # noqa : E402
import pytest # noqa : E402
from diffrax import SaveAt # noqa : E402
from diffrax import Dopri5, ODETerm, PIDController, diffeqsolve
from helpers import MSE # noqa : E402
from jax import lax # noqa : E402
from jax.experimental.multihost_utils import process_allgather # noqa : E402
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P # noqa : E402
from jaxpm.distributed import uniform_particles # noqa : E402
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402
_TOLERANCE = 3.0 # 🙃🙃
@pytest.mark.distributed
@pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("absolute_painting", [True, False])
def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
absolute_painting):
mesh_shape, box_shape = simulation_config
# SINGLE DEVICE RUN
cosmo._workspace = {}
if absolute_painting:
particles = uniform_particles(mesh_shape)
# Initial displacement
dx, p, _ = lpt(cosmo,
initial_conditions,
particles,
a=0.1,
order=order)
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
y0 = jnp.stack([particles + dx, p])
else:
dx, p, _ = lpt(cosmo, initial_conditions, a=0.1, order=order)
ode_fn = ODETerm(
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
y0 = jnp.stack([dx, p])
solver = Dopri5()
controller = PIDController(rtol=1e-8,
atol=1e-8,
pcoeff=0.4,
icoeff=1,
dcoeff=0)
saveat = SaveAt(t1=True)
solutions = diffeqsolve(ode_fn,
solver,
t0=0.1,
t1=1.0,
dt0=None,
y0=y0,
stepsize_controller=controller,
saveat=saveat)
if absolute_painting:
single_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape),
solutions.ys[-1, 0])
else:
single_device_final_field = cic_paint_dx(solutions.ys[-1, 0])
print("Done with single device run")
# MULTI DEVICE RUN
mesh = jax.make_mesh((1, 8), ('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2
initial_conditions = lax.with_sharding_constraint(initial_conditions,
sharding)
print(f"sharded initial conditions {initial_conditions.sharding}")
cosmo._workspace = {}
if absolute_painting:
particles = uniform_particles(mesh_shape, sharding=sharding)
# Initial displacement
dx, p, _ = lpt(cosmo,
initial_conditions,
particles,
a=0.1,
order=order,
halo_size=halo_size,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(cosmo,
mesh_shape,
halo_size=halo_size,
sharding=sharding))
y0 = jnp.stack([particles + dx, p])
else:
dx, p, _ = lpt(cosmo,
initial_conditions,
a=0.1,
order=order,
halo_size=halo_size,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(cosmo,
mesh_shape,
paint_absolute_pos=False,
halo_size=halo_size,
sharding=sharding))
y0 = jnp.stack([dx, p])
solver = Dopri5()
controller = PIDController(rtol=1e-8,
atol=1e-8,
pcoeff=0.4,
icoeff=1,
dcoeff=0)
saveat = SaveAt(t1=True)
solutions = diffeqsolve(ode_fn,
solver,
t0=0.1,
t1=1.0,
dt0=None,
y0=y0,
stepsize_controller=controller,
saveat=saveat)
if absolute_painting:
multi_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape),
solutions.ys[-1, 0],
halo_size=halo_size,
sharding=sharding)
else:
multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0],
halo_size=halo_size,
sharding=sharding)
multi_device_final_field = process_allgather(multi_device_final_field,
tiled=True)
mse = MSE(single_device_final_field, multi_device_final_field)
print(f"MSE is {mse}")
assert mse < _TOLERANCE