mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-02-22 09:37:11 +00:00
jaxdecomp proto (#21)
* adding example of distributed solution * put back old functgion * update formatting * add halo exchange and slice pad * apply formatting * implement distributed optimized cic_paint * Use new cic_paint with halo * Fix seed for distributed normal * Wrap interpolation function to avoid all gather * Return normal order frequencies for single GPU * add example * format * add optimised bench script * times in ms * add lpt2 * update benchmark and add slurm * Visualize only final field * Update scripts/distributed_pm.py Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com> * Adjust pencil type for frequencies * fix painting issue with slabs * Shared operation in fourrier space now take inverted sharding axis for slabs * add assert to make pyright happy * adjust test for hpc-plotter * add PMWD test * bench * format * added github workflow * fix formatting from main * Update for jaxDecomp pure JAX * revert single halo extent change * update for latest jaxDecomp * remove fourrier_space in autoshmap * make normal_field work with single controller * format * make distributed pm work in single controller * merge bench_pm * update to leapfrog * add a strict dependency on jaxdecomp * global mesh no longer needed * kernels.py no longer uses global mesh * quick fix in distributed * pm.py no longer uses global mesh * painting.py no longer uses global mesh * update demo script * quick fix in kernels * quick fix in distributed * update demo * merge hugos LPT2 code * format * Small fix * format * remove duplicate get_ode_fn * update visualizer * update compensate CIC * By default check_rep is false for shard_map * remove experimental distributed code * update PGDCorrection and neural ode to use new fft3d * jaxDecomp pfft3d promotes to complex automatically * remove deprecated stuff * fix painting issue with read_cic * use jnp interp instead of jc interp * delete old slurms * add notebook examples * apply formatting * add distributed zeros * fix code in LPT2 * jit cic_paint * update notebooks * apply formating * get local shape and zeros can be used by users * add a user facing function to create uniform particle grid * use jax interp instead of jax_cosmo * use float64 for enmeshing * Allow applying weights with relative cic paint * Weights can be traced * remove script folder * update example notebooks * delete outdated design file * add readme for tutorials * update readme * fix small error * forgot particles in multi host * clarifying why cic_paint_dx is slower * clarifying the halo size dependence on the box size * ability to choose snapshots number with MultiHost script * Adding animation notebook * Put plotting in package * Add finite difference laplace kernel + powerspec functions from Hugo Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com> * Put plotting utils in package * By default use absoulute painting with * update code * update notebooks * add tests * Upgrade setup.py to pyproject * Format * format tests * update test dependencies * add test workflow * fix deprecated FftType in jaxpm.kernels * Add aboucaud comments * JAX version is 0.4.35 until Diffrax new release * add numpy explicitly as dependency for tests * fix install order for tests * add numpy to be installed * enforce no build isolation for fastpm * pip install jaxpm test without build isolation * bump jaxdecomp version * revert test workflow * remove outdated tests --------- Co-authored-by: EiffL <fr.eiffel@gmail.com> Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com> Co-authored-by: Wassim KABALAN <wassim@apc.in2p3.fr> Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com> Former-commit-id: 8c2e823d4669eac712089bf7f85ffb7912e8232d
This commit is contained in:
parent
a0a79277e5
commit
df8602b318
26 changed files with 1871 additions and 434 deletions
21
.github/workflows/formatting.yml
vendored
Normal file
21
.github/workflows/formatting.yml
vendored
Normal file
|
@ -0,0 +1,21 @@
|
|||
name: Code Formatting
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v3
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip isort
|
||||
python -m pip install pre-commit
|
||||
- name: Run pre-commit
|
||||
run: python -m pre_commit run --all-files
|
45
.github/workflows/tests.yml
vendored
Normal file
45
.github/workflows/tests.yml
vendored
Normal file
|
@ -0,0 +1,45 @@
|
|||
name: Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
run_tests:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.10" , "3.11" , "3.12"]
|
||||
|
||||
steps:
|
||||
- name: Checkout Source
|
||||
uses: actions/checkout@v2.3.1
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get install -y libopenmpi-dev
|
||||
python -m pip install --upgrade pip
|
||||
pip install jax==0.4.35
|
||||
pip install numpy setuptools cython wheel
|
||||
pip install git+https://github.com/MP-Gadget/pfft-python
|
||||
pip install git+https://github.com/MP-Gadget/pmesh
|
||||
pip install git+https://github.com/ASKabalan/fastpm-python --no-build-isolation
|
||||
pip install .[test]
|
||||
|
||||
- name: Run Single Device Tests
|
||||
run: |
|
||||
cd tests
|
||||
pytest -v -m "not distributed"
|
||||
- name: Run Distributed tests
|
||||
run: |
|
||||
pytest -v -m distributed
|
5
.gitignore
vendored
5
.gitignore
vendored
|
@ -98,6 +98,11 @@ __pypackages__/
|
|||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
|
||||
out
|
||||
traces
|
||||
*.npy
|
||||
*.out
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
|
|
|
@ -4,16 +4,15 @@
|
|||
<!-- ALL-CONTRIBUTORS-BADGE:END -->
|
||||
JAX-powered Cosmological Particle-Mesh N-body Solver
|
||||
|
||||
**This project is currently in an early design phase. All inputs are welcome on the [design document](https://github.com/DifferentiableUniverseInitiative/JaxPM/blob/main/design.md)**
|
||||
|
||||
## Goals
|
||||
|
||||
Provide a modern infrastructure to support differentiable PM N-body simulations using JAX:
|
||||
- Keep implementation simple and readable, in pure NumPy API
|
||||
- Transparent distribution using builtin `xmap`
|
||||
- Any order forward and backward automatic differentiation
|
||||
- Support automated batching using `vmap`
|
||||
- Compatibility with external optimizer libraries like `optax`
|
||||
- Now fully distributable on **multi-GPU and multi-node** systems using [jaxDecomp](https://github.com/DifferentiableUniverseInitiative/jaxDecomp) working with`JAX v0.4.35`
|
||||
|
||||
|
||||
## Open development and use
|
||||
|
||||
|
@ -23,6 +22,10 @@ Current expectations are:
|
|||
- Everyone is welcome to contribute, and can join the JOSS publication (until it is submitted to the journal).
|
||||
- Anyone (including main contributors) can use this code as a framework to build and publish their own applications, with no expectation that they *need* to extend authorship to all jaxpm developers.
|
||||
|
||||
## Getting Started
|
||||
|
||||
To dive into JaxPM’s capabilities, please explore the **notebook section** for detailed tutorials and examples on various setups, from single-device simulations to multi-host configurations. You can find the notebooks' [README here](notebooks/README.md) for a structured guide through each tutorial.
|
||||
|
||||
|
||||
## Contributors ✨
|
||||
|
||||
|
|
52
design.md
52
design.md
|
@ -1,52 +0,0 @@
|
|||
# Design Document for JaxPM
|
||||
|
||||
This document aims to detail some of the API, implementation choices, and internal mechanism.
|
||||
|
||||
## Objective
|
||||
|
||||
Provide a user-friendly framework for distributed Particle-Mesh N-body simulations.
|
||||
|
||||
## Related Work
|
||||
|
||||
This project would be the latest iteration of a number of past libraries that have provided differentiable N-body models.
|
||||
|
||||
- [FlowPM](https://github.com/DifferentiableUniverseInitiative/flowpm): TensorFlow
|
||||
- [vmad FastPM](https://github.com/rainwoodman/vmad/blob/master/vmad/lib/fastpm.py): VMAD
|
||||
- Borg
|
||||
|
||||
|
||||
In addition, a number of fast N-body simulation projets exist out there:
|
||||
- [FastPM](https://github.com/fastpm/fastpm)
|
||||
- ...
|
||||
|
||||
## Design Overview
|
||||
|
||||
### Coding principles
|
||||
|
||||
Following recent trends and JAX philosophy, the library should have a functional programming type of interface.
|
||||
|
||||
|
||||
### Illustration of API
|
||||
|
||||
Here is a potential illustration of what the user interface could be for the simulation code:
|
||||
```python
|
||||
import jaxpm as jpm
|
||||
import jax_cosmo as jc
|
||||
|
||||
# Instantiate differentiable cosmology object
|
||||
cosmo = jc.Planck()
|
||||
|
||||
# Creates initial conditions
|
||||
inital_conditions = jpm.generate_ic(cosmo, boxsize, nmesh, dtype='float32')
|
||||
|
||||
# Create a particular solver
|
||||
solver = jpm.solvers.fastpm(cosmo, B=1)
|
||||
|
||||
# Initialize and run the simulation
|
||||
state = solver.init(initial_conditions)
|
||||
state = solver.nbody(state)
|
||||
|
||||
# Painting the results
|
||||
density = jpm.zeros(boxsize, nmesh)
|
||||
density = jpm.paint(density, state.positions)
|
||||
```
|
|
@ -1,96 +0,0 @@
|
|||
# Can be executed with:
|
||||
# srun -n 4 -c 32 --gpus-per-task 1 --gpu-bind=none python test_pfft.py
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
import jax.lax as lax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from jax.experimental.maps import Mesh, xmap
|
||||
from jax.experimental.pjit import PartitionSpec, pjit
|
||||
|
||||
jax.distributed.initialize()
|
||||
|
||||
cube_size = 2048
|
||||
|
||||
|
||||
@partial(xmap,
|
||||
in_axes=[...],
|
||||
out_axes=['x', 'y', ...],
|
||||
axis_sizes={
|
||||
'x': cube_size,
|
||||
'y': cube_size
|
||||
},
|
||||
axis_resources={
|
||||
'x': 'nx',
|
||||
'y': 'ny',
|
||||
'key_x': 'nx',
|
||||
'key_y': 'ny'
|
||||
})
|
||||
def pnormal(key):
|
||||
return jax.random.normal(key, shape=[cube_size])
|
||||
|
||||
|
||||
@partial(xmap,
|
||||
in_axes={
|
||||
0: 'x',
|
||||
1: 'y'
|
||||
},
|
||||
out_axes=['x', 'y', ...],
|
||||
axis_resources={
|
||||
'x': 'nx',
|
||||
'y': 'ny'
|
||||
})
|
||||
@jax.jit
|
||||
def pfft3d(mesh):
|
||||
# [x, y, z]
|
||||
mesh = jnp.fft.fft(mesh) # Transform on z
|
||||
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now x is exposed, [z,y,x]
|
||||
mesh = jnp.fft.fft(mesh) # Transform on x
|
||||
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now y is exposed, [z,x,y]
|
||||
mesh = jnp.fft.fft(mesh) # Transform on y
|
||||
# [z, x, y]
|
||||
return mesh
|
||||
|
||||
|
||||
@partial(xmap,
|
||||
in_axes={
|
||||
0: 'x',
|
||||
1: 'y'
|
||||
},
|
||||
out_axes=['x', 'y', ...],
|
||||
axis_resources={
|
||||
'x': 'nx',
|
||||
'y': 'ny'
|
||||
})
|
||||
@jax.jit
|
||||
def pifft3d(mesh):
|
||||
# [z, x, y]
|
||||
mesh = jnp.fft.ifft(mesh) # Transform on y
|
||||
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now x is exposed, [z,y,x]
|
||||
mesh = jnp.fft.ifft(mesh) # Transform on x
|
||||
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now z is exposed, [x,y,z]
|
||||
mesh = jnp.fft.ifft(mesh) # Transform on z
|
||||
# [x, y, z]
|
||||
return mesh
|
||||
|
||||
|
||||
key = jax.random.PRNGKey(42)
|
||||
# keys = jax.random.split(key, 4).reshape((2,2,2))
|
||||
|
||||
# We reshape all our devices to the mesh shape we want
|
||||
devices = np.array(jax.devices()).reshape((2, 4))
|
||||
|
||||
with Mesh(devices, ('nx', 'ny')):
|
||||
mesh = pnormal(key)
|
||||
kmesh = pfft3d(mesh)
|
||||
kmesh.block_until_ready()
|
||||
|
||||
# jax.profiler.start_trace("tensorboard")
|
||||
# with Mesh(devices, ('nx', 'ny')):
|
||||
# mesh = pnormal(key)
|
||||
# kmesh = pfft3d(mesh)
|
||||
# kmesh.block_until_ready()
|
||||
# jax.profiler.stop_trace()
|
||||
|
||||
print('Done')
|
|
@ -1,68 +0,0 @@
|
|||
# Start this script with:
|
||||
# mpirun -np 4 python test_script.py
|
||||
import os
|
||||
|
||||
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'
|
||||
import jax
|
||||
import jax.lax as lax
|
||||
import jax.numpy as jnp
|
||||
import matplotlib.pylab as plt
|
||||
import numpy as np
|
||||
import tensorflow_probability as tfp
|
||||
from jax.experimental.maps import mesh, xmap
|
||||
from jax.experimental.pjit import PartitionSpec, pjit
|
||||
|
||||
tfp = tfp.substrates.jax
|
||||
tfd = tfp.distributions
|
||||
|
||||
|
||||
def cic_paint(mesh, positions):
|
||||
""" Paints positions onto mesh
|
||||
mesh: [nx, ny, nz]
|
||||
positions: [npart, 3]
|
||||
"""
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
floor = jnp.floor(positions)
|
||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
||||
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
|
||||
|
||||
neighboor_coords = floor + connection
|
||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||
|
||||
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
|
||||
inserted_window_dims=(0, 1, 2),
|
||||
scatter_dims_to_operand_dims=(0, 1,
|
||||
2))
|
||||
mesh = lax.scatter_add(
|
||||
mesh,
|
||||
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
|
||||
kernel.reshape([-1, 8]), dnums)
|
||||
return mesh
|
||||
|
||||
|
||||
# And let's draw some points from some 3D distribution
|
||||
dist = tfd.MultivariateNormalDiag(loc=[16., 16., 16.],
|
||||
scale_identity_multiplier=3.)
|
||||
pos = dist.sample(1e4, seed=jax.random.PRNGKey(0))
|
||||
|
||||
f = pjit(lambda x: cic_paint(x, pos),
|
||||
in_axis_resources=PartitionSpec('x', 'y', 'z'),
|
||||
out_axis_resources=None)
|
||||
|
||||
devices = np.array(jax.devices()).reshape((2, 2, 1))
|
||||
|
||||
# Let's import the mesh
|
||||
m = jnp.zeros([32, 32, 32])
|
||||
|
||||
with mesh(devices, ('x', 'y', 'z')):
|
||||
# Shard the mesh, I'm not sure this is absolutely necessary
|
||||
m = pjit(lambda x: x,
|
||||
in_axis_resources=None,
|
||||
out_axis_resources=PartitionSpec('x', 'y', 'z'))(m)
|
||||
|
||||
# Apply the sharded CiC function
|
||||
res = f(m)
|
||||
|
||||
plt.imshow(res.sum(axis=2))
|
||||
plt.show()
|
198
jaxpm/distributed.py
Normal file
198
jaxpm/distributed.py
Normal file
|
@ -0,0 +1,198 @@
|
|||
from typing import Any, Callable, Hashable
|
||||
|
||||
Specs = Any
|
||||
AxisName = Hashable
|
||||
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jaxdecomp
|
||||
from jax import lax
|
||||
from jax.experimental.shard_map import shard_map
|
||||
from jax.sharding import AbstractMesh, Mesh
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
|
||||
def autoshmap(
|
||||
f: Callable,
|
||||
gpu_mesh: Mesh | AbstractMesh | None,
|
||||
in_specs: Specs,
|
||||
out_specs: Specs,
|
||||
check_rep: bool = False,
|
||||
auto: frozenset[AxisName] = frozenset()) -> Callable:
|
||||
"""Helper function to wrap the provided function in a shard map if
|
||||
the code is being executed in a mesh context."""
|
||||
if gpu_mesh is None or gpu_mesh.empty:
|
||||
return f
|
||||
else:
|
||||
return shard_map(f, gpu_mesh, in_specs, out_specs, check_rep, auto)
|
||||
|
||||
|
||||
def fft3d(x):
|
||||
return jaxdecomp.pfft3d(x)
|
||||
|
||||
|
||||
def ifft3d(x):
|
||||
return jaxdecomp.pifft3d(x).real
|
||||
|
||||
|
||||
def get_halo_size(halo_size, sharding):
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
if gpu_mesh is None or gpu_mesh.empty:
|
||||
zero_ext = (0, 0)
|
||||
zero_tuple = (0, 0)
|
||||
return (zero_tuple, zero_tuple, zero_tuple), zero_ext
|
||||
else:
|
||||
pdims = gpu_mesh.devices.shape
|
||||
halo_x = (0, 0) if pdims[0] == 1 else (halo_size, halo_size)
|
||||
halo_y = (0, 0) if pdims[1] == 1 else (halo_size, halo_size)
|
||||
|
||||
halo_x_ext = 0 if pdims[0] == 1 else halo_size // 2
|
||||
halo_y_ext = 0 if pdims[1] == 1 else halo_size // 2
|
||||
return ((halo_x, halo_y, (0, 0)), (halo_x_ext, halo_y_ext))
|
||||
|
||||
|
||||
def halo_exchange(x, halo_extents, halo_periods=(True, True)):
|
||||
if (halo_extents[0] > 0 or halo_extents[1] > 0):
|
||||
return jaxdecomp.halo_exchange(x, halo_extents, halo_periods)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def slice_unpad_impl(x, pad_width):
|
||||
|
||||
halo_x, _ = pad_width[0]
|
||||
halo_y, _ = pad_width[1]
|
||||
# Apply corrections along x
|
||||
x = x.at[halo_x:halo_x + halo_x // 2].add(x[:halo_x // 2])
|
||||
x = x.at[-(halo_x + halo_x // 2):-halo_x].add(x[-halo_x // 2:])
|
||||
# Apply corrections along y
|
||||
x = x.at[:, halo_y:halo_y + halo_y // 2].add(x[:, :halo_y // 2])
|
||||
x = x.at[:, -(halo_y + halo_y // 2):-halo_y].add(x[:, -halo_y // 2:])
|
||||
|
||||
unpad_slice = [slice(None)] * 3
|
||||
if halo_x > 0:
|
||||
unpad_slice[0] = slice(halo_x, -halo_x)
|
||||
if halo_y > 0:
|
||||
unpad_slice[1] = slice(halo_y, -halo_y)
|
||||
|
||||
return x[tuple(unpad_slice)]
|
||||
|
||||
|
||||
def slice_pad(x, pad_width, sharding):
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
if gpu_mesh is not None and not (gpu_mesh.empty) and (
|
||||
pad_width[0][0] > 0 or pad_width[1][0] > 0):
|
||||
assert sharding is not None
|
||||
spec = sharding.spec
|
||||
return shard_map((partial(jnp.pad, pad_width=pad_width)),
|
||||
mesh=gpu_mesh,
|
||||
in_specs=spec,
|
||||
out_specs=spec)(x)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def slice_unpad(x, pad_width, sharding):
|
||||
mesh = sharding.mesh if sharding is not None else None
|
||||
if mesh is not None and not (mesh.empty) and (pad_width[0][0] > 0
|
||||
or pad_width[1][0] > 0):
|
||||
assert sharding is not None
|
||||
spec = sharding.spec
|
||||
return shard_map(partial(slice_unpad_impl, pad_width=pad_width),
|
||||
mesh=mesh,
|
||||
in_specs=spec,
|
||||
out_specs=spec)(x)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def get_local_shape(mesh_shape, sharding=None):
|
||||
""" Helper function to get the local size of a mesh given the global size.
|
||||
"""
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
if gpu_mesh is None or gpu_mesh.empty:
|
||||
return mesh_shape
|
||||
else:
|
||||
pdims = gpu_mesh.devices.shape
|
||||
return [
|
||||
mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1],
|
||||
*mesh_shape[2:]
|
||||
]
|
||||
|
||||
|
||||
def _axis_names(spec):
|
||||
if len(spec) == 1:
|
||||
x_axis, = spec
|
||||
y_axis = None
|
||||
single_axis = True
|
||||
elif len(spec) == 2:
|
||||
x_axis, y_axis = spec
|
||||
if y_axis == None:
|
||||
single_axis = True
|
||||
elif x_axis == None:
|
||||
x_axis = y_axis
|
||||
single_axis = True
|
||||
else:
|
||||
single_axis = False
|
||||
else:
|
||||
raise ValueError("Only 1 or 2 axis sharding is supported")
|
||||
return x_axis, y_axis, single_axis
|
||||
|
||||
|
||||
def uniform_particles(mesh_shape, sharding=None):
|
||||
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
if gpu_mesh is not None and not (gpu_mesh.empty):
|
||||
local_mesh_shape = get_local_shape(mesh_shape, sharding)
|
||||
spec = sharding.spec
|
||||
x_axis, y_axis, single_axis = _axis_names(spec)
|
||||
|
||||
def particles():
|
||||
x_indx = lax.axis_index(x_axis)
|
||||
y_indx = 0 if single_axis else lax.axis_index(y_axis)
|
||||
|
||||
x = jnp.arange(local_mesh_shape[0]) + x_indx * local_mesh_shape[0]
|
||||
y = jnp.arange(local_mesh_shape[1]) + y_indx * local_mesh_shape[1]
|
||||
z = jnp.arange(local_mesh_shape[2])
|
||||
return jnp.stack(jnp.meshgrid(x, y, z, indexing='ij'), axis=-1)
|
||||
|
||||
return shard_map(particles, mesh=gpu_mesh, in_specs=(),
|
||||
out_specs=spec)()
|
||||
else:
|
||||
return jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape],
|
||||
indexing='ij'),
|
||||
axis=-1)
|
||||
|
||||
|
||||
def normal_field(mesh_shape, seed, sharding=None):
|
||||
"""Generate a Gaussian random field with the given power spectrum."""
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
if gpu_mesh is not None and not (gpu_mesh.empty):
|
||||
local_mesh_shape = get_local_shape(mesh_shape, sharding)
|
||||
|
||||
size = jax.device_count()
|
||||
# rank = jax.process_index()
|
||||
# process_index is multi_host only
|
||||
# to make the code work both in multi host and single controller we can do this trick
|
||||
keys = jax.random.split(seed, size)
|
||||
spec = sharding.spec
|
||||
x_axis, y_axis, single_axis = _axis_names(spec)
|
||||
|
||||
def normal(keys, shape, dtype):
|
||||
idx = lax.axis_index(x_axis)
|
||||
if not single_axis:
|
||||
y_index = lax.axis_index(y_axis)
|
||||
x_size = lax.psum(1, axis_name=x_axis)
|
||||
idx += y_index * x_size
|
||||
|
||||
return jax.random.normal(key=keys[idx], shape=shape, dtype=dtype)
|
||||
|
||||
return shard_map(
|
||||
partial(normal, shape=local_mesh_shape, dtype='float32'),
|
||||
mesh=gpu_mesh,
|
||||
in_specs=P(None),
|
||||
out_specs=spec)(keys) # yapf: disable
|
||||
else:
|
||||
return jax.random.normal(shape=mesh_shape, key=seed)
|
|
@ -1,6 +1,6 @@
|
|||
import jax.numpy as np
|
||||
from jax.numpy import interp
|
||||
from jax_cosmo.background import *
|
||||
from jax_cosmo.scipy.interpolate import interp
|
||||
from jax_cosmo.scipy.ode import odeint
|
||||
|
||||
|
||||
|
@ -587,5 +587,6 @@ def dGf2a(cosmo, a):
|
|||
cache = cosmo._workspace['background.growth_factor']
|
||||
f2p = cache['h2'] / cache['a'] * cache['g2']
|
||||
f2p = interp(np.log(a), np.log(cache['a']), f2p)
|
||||
E = E(cosmo, a)
|
||||
return (f2p * a**3 * E + D2f * a**3 * dEa(cosmo, a) + 3 * a**2 * E * D2f)
|
||||
E_a = E(cosmo, a)
|
||||
return (f2p * a**3 * E_a + D2f * a**3 * dEa(cosmo, a) +
|
||||
3 * a**2 * E_a * D2f)
|
||||
|
|
|
@ -1,30 +1,46 @@
|
|||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from jax.lax import FftType
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jaxdecomp import fftfreq3d, get_output_specs
|
||||
|
||||
from jaxpm.distributed import autoshmap
|
||||
|
||||
|
||||
def fftk(shape, symmetric=True, finite=False, dtype=np.float32):
|
||||
def fftk(k_array):
|
||||
"""
|
||||
Return wave-vectors for a given shape
|
||||
"""
|
||||
k = []
|
||||
for d in range(len(shape)):
|
||||
kd = np.fft.fftfreq(shape[d])
|
||||
kd *= 2 * np.pi
|
||||
kdshape = np.ones(len(shape), dtype='int')
|
||||
if symmetric and d == len(shape) - 1:
|
||||
kd = kd[:shape[d] // 2 + 1]
|
||||
kdshape[d] = len(kd)
|
||||
kd = kd.reshape(kdshape)
|
||||
Generate Fourier transform wave numbers for a given mesh.
|
||||
|
||||
k.append(kd.astype(dtype))
|
||||
del kd, kdshape
|
||||
return k
|
||||
Args:
|
||||
nc (int): Shape of the mesh grid.
|
||||
|
||||
Returns:
|
||||
list: List of wave number arrays for each dimension in
|
||||
the order [kx, ky, kz].
|
||||
"""
|
||||
kx, ky, kz = fftfreq3d(k_array)
|
||||
# to the order of dimensions in the transposed FFT
|
||||
return kx, ky, kz
|
||||
|
||||
|
||||
def interpolate_power_spectrum(input, k, pk, sharding=None):
|
||||
|
||||
pk_fn = lambda x: jnp.interp(x.reshape(-1), k, pk).reshape(x.shape)
|
||||
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
specs = sharding.spec if sharding is not None else P()
|
||||
out_specs = P(*get_output_specs(
|
||||
FftType.FFT, specs, mesh=gpu_mesh)) if gpu_mesh is not None else P()
|
||||
|
||||
return autoshmap(pk_fn,
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=out_specs,
|
||||
out_specs=out_specs)(input)
|
||||
|
||||
|
||||
def gradient_kernel(kvec, direction, order=1):
|
||||
"""
|
||||
Computes the gradient kernel in the requested direction
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
kvec: list
|
||||
|
@ -50,23 +66,30 @@ def gradient_kernel(kvec, direction, order=1):
|
|||
return wts
|
||||
|
||||
|
||||
def invlaplace_kernel(kvec):
|
||||
def invlaplace_kernel(kvec, fd=False):
|
||||
"""
|
||||
Compute the inverse Laplace kernel
|
||||
Compute the inverse Laplace kernel.
|
||||
|
||||
cf. [Feng+2016](https://arxiv.org/pdf/1603.00476)
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
kvec: list
|
||||
List of wave-vectors
|
||||
fd: bool
|
||||
Finite difference kernel
|
||||
|
||||
Returns
|
||||
--------
|
||||
wts: array
|
||||
Complex kernel values
|
||||
"""
|
||||
if fd:
|
||||
kk = sum((ki * jnp.sinc(ki / (2 * jnp.pi)))**2 for ki in kvec)
|
||||
else:
|
||||
kk = sum(ki**2 for ki in kvec)
|
||||
kk_nozeros = jnp.where(kk==0, 1, kk)
|
||||
return - jnp.where(kk==0, 0, 1 / kk_nozeros)
|
||||
kk_nozeros = jnp.where(kk == 0, 1, kk)
|
||||
return -jnp.where(kk == 0, 0, 1 / kk_nozeros)
|
||||
|
||||
|
||||
def longrange_kernel(kvec, r_split):
|
||||
|
@ -79,12 +102,10 @@ def longrange_kernel(kvec, r_split):
|
|||
List of wave-vectors
|
||||
r_split: float
|
||||
Splitting radius
|
||||
|
||||
Returns
|
||||
--------
|
||||
wts: array
|
||||
Complex kernel values
|
||||
|
||||
TODO: @modichirag add documentation
|
||||
"""
|
||||
if r_split != 0:
|
||||
|
@ -105,13 +126,12 @@ def cic_compensation(kvec):
|
|||
-----------
|
||||
kvec: list
|
||||
List of wave-vectors
|
||||
|
||||
Returns:
|
||||
--------
|
||||
wts: array
|
||||
Complex kernel values
|
||||
"""
|
||||
kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
|
||||
kwts = [jnp.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
|
||||
wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
|
||||
return wts
|
||||
|
||||
|
|
|
@ -1,15 +1,24 @@
|
|||
from functools import partial
|
||||
|
||||
import jax
|
||||
import jax.lax as lax
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import NamedSharding
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
from jaxpm.distributed import (autoshmap, fft3d, get_halo_size, halo_exchange,
|
||||
ifft3d, slice_pad, slice_unpad)
|
||||
from jaxpm.kernels import cic_compensation, fftk
|
||||
from jaxpm.painting_utils import gather, scatter
|
||||
|
||||
|
||||
def cic_paint(mesh, positions, weight=None):
|
||||
def _cic_paint_impl(grid_mesh, positions, weight=None):
|
||||
""" Paints positions onto mesh
|
||||
mesh: [nx, ny, nz]
|
||||
positions: [npart, 3]
|
||||
displacement field: [nx, ny, nz, 3]
|
||||
"""
|
||||
|
||||
positions = positions.reshape([-1, 3])
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
floor = jnp.floor(positions)
|
||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
||||
|
@ -19,40 +28,98 @@ def cic_paint(mesh, positions, weight=None):
|
|||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||
if weight is not None:
|
||||
if jnp.isscalar(weight):
|
||||
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
|
||||
else:
|
||||
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]),
|
||||
kernel)
|
||||
|
||||
neighboor_coords = jnp.mod(
|
||||
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
|
||||
jnp.array(mesh.shape))
|
||||
jnp.array(grid_mesh.shape))
|
||||
|
||||
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
|
||||
inserted_window_dims=(0, 1, 2),
|
||||
scatter_dims_to_operand_dims=(0, 1,
|
||||
2))
|
||||
mesh = lax.scatter_add(mesh, neighboor_coords, kernel.reshape([-1, 8]),
|
||||
dnums)
|
||||
mesh = lax.scatter_add(grid_mesh, neighboor_coords,
|
||||
kernel.reshape([-1, 8]), dnums)
|
||||
return mesh
|
||||
|
||||
|
||||
def cic_read(mesh, positions):
|
||||
@partial(jax.jit, static_argnums=(3, 4))
|
||||
def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
|
||||
|
||||
positions = positions.reshape((*grid_mesh.shape, 3))
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size, sharding)
|
||||
grid_mesh = slice_pad(grid_mesh, halo_size, sharding)
|
||||
|
||||
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
||||
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
||||
grid_mesh = autoshmap(_cic_paint_impl,
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=(spec, spec, P()),
|
||||
out_specs=spec)(grid_mesh, positions, weight)
|
||||
grid_mesh = halo_exchange(grid_mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True))
|
||||
grid_mesh = slice_unpad(grid_mesh, halo_size, sharding)
|
||||
|
||||
return grid_mesh
|
||||
|
||||
|
||||
def _cic_read_impl(grid_mesh, positions):
|
||||
""" Paints positions onto mesh
|
||||
mesh: [nx, ny, nz]
|
||||
positions: [npart, 3]
|
||||
positions: [nx,ny,nz, 3]
|
||||
"""
|
||||
# Save original shape for reshaping output later
|
||||
original_shape = positions.shape
|
||||
# Reshape positions to a flat list of 3D coordinates
|
||||
positions = positions.reshape([-1, 3])
|
||||
# Expand dimensions to calculate neighbor coordinates
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
# Floor the positions to get the base grid cell for each particle
|
||||
floor = jnp.floor(positions)
|
||||
# Define connections to calculate all neighbor coordinates
|
||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
||||
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
|
||||
|
||||
# Calculate the 8 neighboring coordinates
|
||||
neighboor_coords = floor + connection
|
||||
# Calculate kernel weights based on distance from each neighboring coordinate
|
||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||
|
||||
# Modulo operation to wrap around edges if necessary
|
||||
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'),
|
||||
jnp.array(mesh.shape))
|
||||
jnp.array(grid_mesh.shape))
|
||||
# Ensure grid_mesh shape is as expected
|
||||
# Retrieve values from grid_mesh at each neighboring coordinate and multiply by kernel
|
||||
return (grid_mesh[neighboor_coords[..., 0],
|
||||
neighboor_coords[..., 1],
|
||||
neighboor_coords[..., 2]] * kernel).sum(axis=-1).reshape(original_shape[:-1]) # yapf: disable
|
||||
|
||||
return (mesh[neighboor_coords[..., 0], neighboor_coords[..., 1],
|
||||
neighboor_coords[..., 3]] * kernel).sum(axis=-1)
|
||||
|
||||
@partial(jax.jit, static_argnums=(2, 3))
|
||||
def cic_read(grid_mesh, positions, halo_size=0, sharding=None):
|
||||
|
||||
original_shape = positions.shape
|
||||
positions = positions.reshape((*grid_mesh.shape, 3))
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
|
||||
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
|
||||
grid_mesh = halo_exchange(grid_mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True))
|
||||
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
||||
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
||||
|
||||
displacement = autoshmap(_cic_read_impl,
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=(spec, spec),
|
||||
out_specs=spec)(grid_mesh, positions)
|
||||
|
||||
return displacement.reshape(original_shape[:-1])
|
||||
|
||||
|
||||
def cic_paint_2d(mesh, positions, weight):
|
||||
|
@ -84,6 +151,99 @@ def cic_paint_2d(mesh, positions, weight):
|
|||
return mesh
|
||||
|
||||
|
||||
def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
|
||||
|
||||
halo_x, _ = halo_size[0]
|
||||
halo_y, _ = halo_size[1]
|
||||
|
||||
original_shape = displacements.shape
|
||||
particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32')
|
||||
if not jnp.isscalar(weight):
|
||||
if weight.shape != original_shape[:-1]:
|
||||
raise ValueError("Weight shape must match particle shape")
|
||||
else:
|
||||
weight = weight.flatten()
|
||||
# Padding is forced to be zero in a single gpu run
|
||||
|
||||
a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]),
|
||||
jnp.arange(particle_mesh.shape[1]),
|
||||
jnp.arange(particle_mesh.shape[2]),
|
||||
indexing='ij')
|
||||
|
||||
particle_mesh = jnp.pad(particle_mesh, halo_size)
|
||||
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
|
||||
return scatter(pmid.reshape([-1, 3]),
|
||||
displacements.reshape([-1, 3]),
|
||||
particle_mesh,
|
||||
chunk_size=2**24,
|
||||
val=weight)
|
||||
|
||||
|
||||
@partial(jax.jit, static_argnums=(1, 2, 4))
|
||||
def cic_paint_dx(displacements,
|
||||
halo_size=0,
|
||||
sharding=None,
|
||||
weight=1.0,
|
||||
chunk_size=2**24):
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
|
||||
|
||||
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
||||
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
||||
grid_mesh = autoshmap(partial(_cic_paint_dx_impl,
|
||||
halo_size=halo_size,
|
||||
weight=weight,
|
||||
chunk_size=chunk_size),
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=spec,
|
||||
out_specs=spec)(displacements)
|
||||
|
||||
grid_mesh = halo_exchange(grid_mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True))
|
||||
grid_mesh = slice_unpad(grid_mesh, halo_size, sharding)
|
||||
return grid_mesh
|
||||
|
||||
|
||||
def _cic_read_dx_impl(grid_mesh, disp, halo_size):
|
||||
|
||||
halo_x, _ = halo_size[0]
|
||||
halo_y, _ = halo_size[1]
|
||||
|
||||
original_shape = [
|
||||
dim - 2 * halo[0] for dim, halo in zip(grid_mesh.shape, halo_size)
|
||||
]
|
||||
a, b, c = jnp.meshgrid(jnp.arange(original_shape[0]),
|
||||
jnp.arange(original_shape[1]),
|
||||
jnp.arange(original_shape[2]),
|
||||
indexing='ij')
|
||||
|
||||
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
|
||||
|
||||
pmid = pmid.reshape([-1, 3])
|
||||
disp = disp.reshape([-1, 3])
|
||||
|
||||
return gather(pmid, disp, grid_mesh).reshape(original_shape)
|
||||
|
||||
|
||||
@partial(jax.jit, static_argnums=(2, 3))
|
||||
def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None):
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
|
||||
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
|
||||
grid_mesh = halo_exchange(grid_mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True))
|
||||
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
||||
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
||||
displacements = autoshmap(partial(_cic_read_dx_impl, halo_size=halo_size),
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=(spec),
|
||||
out_specs=spec)(grid_mesh, disp)
|
||||
|
||||
return displacements
|
||||
|
||||
|
||||
def compensate_cic(field):
|
||||
"""
|
||||
Compensate for CiC painting
|
||||
|
@ -92,9 +252,8 @@ def compensate_cic(field):
|
|||
Returns:
|
||||
compensated_field
|
||||
"""
|
||||
nc = field.shape
|
||||
kvec = fftk(nc)
|
||||
delta_k = fft3d(field)
|
||||
|
||||
delta_k = jnp.fft.rfftn(field)
|
||||
kvec = fftk(delta_k)
|
||||
delta_k = cic_compensation(kvec) * delta_k
|
||||
return jnp.fft.irfftn(delta_k)
|
||||
return ifft3d(delta_k)
|
||||
|
|
190
jaxpm/painting_utils.py
Normal file
190
jaxpm/painting_utils.py
Normal file
|
@ -0,0 +1,190 @@
|
|||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.lax import scan
|
||||
|
||||
|
||||
def _chunk_split(ptcl_num, chunk_size, *arrays):
|
||||
"""Split and reshape particle arrays into chunks and remainders, with the remainders
|
||||
preceding the chunks. 0D ones are duplicated as full arrays in the chunks."""
|
||||
chunk_size = ptcl_num if chunk_size is None else min(chunk_size, ptcl_num)
|
||||
remainder_size = ptcl_num % chunk_size
|
||||
chunk_num = ptcl_num // chunk_size
|
||||
|
||||
remainder = None
|
||||
chunks = arrays
|
||||
if remainder_size:
|
||||
remainder = [x[:remainder_size] if x.ndim != 0 else x for x in arrays]
|
||||
chunks = [x[remainder_size:] if x.ndim != 0 else x for x in arrays]
|
||||
|
||||
# `scan` triggers errors in scatter and gather without the `full`
|
||||
chunks = [
|
||||
x.reshape(chunk_num, chunk_size, *x.shape[1:])
|
||||
if x.ndim != 0 else jnp.full(chunk_num, x) for x in chunks
|
||||
]
|
||||
|
||||
return remainder, chunks
|
||||
|
||||
|
||||
def enmesh(base_indices, displacements, cell_size, base_shape, offset,
|
||||
new_cell_size, new_shape):
|
||||
"""Multilinear enmeshing."""
|
||||
base_indices = jnp.asarray(base_indices)
|
||||
displacements = jnp.asarray(displacements)
|
||||
with jax.experimental.enable_x64():
|
||||
cell_size = jnp.float64(
|
||||
cell_size) if new_cell_size is not None else jnp.array(
|
||||
cell_size, dtype=displacements.dtype)
|
||||
if base_shape is not None:
|
||||
base_shape = jnp.array(base_shape, dtype=base_indices.dtype)
|
||||
offset = jnp.float64(offset)
|
||||
if new_cell_size is not None:
|
||||
new_cell_size = jnp.float64(new_cell_size)
|
||||
if new_shape is not None:
|
||||
new_shape = jnp.array(new_shape, dtype=base_indices.dtype)
|
||||
|
||||
spatial_dim = base_indices.shape[1]
|
||||
neighbor_offsets = (
|
||||
jnp.arange(2**spatial_dim, dtype=base_indices.dtype)[:, jnp.newaxis] >>
|
||||
jnp.arange(spatial_dim, dtype=base_indices.dtype)) & 1
|
||||
|
||||
if new_cell_size is not None:
|
||||
particle_positions = base_indices * cell_size + displacements - offset
|
||||
particle_positions = particle_positions[:, jnp.
|
||||
newaxis] # insert neighbor axis
|
||||
new_indices = particle_positions + neighbor_offsets * new_cell_size # multilinear
|
||||
|
||||
if base_shape is not None:
|
||||
grid_length = base_shape * cell_size
|
||||
new_indices %= grid_length
|
||||
|
||||
new_indices //= new_cell_size
|
||||
new_displacements = particle_positions - new_indices * new_cell_size
|
||||
|
||||
if base_shape is not None:
|
||||
new_displacements -= jnp.rint(
|
||||
new_displacements / grid_length
|
||||
) * grid_length # also abs(new_displacements) < new_cell_size is expected
|
||||
|
||||
new_indices = new_indices.astype(base_indices.dtype)
|
||||
new_displacements = new_displacements.astype(displacements.dtype)
|
||||
new_cell_size = new_cell_size.astype(displacements.dtype)
|
||||
|
||||
new_displacements /= new_cell_size
|
||||
else:
|
||||
offset_indices, offset_displacements = jnp.divmod(offset, cell_size)
|
||||
base_indices -= offset_indices.astype(base_indices.dtype)
|
||||
displacements -= offset_displacements.astype(displacements.dtype)
|
||||
|
||||
# insert neighbor axis
|
||||
base_indices = base_indices[:, jnp.newaxis]
|
||||
displacements = displacements[:, jnp.newaxis]
|
||||
|
||||
# multilinear
|
||||
displacements /= cell_size
|
||||
new_indices = jnp.floor(displacements).astype(base_indices.dtype)
|
||||
new_indices += neighbor_offsets
|
||||
new_displacements = displacements - new_indices
|
||||
new_indices += base_indices
|
||||
|
||||
if base_shape is not None:
|
||||
new_indices %= base_shape
|
||||
|
||||
weights = 1 - jnp.abs(new_displacements)
|
||||
|
||||
if base_shape is None and new_shape is not None: # all new_indices >= 0 if base_shape is not None
|
||||
new_indices = jnp.where(new_indices < 0, new_shape, new_indices)
|
||||
|
||||
weights = weights.prod(axis=-1)
|
||||
|
||||
return new_indices, weights
|
||||
|
||||
|
||||
def _scatter_chunk(carry, chunk):
|
||||
mesh, offset, cell_size, mesh_shape = carry
|
||||
pmid, disp, val = chunk
|
||||
spatial_ndim = pmid.shape[1]
|
||||
spatial_shape = mesh.shape
|
||||
|
||||
# multilinear mesh indices and fractions
|
||||
ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size,
|
||||
spatial_shape)
|
||||
# scatter
|
||||
ind = tuple(ind[..., i] for i in range(spatial_ndim))
|
||||
mesh = mesh.at[ind].add(jnp.multiply(jnp.expand_dims(val, axis=-1), frac))
|
||||
carry = mesh, offset, cell_size, mesh_shape
|
||||
return carry, None
|
||||
|
||||
|
||||
def scatter(pmid,
|
||||
disp,
|
||||
mesh,
|
||||
chunk_size=2**24,
|
||||
val=1.,
|
||||
offset=0,
|
||||
cell_size=1.):
|
||||
ptcl_num, spatial_ndim = pmid.shape
|
||||
val = jnp.asarray(val)
|
||||
mesh = jnp.asarray(mesh)
|
||||
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
|
||||
carry = mesh, offset, cell_size, mesh.shape
|
||||
if remainder is not None:
|
||||
carry = _scatter_chunk(carry, remainder)[0]
|
||||
carry = scan(_scatter_chunk, carry, chunks)[0]
|
||||
mesh = carry[0]
|
||||
return mesh
|
||||
|
||||
|
||||
def _chunk_cat(remainder_array, chunked_array):
|
||||
"""Reshape and concatenate one remainder and one chunked particle arrays."""
|
||||
array = chunked_array.reshape(-1, *chunked_array.shape[2:])
|
||||
|
||||
if remainder_array is not None:
|
||||
array = jnp.concatenate((remainder_array, array), axis=0)
|
||||
|
||||
return array
|
||||
|
||||
|
||||
def gather(pmid, disp, mesh, chunk_size=2**24, val=0, offset=0, cell_size=1.):
|
||||
ptcl_num, spatial_ndim = pmid.shape
|
||||
|
||||
mesh = jnp.asarray(mesh)
|
||||
|
||||
val = jnp.asarray(val)
|
||||
|
||||
if mesh.shape[spatial_ndim:] != val.shape[1:]:
|
||||
raise ValueError('channel shape mismatch: '
|
||||
f'{mesh.shape[spatial_ndim:]} != {val.shape[1:]}')
|
||||
|
||||
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
|
||||
|
||||
carry = mesh, offset, cell_size, mesh.shape
|
||||
val_0 = None
|
||||
if remainder is not None:
|
||||
val_0 = _gather_chunk(carry, remainder)[1]
|
||||
val = scan(_gather_chunk, carry, chunks)[1]
|
||||
|
||||
val = _chunk_cat(val_0, val)
|
||||
|
||||
return val
|
||||
|
||||
|
||||
def _gather_chunk(carry, chunk):
|
||||
mesh, offset, cell_size, mesh_shape = carry
|
||||
pmid, disp, val = chunk
|
||||
|
||||
spatial_ndim = pmid.shape[1]
|
||||
|
||||
spatial_shape = mesh.shape[:spatial_ndim]
|
||||
chan_ndim = mesh.ndim - spatial_ndim
|
||||
chan_axis = tuple(range(-chan_ndim, 0))
|
||||
|
||||
# multilinear mesh indices and fractions
|
||||
ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size,
|
||||
spatial_shape)
|
||||
|
||||
# gather
|
||||
ind = tuple(ind[..., i] for i in range(spatial_ndim))
|
||||
frac = jnp.expand_dims(frac, chan_axis)
|
||||
val += (mesh.at[ind].get(mode='drop', fill_value=0) * frac).sum(axis=1)
|
||||
|
||||
return carry, val
|
129
jaxpm/plotting.py
Normal file
129
jaxpm/plotting.py
Normal file
|
@ -0,0 +1,129 @@
|
|||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
|
||||
def plot_fields(fields_dict, sum_over=None):
|
||||
"""
|
||||
Plots sum projections of 3D fields along different axes,
|
||||
slicing only the first `sum_over` elements along each axis.
|
||||
|
||||
Args:
|
||||
- fields: list of 3D arrays representing fields to plot
|
||||
- names: list of names for each field, used in titles
|
||||
- sum_over: number of slices to sum along each axis (default: fields[0].shape[0] // 8)
|
||||
"""
|
||||
sum_over = sum_over or list(fields_dict.values())[0].shape[0] // 8
|
||||
nb_rows = len(fields_dict)
|
||||
nb_cols = 3
|
||||
fig, axes = plt.subplots(nb_rows, nb_cols, figsize=(15, 5 * nb_rows))
|
||||
|
||||
def plot_subplots(proj_axis, field, row, title):
|
||||
slicing = [slice(None)] * field.ndim
|
||||
slicing[proj_axis] = slice(None, sum_over)
|
||||
slicing = tuple(slicing)
|
||||
|
||||
# Sum projection over the specified axis and plot
|
||||
axes[row, proj_axis].imshow(
|
||||
field[slicing].sum(axis=proj_axis) + 1,
|
||||
cmap='magma',
|
||||
extent=[0, field.shape[proj_axis], 0, field.shape[proj_axis]])
|
||||
axes[row, proj_axis].set_xlabel('Mpc/h')
|
||||
axes[row, proj_axis].set_ylabel('Mpc/h')
|
||||
axes[row, proj_axis].set_title(title)
|
||||
|
||||
# Plot each field across the three axes
|
||||
for i, (name, field) in enumerate(fields_dict.items()):
|
||||
for proj_axis in range(3):
|
||||
plot_subplots(proj_axis, field, i,
|
||||
f"{name} projection {proj_axis}")
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_fields_single_projection(fields_dict,
|
||||
sum_over=None,
|
||||
project_axis=0,
|
||||
vmin=None,
|
||||
vmax=None,
|
||||
colorbar=False):
|
||||
"""
|
||||
Plots a single projection (along axis 0) of 3D fields in a grid,
|
||||
summing over the first `sum_over` elements along the 0-axis, with 4 images per row.
|
||||
|
||||
Args:
|
||||
- fields_dict: dictionary where keys are field names and values are 3D arrays
|
||||
- sum_over: number of slices to sum along the projection axis (default: fields[0].shape[0] // 8)
|
||||
"""
|
||||
sum_over = sum_over or list(fields_dict.values())[0].shape[0] // 8
|
||||
nb_fields = len(fields_dict)
|
||||
nb_cols = 4 # Set number of images per row
|
||||
nb_rows = (nb_fields + nb_cols - 1) // nb_cols # Calculate required rows
|
||||
|
||||
fig, axes = plt.subplots(nb_rows,
|
||||
nb_cols,
|
||||
figsize=(5 * nb_cols, 5 * nb_rows))
|
||||
axes = np.atleast_2d(axes) # Ensure axes is always a 2D array
|
||||
|
||||
for i, (name, field) in enumerate(fields_dict.items()):
|
||||
row, col = divmod(i, nb_cols)
|
||||
|
||||
# Define the slice for the 0-axis projection
|
||||
slicing = [slice(None)] * field.ndim
|
||||
slicing[project_axis] = slice(None, sum_over)
|
||||
slicing = tuple(slicing)
|
||||
|
||||
# Sum projection over axis 0 and plot
|
||||
a = axes[row,
|
||||
col].imshow(field[slicing].sum(axis=project_axis) + 1,
|
||||
cmap='magma',
|
||||
extent=[0, field.shape[1], 0, field.shape[2]],
|
||||
vmin=vmin,
|
||||
vmax=vmax)
|
||||
axes[row, col].set_xlabel('Mpc/h')
|
||||
axes[row, col].set_ylabel('Mpc/h')
|
||||
axes[row, col].set_title(f"{name} projection 0")
|
||||
if colorbar:
|
||||
fig.colorbar(a, ax=axes[row, col], shrink=0.7)
|
||||
|
||||
# Remove any empty subplots
|
||||
for j in range(i + 1, nb_rows * nb_cols):
|
||||
fig.delaxes(axes.flatten()[j])
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
def stack_slices(array):
|
||||
"""
|
||||
Stacks 2D slices of an array into a single array based on provided partition dimensions.
|
||||
|
||||
Args:
|
||||
- array_slices: a 2D list of array slices (list of lists format) where
|
||||
array_slices[i][j] is the slice located at row i, column j in the grid.
|
||||
- pdims: a tuple representing the grid dimensions (rows, columns).
|
||||
|
||||
Returns:
|
||||
- A single array constructed by stacking the slices.
|
||||
"""
|
||||
# Initialize an empty list to store the vertically stacked rows
|
||||
pdims = array.sharding.mesh.devices.shape
|
||||
|
||||
field_slices = []
|
||||
|
||||
# Iterate over rows in pdims[0]
|
||||
for i in range(pdims[0]):
|
||||
row_slices = []
|
||||
|
||||
# Iterate over columns in pdims[1]
|
||||
for j in range(pdims[1]):
|
||||
slice_index = i * pdims[0] + j
|
||||
row_slices.append(array.addressable_data(slice_index))
|
||||
# Stack the current row of slices vertically
|
||||
stacked_row = np.hstack(row_slices)
|
||||
field_slices.append(stacked_row)
|
||||
|
||||
# Stack all rows horizontally to form the full array
|
||||
full_array = np.vstack(field_slices)
|
||||
|
||||
return full_array
|
186
jaxpm/pm.py
186
jaxpm/pm.py
|
@ -1,50 +1,92 @@
|
|||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax_cosmo as jc
|
||||
from jax_cosmo import Cosmology
|
||||
|
||||
from jaxpm.growth import growth_factor, growth_rate, dGfa, growth_factor_second, growth_rate_second, dGf2a
|
||||
from jaxpm.kernels import PGD_kernel, fftk, gradient_kernel, invlaplace_kernel, longrange_kernel
|
||||
from jaxpm.painting import cic_paint, cic_read
|
||||
from jaxpm.distributed import fft3d, ifft3d, normal_field
|
||||
from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second,
|
||||
growth_rate, growth_rate_second)
|
||||
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel,
|
||||
invlaplace_kernel, longrange_kernel)
|
||||
from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx
|
||||
|
||||
|
||||
|
||||
def pm_forces(positions, mesh_shape, delta=None, r_split=0):
|
||||
def pm_forces(positions,
|
||||
mesh_shape=None,
|
||||
delta=None,
|
||||
r_split=0,
|
||||
paint_absolute_pos=True,
|
||||
halo_size=0,
|
||||
sharding=None):
|
||||
"""
|
||||
Computes gravitational forces on particles using a PM scheme
|
||||
"""
|
||||
if mesh_shape is None:
|
||||
assert (delta is not None),\
|
||||
"If mesh_shape is not provided, delta should be provided"
|
||||
mesh_shape = delta.shape
|
||||
|
||||
if paint_absolute_pos:
|
||||
paint_fn = lambda pos: cic_paint(jnp.zeros(shape=mesh_shape,
|
||||
device=sharding),
|
||||
pos,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
read_fn = lambda grid_mesh, pos: cic_read(
|
||||
grid_mesh, pos, halo_size=halo_size, sharding=sharding)
|
||||
else:
|
||||
paint_fn = lambda disp: cic_paint_dx(
|
||||
disp, halo_size=halo_size, sharding=sharding)
|
||||
read_fn = lambda grid_mesh, disp: cic_read_dx(
|
||||
grid_mesh, disp, halo_size=halo_size, sharding=sharding)
|
||||
|
||||
if delta is None:
|
||||
delta_k = jnp.fft.rfftn(cic_paint(jnp.zeros(mesh_shape), positions))
|
||||
field = paint_fn(positions)
|
||||
delta_k = fft3d(field)
|
||||
elif jnp.isrealobj(delta):
|
||||
delta_k = jnp.fft.rfftn(delta)
|
||||
delta_k = fft3d(delta)
|
||||
else:
|
||||
delta_k = delta
|
||||
|
||||
kvec = fftk(delta_k)
|
||||
# Computes gravitational potential
|
||||
kvec = fftk(mesh_shape)
|
||||
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec, r_split=r_split)
|
||||
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(
|
||||
kvec, r_split=r_split)
|
||||
# Computes gravitational forces
|
||||
return jnp.stack([cic_read(jnp.fft.irfftn(- gradient_kernel(kvec, i) * pot_k), positions)
|
||||
for i in range(3)], axis=-1)
|
||||
forces = jnp.stack([
|
||||
read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),positions
|
||||
) for i in range(3)], axis=-1) # yapf: disable
|
||||
|
||||
return forces
|
||||
|
||||
|
||||
def lpt(cosmo:Cosmology, init_mesh, positions, a, order=1):
|
||||
def lpt(cosmo,
|
||||
initial_conditions,
|
||||
particles=None,
|
||||
a=0.1,
|
||||
halo_size=0,
|
||||
sharding=None,
|
||||
order=1):
|
||||
"""
|
||||
Computes first and second order LPT displacement and momentum,
|
||||
e.g. Eq. 2 and 3 [Jenkins2010](https://arxiv.org/pdf/0910.0258)
|
||||
"""
|
||||
paint_absolute_pos = particles is not None
|
||||
if particles is None:
|
||||
particles = jnp.zeros_like(initial_conditions,
|
||||
shape=(*initial_conditions.shape, 3))
|
||||
|
||||
a = jnp.atleast_1d(a)
|
||||
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
|
||||
delta_k = jnp.fft.rfftn(init_mesh) # TODO: pass the modes directly to save one or two fft?
|
||||
mesh_shape = init_mesh.shape
|
||||
|
||||
init_force = pm_forces(positions, mesh_shape, delta=delta_k)
|
||||
dx = growth_factor(cosmo, a) * init_force
|
||||
delta_k = fft3d(initial_conditions)
|
||||
initial_force = pm_forces(particles,
|
||||
delta=delta_k,
|
||||
paint_absolute_pos=paint_absolute_pos,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
dx = growth_factor(cosmo, a) * initial_force
|
||||
p = a**2 * growth_rate(cosmo, a) * E * dx
|
||||
f = a**2 * E * dGfa(cosmo, a) * init_force
|
||||
|
||||
f = a**2 * E * dGfa(cosmo, a) * initial_force
|
||||
if order == 2:
|
||||
kvec = fftk(mesh_shape)
|
||||
kvec = fftk(delta_k)
|
||||
pot_k = delta_k * invlaplace_kernel(kvec)
|
||||
|
||||
delta2 = 0
|
||||
|
@ -54,20 +96,26 @@ def lpt(cosmo:Cosmology, init_mesh, positions, a, order=1):
|
|||
# Add products of diagonal terms = 0 + s11*s00 + s22*(s11+s00)...
|
||||
# shear_ii = jnp.fft.irfftn(- ki**2 * pot_k)
|
||||
nabla_i_nabla_i = gradient_kernel(kvec, i)**2
|
||||
shear_ii = jnp.fft.irfftn(nabla_i_nabla_i * pot_k)
|
||||
shear_ii = ifft3d(nabla_i_nabla_i * pot_k)
|
||||
delta2 += shear_ii * shear_acc
|
||||
shear_acc += shear_ii
|
||||
|
||||
# for kj in kvec[i+1:]:
|
||||
for j in range(i+1, 3):
|
||||
for j in range(i + 1, 3):
|
||||
# Substract squared strict-up-triangle terms
|
||||
# delta2 -= jnp.fft.irfftn(- ki * kj * pot_k)**2
|
||||
nabla_i_nabla_j = gradient_kernel(kvec, i) * gradient_kernel(kvec, j)
|
||||
delta2 -= jnp.fft.irfftn(nabla_i_nabla_j * pot_k)**2
|
||||
nabla_i_nabla_j = gradient_kernel(kvec, i) * gradient_kernel(
|
||||
kvec, j)
|
||||
delta2 -= ifft3d(nabla_i_nabla_j * pot_k)**2
|
||||
|
||||
init_force2 = pm_forces(positions, mesh_shape, delta=jnp.fft.rfftn(delta2))
|
||||
delta_k2 = fft3d(delta2)
|
||||
init_force2 = pm_forces(particles,
|
||||
delta=delta_k2,
|
||||
paint_absolute_pos=paint_absolute_pos,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
# NOTE: growth_factor_second is renormalized: - D2 = 3/7 * growth_factor_second
|
||||
dx2 = 3/7 * growth_factor_second(cosmo, a) * init_force2
|
||||
dx2 = 3 / 7 * growth_factor_second(cosmo, a) * init_force2
|
||||
p2 = a**2 * growth_rate_second(cosmo, a) * E * dx2
|
||||
f2 = a**2 * E * dGf2a(cosmo, a) * init_force2
|
||||
|
||||
|
@ -78,23 +126,28 @@ def lpt(cosmo:Cosmology, init_mesh, positions, a, order=1):
|
|||
return dx, p, f
|
||||
|
||||
|
||||
def linear_field(mesh_shape, box_size, pk, seed):
|
||||
def linear_field(mesh_shape, box_size, pk, seed, sharding=None):
|
||||
"""
|
||||
Generate initial conditions.
|
||||
"""
|
||||
kvec = fftk(mesh_shape)
|
||||
# Initialize a random field with one slice on each gpu
|
||||
field = normal_field(mesh_shape, seed=seed, sharding=sharding)
|
||||
field = fft3d(field)
|
||||
kvec = fftk(field)
|
||||
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2
|
||||
for i, kk in enumerate(kvec))**0.5
|
||||
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (
|
||||
box_size[0] * box_size[1] * box_size[2])
|
||||
|
||||
field = jax.random.normal(seed, mesh_shape)
|
||||
field = jnp.fft.rfftn(field) * pkmesh**0.5
|
||||
field = jnp.fft.irfftn(field)
|
||||
field = field * (pkmesh)**0.5
|
||||
field = ifft3d(field)
|
||||
return field
|
||||
|
||||
|
||||
def make_ode_fn(mesh_shape):
|
||||
def make_ode_fn(mesh_shape,
|
||||
paint_absolute_pos=True,
|
||||
halo_size=0,
|
||||
sharding=None):
|
||||
|
||||
def nbody_ode(state, a, cosmo):
|
||||
"""
|
||||
|
@ -102,7 +155,11 @@ def make_ode_fn(mesh_shape):
|
|||
"""
|
||||
pos, vel = state
|
||||
|
||||
forces = pm_forces(pos, mesh_shape=mesh_shape) * 1.5 * cosmo.Omega_m
|
||||
forces = pm_forces(pos,
|
||||
mesh_shape=mesh_shape,
|
||||
paint_absolute_pos=paint_absolute_pos,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding) * 1.5 * cosmo.Omega_m
|
||||
|
||||
# Computes the update of position (drift)
|
||||
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
|
||||
|
@ -114,16 +171,24 @@ def make_ode_fn(mesh_shape):
|
|||
|
||||
return nbody_ode
|
||||
|
||||
def get_ode_fn(cosmo:Cosmology, mesh_shape):
|
||||
|
||||
def make_diffrax_ode(cosmo,
|
||||
mesh_shape,
|
||||
paint_absolute_pos=True,
|
||||
halo_size=0,
|
||||
sharding=None):
|
||||
|
||||
def nbody_ode(a, state, args):
|
||||
"""
|
||||
State is an array [position, velocities]
|
||||
|
||||
Compatible with [Diffrax API](https://docs.kidger.site/diffrax/)
|
||||
state is a tuple (position, velocities)
|
||||
"""
|
||||
pos, vel = state
|
||||
forces = pm_forces(pos, mesh_shape) * 1.5 * cosmo.Omega_m
|
||||
|
||||
forces = pm_forces(pos,
|
||||
mesh_shape=mesh_shape,
|
||||
paint_absolute_pos=paint_absolute_pos,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding) * 1.5 * cosmo.Omega_m
|
||||
|
||||
# Computes the update of position (drift)
|
||||
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
|
||||
|
@ -145,44 +210,50 @@ def pgd_correction(pos, mesh_shape, params):
|
|||
pos: particle positions [npart, 3]
|
||||
params: [alpha, kl, ks] pgd parameters
|
||||
"""
|
||||
kvec = fftk(mesh_shape)
|
||||
delta = cic_paint(jnp.zeros(mesh_shape), pos)
|
||||
delta_k = fft3d(delta)
|
||||
kvec = fftk(delta_k)
|
||||
alpha, kl, ks = params
|
||||
delta_k = jnp.fft.rfftn(delta)
|
||||
PGD_range=PGD_kernel(kvec, kl, ks)
|
||||
PGD_range = PGD_kernel(kvec, kl, ks)
|
||||
|
||||
pot_k_pgd=(delta_k * invlaplace_kernel(kvec))*PGD_range
|
||||
pot_k_pgd = (delta_k * invlaplace_kernel(kvec)) * PGD_range
|
||||
|
||||
forces_pgd= jnp.stack([cic_read(jnp.fft.irfftn(- gradient_kernel(kvec, i)*pot_k_pgd), pos)
|
||||
for i in range(3)],axis=-1)
|
||||
forces_pgd = jnp.stack([
|
||||
cic_read(fft3d(-gradient_kernel(kvec, i) * pot_k_pgd), pos)
|
||||
for i in range(3)
|
||||
],
|
||||
axis=-1)
|
||||
|
||||
dpos_pgd = forces_pgd*alpha
|
||||
dpos_pgd = forces_pgd * alpha
|
||||
|
||||
return dpos_pgd
|
||||
|
||||
|
||||
def make_neural_ode_fn(model, mesh_shape):
|
||||
def neural_nbody_ode(state, a, cosmo:Cosmology, params):
|
||||
|
||||
def neural_nbody_ode(state, a, cosmo: Cosmology, params):
|
||||
"""
|
||||
state is a tuple (position, velocities)
|
||||
"""
|
||||
pos, vel = state
|
||||
kvec = fftk(mesh_shape)
|
||||
|
||||
delta = cic_paint(jnp.zeros(mesh_shape), pos)
|
||||
|
||||
delta_k = jnp.fft.rfftn(delta)
|
||||
delta_k = fft3d(delta)
|
||||
kvec = fftk(delta_k)
|
||||
|
||||
# Computes gravitational potential
|
||||
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec, r_split=0)
|
||||
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec,
|
||||
r_split=0)
|
||||
|
||||
# Apply a correction filter
|
||||
kk = jnp.sqrt(sum((ki/jnp.pi)**2 for ki in kvec))
|
||||
pot_k = pot_k *(1. + model.apply(params, kk, jnp.atleast_1d(a)))
|
||||
kk = jnp.sqrt(sum((ki / jnp.pi)**2 for ki in kvec))
|
||||
pot_k = pot_k * (1. + model.apply(params, kk, jnp.atleast_1d(a)))
|
||||
|
||||
# Computes gravitational forces
|
||||
forces = jnp.stack([cic_read(jnp.fft.irfftn(- gradient_kernel(kvec, i)*pot_k), pos)
|
||||
for i in range(3)],axis=-1)
|
||||
forces = jnp.stack([
|
||||
cic_read(fft3d(-gradient_kernel(kvec, i) * pot_k), pos)
|
||||
for i in range(3)
|
||||
],
|
||||
axis=-1)
|
||||
|
||||
forces = forces * 1.5 * cosmo.Omega_m
|
||||
|
||||
|
@ -193,4 +264,5 @@ def make_neural_ode_fn(model, mesh_shape):
|
|||
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
|
||||
|
||||
return dpos, dvel
|
||||
|
||||
return neural_nbody_ode
|
||||
|
|
197
jaxpm/utils.py
197
jaxpm/utils.py
|
@ -1,89 +1,161 @@
|
|||
from functools import partial
|
||||
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from jax.scipy.stats import norm
|
||||
from scipy.special import legendre
|
||||
|
||||
__all__ = ['power_spectrum']
|
||||
__all__ = [
|
||||
'power_spectrum', 'transfer', 'coherence', 'pktranscoh',
|
||||
'cross_correlation_coefficients', 'gaussian_smoothing'
|
||||
]
|
||||
|
||||
|
||||
def _initialize_pk(shape, boxsize, kmin, dk):
|
||||
def _initialize_pk(mesh_shape, box_shape, kedges, los):
|
||||
"""
|
||||
Helper function to initialize various (fixed) values for powerspectra... not differentiable!
|
||||
Parameters
|
||||
----------
|
||||
mesh_shape : tuple of int
|
||||
Shape of the mesh grid.
|
||||
box_shape : tuple of float
|
||||
Physical dimensions of the box.
|
||||
kedges : None, int, float, or list
|
||||
If None, set dk to twice the minimum.
|
||||
If int, specifies number of edges.
|
||||
If float, specifies dk.
|
||||
los : array_like
|
||||
Line-of-sight vector.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dig : ndarray
|
||||
Indices of the bins to which each value in input array belongs.
|
||||
kcount : ndarray
|
||||
Count of values in each bin.
|
||||
kedges : ndarray
|
||||
Edges of the bins.
|
||||
mumesh : ndarray
|
||||
Mu values for the mesh grid.
|
||||
"""
|
||||
I = np.eye(len(shape), dtype='int') * -2 + 1
|
||||
kmax = np.pi * np.min(mesh_shape / box_shape) # = knyquist
|
||||
|
||||
W = np.empty(shape, dtype='f4')
|
||||
W[...] = 2.0
|
||||
W[..., 0] = 1.0
|
||||
W[..., -1] = 1.0
|
||||
if isinstance(kedges, None | int | float):
|
||||
if kedges is None:
|
||||
dk = 2 * np.pi / np.min(
|
||||
box_shape) * 2 # twice the minimum wavenumber
|
||||
if isinstance(kedges, int):
|
||||
dk = kmax / (kedges + 1) # final number of bins will be kedges-1
|
||||
elif isinstance(kedges, float):
|
||||
dk = kedges
|
||||
kedges = np.arange(dk, kmax, dk) + dk / 2 # from dk/2 to kmax-dk/2
|
||||
|
||||
kmax = np.pi * np.min(np.array(shape)) / np.max(np.array(boxsize)) + dk / 2
|
||||
kedges = np.arange(kmin, kmax, dk)
|
||||
kshapes = np.eye(len(mesh_shape), dtype=np.int32) * -2 + 1
|
||||
kvec = [(2 * np.pi * m / l) * np.fft.fftfreq(m).reshape(kshape)
|
||||
for m, l, kshape in zip(mesh_shape, box_shape, kshapes)]
|
||||
kmesh = sum(ki**2 for ki in kvec)**0.5
|
||||
|
||||
k = [
|
||||
np.fft.fftfreq(N, 1. / (N * 2 * np.pi / L))[:pkshape].reshape(kshape)
|
||||
for N, L, kshape, pkshape in zip(shape, boxsize, I, shape)
|
||||
]
|
||||
kmag = sum(ki**2 for ki in k)**0.5
|
||||
dig = np.digitize(kmesh.reshape(-1), kedges)
|
||||
kcount = np.bincount(dig, minlength=len(kedges) + 1)
|
||||
|
||||
xsum = np.zeros(len(kedges) + 1)
|
||||
Nsum = np.zeros(len(kedges) + 1)
|
||||
# Central value of each bin
|
||||
# kavg = (kedges[1:] + kedges[:-1]) / 2
|
||||
kavg = np.bincount(
|
||||
dig, weights=kmesh.reshape(-1), minlength=len(kedges) + 1) / kcount
|
||||
kavg = kavg[1:-1]
|
||||
|
||||
dig = np.digitize(kmag.flat, kedges)
|
||||
if los is None:
|
||||
mumesh = 1.
|
||||
else:
|
||||
mumesh = sum(ki * losi for ki, losi in zip(kvec, los))
|
||||
kmesh_nozeros = np.where(kmesh == 0, 1, kmesh)
|
||||
mumesh = np.where(kmesh == 0, 0, mumesh / kmesh_nozeros)
|
||||
|
||||
xsum.flat += np.bincount(dig, weights=(W * kmag).flat, minlength=xsum.size)
|
||||
Nsum.flat += np.bincount(dig, weights=W.flat, minlength=xsum.size)
|
||||
return dig, Nsum, xsum, W, k, kedges
|
||||
return dig, kcount, kavg, mumesh
|
||||
|
||||
|
||||
def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
|
||||
def power_spectrum(mesh,
|
||||
mesh2=None,
|
||||
box_shape=None,
|
||||
kedges: int | float | list = None,
|
||||
multipoles=0,
|
||||
los=[0., 0., 1.]):
|
||||
"""
|
||||
Calculate the powerspectra given real space field
|
||||
|
||||
Args:
|
||||
|
||||
field: real valued field
|
||||
kmin: minimum k-value for binned powerspectra
|
||||
dk: differential in each kbin
|
||||
boxsize: length of each boxlength (can be strangly shaped?)
|
||||
|
||||
Returns:
|
||||
|
||||
kbins: the central value of the bins for plotting
|
||||
power: real valued array of power in each bin
|
||||
|
||||
Compute the auto and cross spectrum of 3D fields, with multipoles.
|
||||
"""
|
||||
shape = field.shape
|
||||
nx, ny, nz = shape
|
||||
# Initialize
|
||||
mesh_shape = np.array(mesh.shape)
|
||||
if box_shape is None:
|
||||
box_shape = mesh_shape
|
||||
else:
|
||||
box_shape = np.asarray(box_shape)
|
||||
|
||||
#initialze values related to powerspectra (mode bins and weights)
|
||||
dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk)
|
||||
if multipoles == 0:
|
||||
los = None
|
||||
else:
|
||||
los = np.asarray(los)
|
||||
los = los / np.linalg.norm(los)
|
||||
poles = np.atleast_1d(multipoles)
|
||||
dig, kcount, kavg, mumesh = _initialize_pk(mesh_shape, box_shape, kedges,
|
||||
los)
|
||||
n_bins = len(kavg) + 2
|
||||
|
||||
#fast fourier transform
|
||||
fft_image = jnp.fft.fftn(field)
|
||||
# FFTs
|
||||
meshk = jnp.fft.fftn(mesh, norm='ortho')
|
||||
if mesh2 is None:
|
||||
mmk = meshk.real**2 + meshk.imag**2
|
||||
else:
|
||||
mmk = meshk * jnp.fft.fftn(mesh2, norm='ortho').conj()
|
||||
|
||||
#absolute value of fast fourier transform
|
||||
pk = jnp.real(fft_image * jnp.conj(fft_image))
|
||||
# Sum powers
|
||||
pk = jnp.empty((len(poles), n_bins))
|
||||
for i_ell, ell in enumerate(poles):
|
||||
weights = (mmk * (2 * ell + 1) * legendre(ell)(mumesh)).reshape(-1)
|
||||
if mesh2 is None:
|
||||
psum = jnp.bincount(dig, weights=weights, length=n_bins)
|
||||
else: # XXX: bincount is really slow with complex numbers
|
||||
psum_real = jnp.bincount(dig, weights=weights.real, length=n_bins)
|
||||
psum_imag = jnp.bincount(dig, weights=weights.imag, length=n_bins)
|
||||
psum = (psum_real**2 + psum_imag**2)**.5
|
||||
pk = pk.at[i_ell].set(psum)
|
||||
|
||||
#calculating powerspectra
|
||||
real = jnp.real(pk).reshape([-1])
|
||||
imag = jnp.imag(pk).reshape([-1])
|
||||
# Normalization and conversion from cell units to [Mpc/h]^3
|
||||
pk = (pk / kcount)[:, 1:-1] * (box_shape / mesh_shape).prod()
|
||||
|
||||
Psum = jnp.bincount(dig, weights=(W.flatten() * imag),
|
||||
length=xsum.size) * 1j
|
||||
Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
|
||||
|
||||
P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')
|
||||
|
||||
#normalization for powerspectra
|
||||
norm = np.prod(np.array(shape[:])).astype('float32')**2
|
||||
|
||||
#find central values of each bin
|
||||
kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2
|
||||
|
||||
return kbins, P / norm
|
||||
# pk = jnp.concatenate([kavg[None], pk])
|
||||
if np.ndim(multipoles) == 0:
|
||||
return kavg, pk[0]
|
||||
else:
|
||||
return kavg, pk
|
||||
|
||||
|
||||
def cross_correlation_coefficients(field_a,field_b, kmin=5, dk=0.5, boxsize=False):
|
||||
def transfer(mesh0, mesh1, box_shape, kedges: int | float | list = None):
|
||||
pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
|
||||
ks, pk0 = pk_fn(mesh0)
|
||||
ks, pk1 = pk_fn(mesh1)
|
||||
return ks, (pk1 / pk0)**.5
|
||||
|
||||
|
||||
def coherence(mesh0, mesh1, box_shape, kedges: int | float | list = None):
|
||||
pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
|
||||
ks, pk01 = pk_fn(mesh0, mesh1)
|
||||
ks, pk0 = pk_fn(mesh0)
|
||||
ks, pk1 = pk_fn(mesh1)
|
||||
return ks, pk01 / (pk0 * pk1)**.5
|
||||
|
||||
|
||||
def pktranscoh(mesh0, mesh1, box_shape, kedges: int | float | list = None):
|
||||
pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
|
||||
ks, pk01 = pk_fn(mesh0, mesh1)
|
||||
ks, pk0 = pk_fn(mesh0)
|
||||
ks, pk1 = pk_fn(mesh1)
|
||||
return ks, pk0, pk1, (pk1 / pk0)**.5, pk01 / (pk0 * pk1)**.5
|
||||
|
||||
|
||||
def cross_correlation_coefficients(field_a,
|
||||
field_b,
|
||||
kmin=5,
|
||||
dk=0.5,
|
||||
boxsize=False):
|
||||
"""
|
||||
Calculate the cross correlation coefficients given two real space field
|
||||
|
||||
|
@ -118,7 +190,8 @@ def cross_correlation_coefficients(field_a,field_b, kmin=5, dk=0.5, boxsize=Fals
|
|||
real = jnp.real(pk).reshape([-1])
|
||||
imag = jnp.imag(pk).reshape([-1])
|
||||
|
||||
Psum = jnp.bincount(dig, weights=(W.flatten() * imag), length=xsum.size) * 1j
|
||||
Psum = jnp.bincount(dig, weights=(W.flatten() * imag),
|
||||
length=xsum.size) * 1j
|
||||
Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
|
||||
|
||||
P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')
|
||||
|
|
179
notebooks/05-MultiHost_PM.py
Normal file
179
notebooks/05-MultiHost_PM.py
Normal file
|
@ -0,0 +1,179 @@
|
|||
import os
|
||||
|
||||
os.environ["EQX_ON_ERROR"] = "nan" # avoid an allgather caused by diffrax
|
||||
import jax
|
||||
|
||||
jax.distributed.initialize()
|
||||
rank = jax.process_index()
|
||||
size = jax.process_count()
|
||||
if rank == 0:
|
||||
print(f"SIZE is {jax.device_count()}")
|
||||
|
||||
import argparse
|
||||
from functools import partial
|
||||
|
||||
import jax.numpy as jnp
|
||||
import jax_cosmo as jc
|
||||
import numpy as np
|
||||
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
|
||||
PIDController, SaveAt, diffeqsolve)
|
||||
from jax.experimental.mesh_utils import create_device_mesh
|
||||
from jax.experimental.multihost_utils import process_allgather
|
||||
from jax.sharding import Mesh, NamedSharding
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
from jaxpm.kernels import interpolate_power_spectrum
|
||||
from jaxpm.painting import cic_paint_dx
|
||||
from jaxpm.pm import linear_field, lpt, make_diffrax_ode
|
||||
|
||||
all_gather = partial(process_allgather, tiled=True)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run a cosmological simulation with JAX.")
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--pdims",
|
||||
type=int,
|
||||
nargs=2,
|
||||
default=[1, jax.devices()],
|
||||
help="Processor grid dimensions as two integers (e.g., 2 4).")
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--mesh_shape",
|
||||
type=int,
|
||||
nargs=3,
|
||||
default=[512, 512, 512],
|
||||
help="Shape of the simulation mesh as three values (e.g., 512 512 512)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--box_size",
|
||||
type=float,
|
||||
nargs=3,
|
||||
default=[500.0, 500.0, 500.0],
|
||||
help=
|
||||
"Box size of the simulation as three values (e.g., 500.0 500.0 1000.0)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-st",
|
||||
"--snapshots",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Number of snapshots to save during the simulation.")
|
||||
parser.add_argument("-H",
|
||||
"--halo_size",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Halo size for the simulation.")
|
||||
parser.add_argument("-s",
|
||||
"--solver",
|
||||
type=str,
|
||||
choices=['leapfrog', 'dopri8'],
|
||||
default='leapfrog',
|
||||
help="ODE solver choice: 'leapfrog' or 'dopri8'.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def create_mesh_and_sharding(pdims):
|
||||
devices = create_device_mesh(pdims)
|
||||
mesh = Mesh(devices, axis_names=('x', 'y'))
|
||||
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||
return mesh, sharding
|
||||
|
||||
|
||||
@partial(jax.jit, static_argnums=(2, 3, 4, 5, 6))
|
||||
def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
|
||||
solver_choice, nb_snapshots, sharding):
|
||||
k = jnp.logspace(-4, 1, 128)
|
||||
pk = jc.power.linear_matter_power(
|
||||
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
|
||||
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)
|
||||
|
||||
initial_conditions = linear_field(mesh_shape,
|
||||
box_size,
|
||||
pk_fn,
|
||||
seed=jax.random.PRNGKey(0),
|
||||
sharding=sharding)
|
||||
|
||||
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
|
||||
|
||||
dx, p, _ = lpt(cosmo,
|
||||
initial_conditions,
|
||||
a=0.1,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
|
||||
|
||||
# Choose solver
|
||||
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()
|
||||
stepsize_controller = ConstantStepSize(
|
||||
) if solver_choice == "leapfrog" else PIDController(rtol=1e-5, atol=1e-5)
|
||||
res = diffeqsolve(ode_fn,
|
||||
solver,
|
||||
t0=0.1,
|
||||
t1=1.,
|
||||
dt0=0.01,
|
||||
y0=jnp.stack([dx, p], axis=0),
|
||||
args=cosmo,
|
||||
saveat=SaveAt(ts=jnp.linspace(0.2, 1., nb_snapshots)),
|
||||
stepsize_controller=stepsize_controller)
|
||||
|
||||
ode_fields = [
|
||||
cic_paint_dx(sol[0], halo_size=halo_size, sharding=sharding)
|
||||
for sol in res.ys
|
||||
]
|
||||
lpt_field = cic_paint_dx(dx, halo_size=halo_size, sharding=sharding)
|
||||
return initial_conditions, lpt_field, ode_fields, res.stats
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_arguments()
|
||||
mesh_shape = args.mesh_shape
|
||||
box_size = args.box_size
|
||||
halo_size = args.halo_size
|
||||
solver_choice = args.solver
|
||||
nb_snapshots = args.snapshots
|
||||
|
||||
sharding = create_mesh_and_sharding(args.pdims)
|
||||
|
||||
initial_conditions, lpt_displacements, ode_solutions, solver_stats = run_simulation(
|
||||
0.25, 0.8, tuple(mesh_shape), tuple(box_size), halo_size,
|
||||
solver_choice, nb_snapshots, sharding)
|
||||
|
||||
if rank == 0:
|
||||
os.makedirs("fields", exist_ok=True)
|
||||
print(f"[{rank}] Simulation done")
|
||||
print(f"Solver stats: {solver_stats}")
|
||||
|
||||
# Save initial conditions
|
||||
initial_conditions_g = all_gather(initial_conditions)
|
||||
if rank == 0:
|
||||
print(f"[{rank}] Saving initial_conditions")
|
||||
np.save("fields/initial_conditions.npy", initial_conditions_g)
|
||||
print(f"[{rank}] initial_conditions saved")
|
||||
del initial_conditions_g, initial_conditions
|
||||
|
||||
# Save LPT displacements
|
||||
lpt_displacements_g = all_gather(lpt_displacements)
|
||||
if rank == 0:
|
||||
print(f"[{rank}] Saving lpt_displacements")
|
||||
np.save("fields/lpt_displacements.npy", lpt_displacements_g)
|
||||
print(f"[{rank}] lpt_displacements saved")
|
||||
del lpt_displacements_g, lpt_displacements
|
||||
|
||||
# Save each ODE solution separately
|
||||
for i, sol in enumerate(ode_solutions):
|
||||
sol_g = all_gather(sol)
|
||||
if rank == 0:
|
||||
print(f"[{rank}] Saving ode_solution_{i}")
|
||||
np.save(f"fields/ode_solution_{i}.npy", sol_g)
|
||||
print(f"[{rank}] ode_solution_{i} saved")
|
||||
del sol_g
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
notebooks/06-Animating_PM_Fields.ipynb.REMOVED.git-id
Normal file
1
notebooks/06-Animating_PM_Fields.ipynb.REMOVED.git-id
Normal file
|
@ -0,0 +1 @@
|
|||
c4a44973e4f11841a8c14f4d200e7e87887419aa
|
39
notebooks/README.md
Normal file
39
notebooks/README.md
Normal file
|
@ -0,0 +1,39 @@
|
|||
# Particle Mesh Simulation with JAXPM on Multi-GPU and Multi-Host Systems
|
||||
|
||||
This collection of notebooks demonstrates how to perform Particle Mesh (PM) simulations using **JAXPM**, leveraging JAX for efficient computation on multi-GPU and multi-host systems. Each notebook progressively covers different setups, from single-GPU simulations to advanced, distributed, multi-host simulations across multiple nodes.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. **[Single-GPU Particle Mesh Simulation](01-Introduction.ipynb)**
|
||||
- Introduction to basic PM simulations on a single GPU.
|
||||
- Uses JAXPM to run simulations with absolute particle positions and Cloud-in-Cell (CIC) painting.
|
||||
|
||||
2. **[Advanced Particle Mesh Simulation on a Single GPU](02-Advanced_usage.ipynb)**
|
||||
- Explore using diffrax solvers in the ODE step.
|
||||
- Explores second order Lagrangian Perturbation Theory (LPT) simulations.
|
||||
- Introduces weighted density field projections
|
||||
|
||||
3. **[Multi-GPU Particle Mesh Simulation with Halo Exchange](03-MultiGPU_PM_Halo.ipynb)**
|
||||
- Extends PM simulation to multi-GPU setups with halo exchange.
|
||||
- Uses sharding and device mesh configurations to manage distributed data across GPUs.
|
||||
|
||||
4. **[Multi-GPU Particle Mesh Simulation with Advanced Solvers](04-MultiGPU_PM_Solvers.ipynb)**
|
||||
- Compares different ODE solvers (Leapfrog and Dopri5) in multi-GPU simulations.
|
||||
- Highlights performance, memory considerations, and solver impact on simulation quality.
|
||||
|
||||
5. **[Multi-Host Particle Mesh Simulation](05-MultiHost_PM.ipynb)**
|
||||
- Extends PM simulations to multi-host, multi-GPU setups for large-scale simulations.
|
||||
- Guides through job submission, device initialization, and retrieving results across nodes.
|
||||
|
||||
## Getting Started
|
||||
|
||||
Each notebook includes installation instructions and guidelines for configuring JAXPM and required dependencies. Follow the setup instructions in each notebook to ensure an optimal environment.
|
||||
|
||||
## Requirements
|
||||
|
||||
- **JAXPM** (included in the installation commands within notebooks)
|
||||
- **Diffrax** for ODE solvers
|
||||
- **JAX** with CUDA support for multi-GPU or TPU setups
|
||||
- **SLURM** for job scheduling on clusters (if running multi-host setups)
|
||||
|
||||
> **Note**: These notebooks are tested on the **Jean Zay** supercomputer and may require configuration changes for different HPC clusters.
|
30
pyproject.toml
Normal file
30
pyproject.toml
Normal file
|
@ -0,0 +1,30 @@
|
|||
[build-system]
|
||||
requires = ["setuptools", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "JaxPM"
|
||||
version = "0.0.1"
|
||||
description = "A dead simple FastPM implementation in JAX"
|
||||
authors = [{ name = "JaxPM developers" }]
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
license = { file = "LICENSE" }
|
||||
urls = { "Homepage" = "https://github.com/DifferentiableUniverseInitiative/JaxPM" }
|
||||
dependencies = ["jax_cosmo", "jax>=0.4.30", "jaxdecomp>=0.2.2"]
|
||||
|
||||
[project.optional-dependencies]
|
||||
test = [
|
||||
"jax>=0.4.30",
|
||||
"numpy",
|
||||
"jax_cosmo",
|
||||
"jaxdecomp>=0.2.2",
|
||||
"pytest>=8.0.0",
|
||||
"pfft-python @ git+https://github.com/MP-Gadget/pfft-python",
|
||||
"pmesh @ git+https://github.com/MP-Gadget/pmesh",
|
||||
"fastpm @ git+https://github.com/ASKabalan/fastpm-python",
|
||||
"diffrax"
|
||||
]
|
||||
|
||||
[tool.setuptools]
|
||||
packages = ["jaxpm"]
|
4
pytest.ini
Normal file
4
pytest.ini
Normal file
|
@ -0,0 +1,4 @@
|
|||
[pytest]
|
||||
markers =
|
||||
distributed: mark a test as distributed
|
||||
single_device: mark a test as single_device
|
11
setup.py
11
setup.py
|
@ -1,11 +0,0 @@
|
|||
from setuptools import find_packages, setup
|
||||
|
||||
setup(
|
||||
name='JaxPM',
|
||||
version='0.0.1',
|
||||
url='https://github.com/DifferentiableUniverseInitiative/JaxPM',
|
||||
author='JaxPM developers',
|
||||
description='A dead simple FastPM implementation in JAX',
|
||||
packages=find_packages(),
|
||||
install_requires=['jax', 'jax_cosmo'],
|
||||
)
|
175
tests/conftest.py
Normal file
175
tests/conftest.py
Normal file
|
@ -0,0 +1,175 @@
|
|||
# Parameterized fixture for mesh_shape
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ["EQX_ON_ERROR"] = "nan"
|
||||
setup_done = False
|
||||
on_cluster = False
|
||||
|
||||
|
||||
def is_on_cluster():
|
||||
global on_cluster
|
||||
return on_cluster
|
||||
|
||||
|
||||
def initialize_distributed():
|
||||
global setup_done
|
||||
global on_cluster
|
||||
if not setup_done:
|
||||
if "SLURM_JOB_ID" in os.environ:
|
||||
on_cluster = True
|
||||
print("Running on cluster")
|
||||
import jax
|
||||
jax.distributed.initialize()
|
||||
setup_done = True
|
||||
on_cluster = True
|
||||
else:
|
||||
print("Running locally")
|
||||
setup_done = True
|
||||
on_cluster = False
|
||||
os.environ["JAX_PLATFORM_NAME"] = "cpu"
|
||||
os.environ[
|
||||
"XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
|
||||
import jax
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
scope="session",
|
||||
params=[
|
||||
((32, 32, 32), (256., 256., 256.)), # BOX
|
||||
((32, 32, 64), (256., 256., 512.)), # RECTANGULAR
|
||||
])
|
||||
def simulation_config(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", params=[0.1, 0.5, 0.8])
|
||||
def lpt_scale_factor(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def cosmo():
|
||||
from functools import partial
|
||||
|
||||
from jax_cosmo import Cosmology
|
||||
Planck18 = partial(
|
||||
Cosmology,
|
||||
# Omega_m = 0.3111
|
||||
Omega_c=0.2607,
|
||||
Omega_b=0.0490,
|
||||
Omega_k=0.0,
|
||||
h=0.6766,
|
||||
n_s=0.9665,
|
||||
sigma8=0.8102,
|
||||
w0=-1.0,
|
||||
wa=0.0,
|
||||
)
|
||||
|
||||
return Planck18()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def particle_mesh(simulation_config):
|
||||
from pmesh.pm import ParticleMesh
|
||||
mesh_shape, box_shape = simulation_config
|
||||
return ParticleMesh(BoxSize=box_shape, Nmesh=mesh_shape, dtype='f4')
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def fpm_initial_conditions(cosmo, particle_mesh):
|
||||
import jax_cosmo as jc
|
||||
import numpy as np
|
||||
from jax import numpy as jnp
|
||||
|
||||
# Generate initial particle positions
|
||||
grid = particle_mesh.generate_uniform_particle_grid(shift=0).astype(
|
||||
np.float32)
|
||||
# Interpolate with linear_matter spectrum to get initial density field
|
||||
k = jnp.logspace(-4, 1, 128)
|
||||
pk = jc.power.linear_matter_power(cosmo, k)
|
||||
|
||||
def pk_fn(x):
|
||||
return jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape)
|
||||
|
||||
whitec = particle_mesh.generate_whitenoise(42,
|
||||
type='complex',
|
||||
unitary=False)
|
||||
lineark = whitec.apply(lambda k, v: pk_fn(sum(ki**2 for ki in k)**0.5)**0.5
|
||||
* v * (1 / v.BoxSize).prod()**0.5)
|
||||
init_mesh = lineark.c2r().value # XXX
|
||||
|
||||
return lineark, grid, init_mesh
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def initial_conditions(fpm_initial_conditions):
|
||||
_, _, init_mesh = fpm_initial_conditions
|
||||
return init_mesh
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def solver(cosmo, particle_mesh):
|
||||
from fastpm.core import Cosmology as FastPMCosmology
|
||||
from fastpm.core import Solver
|
||||
ref_cosmo = FastPMCosmology(cosmo)
|
||||
return Solver(particle_mesh, ref_cosmo, B=1)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def fpm_lpt1(solver, fpm_initial_conditions, lpt_scale_factor):
|
||||
|
||||
lineark, grid, _ = fpm_initial_conditions
|
||||
statelpt = solver.lpt(lineark, grid, lpt_scale_factor, order=1)
|
||||
return statelpt
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def fpm_lpt1_field(fpm_lpt1, particle_mesh):
|
||||
return particle_mesh.paint(fpm_lpt1.X).value
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def fpm_lpt2(solver, fpm_initial_conditions, lpt_scale_factor):
|
||||
|
||||
lineark, grid, _ = fpm_initial_conditions
|
||||
statelpt = solver.lpt(lineark, grid, lpt_scale_factor, order=2)
|
||||
return statelpt
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def fpm_lpt2_field(fpm_lpt2, particle_mesh):
|
||||
return particle_mesh.paint(fpm_lpt2.X).value
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def nbody_from_lpt1(solver, fpm_lpt1, particle_mesh, lpt_scale_factor):
|
||||
import numpy as np
|
||||
from fastpm.core import leapfrog
|
||||
|
||||
if lpt_scale_factor == 0.8:
|
||||
pytest.skip("Do not run nbody simulation from scale factor 0.8")
|
||||
|
||||
stages = np.linspace(lpt_scale_factor, 1.0, 10, endpoint=True)
|
||||
|
||||
finalstate = solver.nbody(fpm_lpt1, leapfrog(stages))
|
||||
fpm_mesh = particle_mesh.paint(finalstate.X).value
|
||||
|
||||
return fpm_mesh
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def nbody_from_lpt2(solver, fpm_lpt2, particle_mesh, lpt_scale_factor):
|
||||
import numpy as np
|
||||
from fastpm.core import leapfrog
|
||||
|
||||
if lpt_scale_factor == 0.8:
|
||||
pytest.skip("Do not run nbody simulation from scale factor 0.8")
|
||||
|
||||
stages = np.linspace(lpt_scale_factor, 1.0, 10, endpoint=True)
|
||||
|
||||
finalstate = solver.nbody(fpm_lpt2, leapfrog(stages))
|
||||
fpm_mesh = particle_mesh.paint(finalstate.X).value
|
||||
|
||||
return fpm_mesh
|
13
tests/helpers.py
Normal file
13
tests/helpers.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
import jax.numpy as jnp
|
||||
|
||||
|
||||
def MSE(x, y):
|
||||
return jnp.mean((x - y)**2)
|
||||
|
||||
|
||||
def MSE_3D(x, y):
|
||||
return ((x - y)**2).mean(axis=0)
|
||||
|
||||
|
||||
def MSRE(x, y):
|
||||
return jnp.mean(((x - y) / y)**2)
|
155
tests/test_against_fpm.py
Normal file
155
tests/test_against_fpm.py
Normal file
|
@ -0,0 +1,155 @@
|
|||
import pytest
|
||||
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve
|
||||
from helpers import MSE, MSRE
|
||||
from jax import numpy as jnp
|
||||
|
||||
from jaxpm.distributed import uniform_particles
|
||||
from jaxpm.painting import cic_paint, cic_paint_dx
|
||||
from jaxpm.pm import lpt, make_diffrax_ode
|
||||
from jaxpm.utils import power_spectrum
|
||||
|
||||
_TOLERANCE = 1e-4
|
||||
_PM_TOLERANCE = 1e-3
|
||||
|
||||
|
||||
@pytest.mark.single_device
|
||||
@pytest.mark.parametrize("order", [1, 2])
|
||||
def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor,
|
||||
fpm_lpt1_field, fpm_lpt2_field, cosmo, order):
|
||||
|
||||
mesh_shape, box_shape = simulation_config
|
||||
cosmo._workspace = {}
|
||||
particles = uniform_particles(mesh_shape)
|
||||
|
||||
# Initial displacement
|
||||
dx, _, _ = lpt(cosmo,
|
||||
initial_conditions,
|
||||
particles,
|
||||
a=lpt_scale_factor,
|
||||
order=order)
|
||||
|
||||
fpm_ref_field = fpm_lpt1_field if order == 1 else fpm_lpt2_field
|
||||
|
||||
lpt_field = cic_paint(jnp.zeros(mesh_shape), particles + dx)
|
||||
_, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape)
|
||||
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
|
||||
|
||||
assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE
|
||||
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
|
||||
|
||||
|
||||
@pytest.mark.single_device
|
||||
@pytest.mark.parametrize("order", [1, 2])
|
||||
def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor,
|
||||
fpm_lpt1_field, fpm_lpt2_field, cosmo, order):
|
||||
|
||||
mesh_shape, box_shape = simulation_config
|
||||
cosmo._workspace = {}
|
||||
# Initial displacement
|
||||
dx, _, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order)
|
||||
|
||||
lpt_field = cic_paint_dx(dx)
|
||||
|
||||
fpm_ref_field = fpm_lpt1_field if order == 1 else fpm_lpt2_field
|
||||
|
||||
_, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape)
|
||||
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
|
||||
|
||||
assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE
|
||||
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
|
||||
|
||||
|
||||
@pytest.mark.single_device
|
||||
@pytest.mark.parametrize("order", [1, 2])
|
||||
def test_nbody_absolute(simulation_config, initial_conditions,
|
||||
lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2,
|
||||
cosmo, order):
|
||||
|
||||
mesh_shape, box_shape = simulation_config
|
||||
cosmo._workspace = {}
|
||||
particles = uniform_particles(mesh_shape)
|
||||
|
||||
# Initial displacement
|
||||
dx, p, _ = lpt(cosmo,
|
||||
initial_conditions,
|
||||
particles,
|
||||
a=lpt_scale_factor,
|
||||
order=order)
|
||||
|
||||
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
|
||||
|
||||
solver = Dopri5()
|
||||
controller = PIDController(rtol=1e-8,
|
||||
atol=1e-8,
|
||||
pcoeff=0.4,
|
||||
icoeff=1,
|
||||
dcoeff=0)
|
||||
|
||||
saveat = SaveAt(t1=True)
|
||||
|
||||
y0 = jnp.stack([particles + dx, p])
|
||||
|
||||
solutions = diffeqsolve(ode_fn,
|
||||
solver,
|
||||
t0=lpt_scale_factor,
|
||||
t1=1.0,
|
||||
dt0=None,
|
||||
y0=y0,
|
||||
stepsize_controller=controller,
|
||||
saveat=saveat)
|
||||
|
||||
final_field = cic_paint(jnp.zeros(mesh_shape), solutions.ys[-1, 0])
|
||||
|
||||
fpm_ref_field = nbody_from_lpt1 if order == 1 else nbody_from_lpt2
|
||||
|
||||
_, jpm_ps = power_spectrum(final_field, box_shape=box_shape)
|
||||
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
|
||||
|
||||
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE
|
||||
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE
|
||||
|
||||
|
||||
@pytest.mark.single_device
|
||||
@pytest.mark.parametrize("order", [1, 2])
|
||||
def test_nbody_relative(simulation_config, initial_conditions,
|
||||
lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2,
|
||||
cosmo, order):
|
||||
|
||||
mesh_shape, box_shape = simulation_config
|
||||
cosmo._workspace = {}
|
||||
|
||||
# Initial displacement
|
||||
dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order)
|
||||
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
|
||||
|
||||
solver = Dopri5()
|
||||
controller = PIDController(rtol=1e-9,
|
||||
atol=1e-9,
|
||||
pcoeff=0.4,
|
||||
icoeff=1,
|
||||
dcoeff=0)
|
||||
|
||||
saveat = SaveAt(t1=True)
|
||||
|
||||
y0 = jnp.stack([dx, p])
|
||||
|
||||
solutions = diffeqsolve(ode_fn,
|
||||
solver,
|
||||
t0=lpt_scale_factor,
|
||||
t1=1.0,
|
||||
dt0=None,
|
||||
y0=y0,
|
||||
stepsize_controller=controller,
|
||||
saveat=saveat)
|
||||
|
||||
final_field = cic_paint_dx(solutions.ys[-1, 0])
|
||||
|
||||
fpm_ref_field = nbody_from_lpt1 if order == 1 else nbody_from_lpt2
|
||||
|
||||
_, jpm_ps = power_spectrum(final_field, box_shape=box_shape)
|
||||
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
|
||||
|
||||
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE
|
||||
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE
|
152
tests/test_distributed_pm.py
Normal file
152
tests/test_distributed_pm.py
Normal file
|
@ -0,0 +1,152 @@
|
|||
from conftest import initialize_distributed
|
||||
|
||||
initialize_distributed() # ignore : E402
|
||||
|
||||
import jax # noqa : E402
|
||||
import jax.numpy as jnp # noqa : E402
|
||||
import pytest # noqa : E402
|
||||
from diffrax import SaveAt # noqa : E402
|
||||
from diffrax import Dopri5, ODETerm, PIDController, diffeqsolve
|
||||
from helpers import MSE # noqa : E402
|
||||
from jax import lax # noqa : E402
|
||||
from jax.experimental.multihost_utils import process_allgather # noqa : E402
|
||||
from jax.sharding import NamedSharding
|
||||
from jax.sharding import PartitionSpec as P # noqa : E402
|
||||
|
||||
from jaxpm.distributed import uniform_particles # noqa : E402
|
||||
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
|
||||
from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402
|
||||
|
||||
_TOLERANCE = 3.0 # 🙃🙃
|
||||
|
||||
|
||||
@pytest.mark.distributed
|
||||
@pytest.mark.parametrize("order", [1, 2])
|
||||
@pytest.mark.parametrize("absolute_painting", [True, False])
|
||||
def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
||||
absolute_painting):
|
||||
|
||||
mesh_shape, box_shape = simulation_config
|
||||
# SINGLE DEVICE RUN
|
||||
cosmo._workspace = {}
|
||||
if absolute_painting:
|
||||
particles = uniform_particles(mesh_shape)
|
||||
# Initial displacement
|
||||
dx, p, _ = lpt(cosmo,
|
||||
initial_conditions,
|
||||
particles,
|
||||
a=0.1,
|
||||
order=order)
|
||||
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
|
||||
y0 = jnp.stack([particles + dx, p])
|
||||
else:
|
||||
dx, p, _ = lpt(cosmo, initial_conditions, a=0.1, order=order)
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
|
||||
y0 = jnp.stack([dx, p])
|
||||
|
||||
solver = Dopri5()
|
||||
controller = PIDController(rtol=1e-8,
|
||||
atol=1e-8,
|
||||
pcoeff=0.4,
|
||||
icoeff=1,
|
||||
dcoeff=0)
|
||||
|
||||
saveat = SaveAt(t1=True)
|
||||
|
||||
solutions = diffeqsolve(ode_fn,
|
||||
solver,
|
||||
t0=0.1,
|
||||
t1=1.0,
|
||||
dt0=None,
|
||||
y0=y0,
|
||||
stepsize_controller=controller,
|
||||
saveat=saveat)
|
||||
|
||||
if absolute_painting:
|
||||
single_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape),
|
||||
solutions.ys[-1, 0])
|
||||
else:
|
||||
single_device_final_field = cic_paint_dx(solutions.ys[-1, 0])
|
||||
|
||||
print("Done with single device run")
|
||||
# MULTI DEVICE RUN
|
||||
|
||||
mesh = jax.make_mesh((1, 8), ('x', 'y'))
|
||||
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||
halo_size = mesh_shape[0] // 2
|
||||
|
||||
initial_conditions = lax.with_sharding_constraint(initial_conditions,
|
||||
sharding)
|
||||
|
||||
print(f"sharded initial conditions {initial_conditions.sharding}")
|
||||
|
||||
cosmo._workspace = {}
|
||||
if absolute_painting:
|
||||
particles = uniform_particles(mesh_shape, sharding=sharding)
|
||||
# Initial displacement
|
||||
dx, p, _ = lpt(cosmo,
|
||||
initial_conditions,
|
||||
particles,
|
||||
a=0.1,
|
||||
order=order,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo,
|
||||
mesh_shape,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding))
|
||||
|
||||
y0 = jnp.stack([particles + dx, p])
|
||||
else:
|
||||
dx, p, _ = lpt(cosmo,
|
||||
initial_conditions,
|
||||
a=0.1,
|
||||
order=order,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo,
|
||||
mesh_shape,
|
||||
paint_absolute_pos=False,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding))
|
||||
y0 = jnp.stack([dx, p])
|
||||
|
||||
solver = Dopri5()
|
||||
controller = PIDController(rtol=1e-8,
|
||||
atol=1e-8,
|
||||
pcoeff=0.4,
|
||||
icoeff=1,
|
||||
dcoeff=0)
|
||||
|
||||
saveat = SaveAt(t1=True)
|
||||
|
||||
solutions = diffeqsolve(ode_fn,
|
||||
solver,
|
||||
t0=0.1,
|
||||
t1=1.0,
|
||||
dt0=None,
|
||||
y0=y0,
|
||||
stepsize_controller=controller,
|
||||
saveat=saveat)
|
||||
|
||||
if absolute_painting:
|
||||
multi_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape),
|
||||
solutions.ys[-1, 0],
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
else:
|
||||
multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0],
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
|
||||
multi_device_final_field = process_allgather(multi_device_final_field,
|
||||
tiled=True)
|
||||
|
||||
mse = MSE(single_device_final_field, multi_device_final_field)
|
||||
print(f"MSE is {mse}")
|
||||
|
||||
assert mse < _TOLERANCE
|
Loading…
Add table
Reference in a new issue