mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-02-22 17:47:11 +00:00
jaxdecomp proto (#21)
* adding example of distributed solution * put back old functgion * update formatting * add halo exchange and slice pad * apply formatting * implement distributed optimized cic_paint * Use new cic_paint with halo * Fix seed for distributed normal * Wrap interpolation function to avoid all gather * Return normal order frequencies for single GPU * add example * format * add optimised bench script * times in ms * add lpt2 * update benchmark and add slurm * Visualize only final field * Update scripts/distributed_pm.py Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com> * Adjust pencil type for frequencies * fix painting issue with slabs * Shared operation in fourrier space now take inverted sharding axis for slabs * add assert to make pyright happy * adjust test for hpc-plotter * add PMWD test * bench * format * added github workflow * fix formatting from main * Update for jaxDecomp pure JAX * revert single halo extent change * update for latest jaxDecomp * remove fourrier_space in autoshmap * make normal_field work with single controller * format * make distributed pm work in single controller * merge bench_pm * update to leapfrog * add a strict dependency on jaxdecomp * global mesh no longer needed * kernels.py no longer uses global mesh * quick fix in distributed * pm.py no longer uses global mesh * painting.py no longer uses global mesh * update demo script * quick fix in kernels * quick fix in distributed * update demo * merge hugos LPT2 code * format * Small fix * format * remove duplicate get_ode_fn * update visualizer * update compensate CIC * By default check_rep is false for shard_map * remove experimental distributed code * update PGDCorrection and neural ode to use new fft3d * jaxDecomp pfft3d promotes to complex automatically * remove deprecated stuff * fix painting issue with read_cic * use jnp interp instead of jc interp * delete old slurms * add notebook examples * apply formatting * add distributed zeros * fix code in LPT2 * jit cic_paint * update notebooks * apply formating * get local shape and zeros can be used by users * add a user facing function to create uniform particle grid * use jax interp instead of jax_cosmo * use float64 for enmeshing * Allow applying weights with relative cic paint * Weights can be traced * remove script folder * update example notebooks * delete outdated design file * add readme for tutorials * update readme * fix small error * forgot particles in multi host * clarifying why cic_paint_dx is slower * clarifying the halo size dependence on the box size * ability to choose snapshots number with MultiHost script * Adding animation notebook * Put plotting in package * Add finite difference laplace kernel + powerspec functions from Hugo Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com> * Put plotting utils in package * By default use absoulute painting with * update code * update notebooks * add tests * Upgrade setup.py to pyproject * Format * format tests * update test dependencies * add test workflow * fix deprecated FftType in jaxpm.kernels * Add aboucaud comments * JAX version is 0.4.35 until Diffrax new release * add numpy explicitly as dependency for tests * fix install order for tests * add numpy to be installed * enforce no build isolation for fastpm * pip install jaxpm test without build isolation * bump jaxdecomp version * revert test workflow * remove outdated tests --------- Co-authored-by: EiffL <fr.eiffel@gmail.com> Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com> Co-authored-by: Wassim KABALAN <wassim@apc.in2p3.fr> Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com> Former-commit-id: 8c2e823d4669eac712089bf7f85ffb7912e8232d
This commit is contained in:
parent
a0a79277e5
commit
df8602b318
26 changed files with 1871 additions and 434 deletions
21
.github/workflows/formatting.yml
vendored
Normal file
21
.github/workflows/formatting.yml
vendored
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
name: Code Formatting
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ "main" ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ "main" ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v3
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip isort
|
||||||
|
python -m pip install pre-commit
|
||||||
|
- name: Run pre-commit
|
||||||
|
run: python -m pre_commit run --all-files
|
45
.github/workflows/tests.yml
vendored
Normal file
45
.github/workflows/tests.yml
vendored
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
name: Tests
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
run_tests:
|
||||||
|
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.10" , "3.11" , "3.12"]
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout Source
|
||||||
|
uses: actions/checkout@v2.3.1
|
||||||
|
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
sudo apt-get install -y libopenmpi-dev
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install jax==0.4.35
|
||||||
|
pip install numpy setuptools cython wheel
|
||||||
|
pip install git+https://github.com/MP-Gadget/pfft-python
|
||||||
|
pip install git+https://github.com/MP-Gadget/pmesh
|
||||||
|
pip install git+https://github.com/ASKabalan/fastpm-python --no-build-isolation
|
||||||
|
pip install .[test]
|
||||||
|
|
||||||
|
- name: Run Single Device Tests
|
||||||
|
run: |
|
||||||
|
cd tests
|
||||||
|
pytest -v -m "not distributed"
|
||||||
|
- name: Run Distributed tests
|
||||||
|
run: |
|
||||||
|
pytest -v -m distributed
|
5
.gitignore
vendored
5
.gitignore
vendored
|
@ -98,6 +98,11 @@ __pypackages__/
|
||||||
celerybeat-schedule
|
celerybeat-schedule
|
||||||
celerybeat.pid
|
celerybeat.pid
|
||||||
|
|
||||||
|
|
||||||
|
out
|
||||||
|
traces
|
||||||
|
*.npy
|
||||||
|
*.out
|
||||||
# SageMath parsed files
|
# SageMath parsed files
|
||||||
*.sage.py
|
*.sage.py
|
||||||
|
|
||||||
|
|
|
@ -4,16 +4,15 @@
|
||||||
<!-- ALL-CONTRIBUTORS-BADGE:END -->
|
<!-- ALL-CONTRIBUTORS-BADGE:END -->
|
||||||
JAX-powered Cosmological Particle-Mesh N-body Solver
|
JAX-powered Cosmological Particle-Mesh N-body Solver
|
||||||
|
|
||||||
**This project is currently in an early design phase. All inputs are welcome on the [design document](https://github.com/DifferentiableUniverseInitiative/JaxPM/blob/main/design.md)**
|
|
||||||
|
|
||||||
## Goals
|
## Goals
|
||||||
|
|
||||||
Provide a modern infrastructure to support differentiable PM N-body simulations using JAX:
|
Provide a modern infrastructure to support differentiable PM N-body simulations using JAX:
|
||||||
- Keep implementation simple and readable, in pure NumPy API
|
- Keep implementation simple and readable, in pure NumPy API
|
||||||
- Transparent distribution using builtin `xmap`
|
|
||||||
- Any order forward and backward automatic differentiation
|
- Any order forward and backward automatic differentiation
|
||||||
- Support automated batching using `vmap`
|
- Support automated batching using `vmap`
|
||||||
- Compatibility with external optimizer libraries like `optax`
|
- Compatibility with external optimizer libraries like `optax`
|
||||||
|
- Now fully distributable on **multi-GPU and multi-node** systems using [jaxDecomp](https://github.com/DifferentiableUniverseInitiative/jaxDecomp) working with`JAX v0.4.35`
|
||||||
|
|
||||||
|
|
||||||
## Open development and use
|
## Open development and use
|
||||||
|
|
||||||
|
@ -23,6 +22,10 @@ Current expectations are:
|
||||||
- Everyone is welcome to contribute, and can join the JOSS publication (until it is submitted to the journal).
|
- Everyone is welcome to contribute, and can join the JOSS publication (until it is submitted to the journal).
|
||||||
- Anyone (including main contributors) can use this code as a framework to build and publish their own applications, with no expectation that they *need* to extend authorship to all jaxpm developers.
|
- Anyone (including main contributors) can use this code as a framework to build and publish their own applications, with no expectation that they *need* to extend authorship to all jaxpm developers.
|
||||||
|
|
||||||
|
## Getting Started
|
||||||
|
|
||||||
|
To dive into JaxPM’s capabilities, please explore the **notebook section** for detailed tutorials and examples on various setups, from single-device simulations to multi-host configurations. You can find the notebooks' [README here](notebooks/README.md) for a structured guide through each tutorial.
|
||||||
|
|
||||||
|
|
||||||
## Contributors ✨
|
## Contributors ✨
|
||||||
|
|
||||||
|
|
52
design.md
52
design.md
|
@ -1,52 +0,0 @@
|
||||||
# Design Document for JaxPM
|
|
||||||
|
|
||||||
This document aims to detail some of the API, implementation choices, and internal mechanism.
|
|
||||||
|
|
||||||
## Objective
|
|
||||||
|
|
||||||
Provide a user-friendly framework for distributed Particle-Mesh N-body simulations.
|
|
||||||
|
|
||||||
## Related Work
|
|
||||||
|
|
||||||
This project would be the latest iteration of a number of past libraries that have provided differentiable N-body models.
|
|
||||||
|
|
||||||
- [FlowPM](https://github.com/DifferentiableUniverseInitiative/flowpm): TensorFlow
|
|
||||||
- [vmad FastPM](https://github.com/rainwoodman/vmad/blob/master/vmad/lib/fastpm.py): VMAD
|
|
||||||
- Borg
|
|
||||||
|
|
||||||
|
|
||||||
In addition, a number of fast N-body simulation projets exist out there:
|
|
||||||
- [FastPM](https://github.com/fastpm/fastpm)
|
|
||||||
- ...
|
|
||||||
|
|
||||||
## Design Overview
|
|
||||||
|
|
||||||
### Coding principles
|
|
||||||
|
|
||||||
Following recent trends and JAX philosophy, the library should have a functional programming type of interface.
|
|
||||||
|
|
||||||
|
|
||||||
### Illustration of API
|
|
||||||
|
|
||||||
Here is a potential illustration of what the user interface could be for the simulation code:
|
|
||||||
```python
|
|
||||||
import jaxpm as jpm
|
|
||||||
import jax_cosmo as jc
|
|
||||||
|
|
||||||
# Instantiate differentiable cosmology object
|
|
||||||
cosmo = jc.Planck()
|
|
||||||
|
|
||||||
# Creates initial conditions
|
|
||||||
inital_conditions = jpm.generate_ic(cosmo, boxsize, nmesh, dtype='float32')
|
|
||||||
|
|
||||||
# Create a particular solver
|
|
||||||
solver = jpm.solvers.fastpm(cosmo, B=1)
|
|
||||||
|
|
||||||
# Initialize and run the simulation
|
|
||||||
state = solver.init(initial_conditions)
|
|
||||||
state = solver.nbody(state)
|
|
||||||
|
|
||||||
# Painting the results
|
|
||||||
density = jpm.zeros(boxsize, nmesh)
|
|
||||||
density = jpm.paint(density, state.positions)
|
|
||||||
```
|
|
|
@ -1,96 +0,0 @@
|
||||||
# Can be executed with:
|
|
||||||
# srun -n 4 -c 32 --gpus-per-task 1 --gpu-bind=none python test_pfft.py
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import jax
|
|
||||||
import jax.lax as lax
|
|
||||||
import jax.numpy as jnp
|
|
||||||
import numpy as np
|
|
||||||
from jax.experimental.maps import Mesh, xmap
|
|
||||||
from jax.experimental.pjit import PartitionSpec, pjit
|
|
||||||
|
|
||||||
jax.distributed.initialize()
|
|
||||||
|
|
||||||
cube_size = 2048
|
|
||||||
|
|
||||||
|
|
||||||
@partial(xmap,
|
|
||||||
in_axes=[...],
|
|
||||||
out_axes=['x', 'y', ...],
|
|
||||||
axis_sizes={
|
|
||||||
'x': cube_size,
|
|
||||||
'y': cube_size
|
|
||||||
},
|
|
||||||
axis_resources={
|
|
||||||
'x': 'nx',
|
|
||||||
'y': 'ny',
|
|
||||||
'key_x': 'nx',
|
|
||||||
'key_y': 'ny'
|
|
||||||
})
|
|
||||||
def pnormal(key):
|
|
||||||
return jax.random.normal(key, shape=[cube_size])
|
|
||||||
|
|
||||||
|
|
||||||
@partial(xmap,
|
|
||||||
in_axes={
|
|
||||||
0: 'x',
|
|
||||||
1: 'y'
|
|
||||||
},
|
|
||||||
out_axes=['x', 'y', ...],
|
|
||||||
axis_resources={
|
|
||||||
'x': 'nx',
|
|
||||||
'y': 'ny'
|
|
||||||
})
|
|
||||||
@jax.jit
|
|
||||||
def pfft3d(mesh):
|
|
||||||
# [x, y, z]
|
|
||||||
mesh = jnp.fft.fft(mesh) # Transform on z
|
|
||||||
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now x is exposed, [z,y,x]
|
|
||||||
mesh = jnp.fft.fft(mesh) # Transform on x
|
|
||||||
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now y is exposed, [z,x,y]
|
|
||||||
mesh = jnp.fft.fft(mesh) # Transform on y
|
|
||||||
# [z, x, y]
|
|
||||||
return mesh
|
|
||||||
|
|
||||||
|
|
||||||
@partial(xmap,
|
|
||||||
in_axes={
|
|
||||||
0: 'x',
|
|
||||||
1: 'y'
|
|
||||||
},
|
|
||||||
out_axes=['x', 'y', ...],
|
|
||||||
axis_resources={
|
|
||||||
'x': 'nx',
|
|
||||||
'y': 'ny'
|
|
||||||
})
|
|
||||||
@jax.jit
|
|
||||||
def pifft3d(mesh):
|
|
||||||
# [z, x, y]
|
|
||||||
mesh = jnp.fft.ifft(mesh) # Transform on y
|
|
||||||
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now x is exposed, [z,y,x]
|
|
||||||
mesh = jnp.fft.ifft(mesh) # Transform on x
|
|
||||||
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now z is exposed, [x,y,z]
|
|
||||||
mesh = jnp.fft.ifft(mesh) # Transform on z
|
|
||||||
# [x, y, z]
|
|
||||||
return mesh
|
|
||||||
|
|
||||||
|
|
||||||
key = jax.random.PRNGKey(42)
|
|
||||||
# keys = jax.random.split(key, 4).reshape((2,2,2))
|
|
||||||
|
|
||||||
# We reshape all our devices to the mesh shape we want
|
|
||||||
devices = np.array(jax.devices()).reshape((2, 4))
|
|
||||||
|
|
||||||
with Mesh(devices, ('nx', 'ny')):
|
|
||||||
mesh = pnormal(key)
|
|
||||||
kmesh = pfft3d(mesh)
|
|
||||||
kmesh.block_until_ready()
|
|
||||||
|
|
||||||
# jax.profiler.start_trace("tensorboard")
|
|
||||||
# with Mesh(devices, ('nx', 'ny')):
|
|
||||||
# mesh = pnormal(key)
|
|
||||||
# kmesh = pfft3d(mesh)
|
|
||||||
# kmesh.block_until_ready()
|
|
||||||
# jax.profiler.stop_trace()
|
|
||||||
|
|
||||||
print('Done')
|
|
|
@ -1,68 +0,0 @@
|
||||||
# Start this script with:
|
|
||||||
# mpirun -np 4 python test_script.py
|
|
||||||
import os
|
|
||||||
|
|
||||||
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'
|
|
||||||
import jax
|
|
||||||
import jax.lax as lax
|
|
||||||
import jax.numpy as jnp
|
|
||||||
import matplotlib.pylab as plt
|
|
||||||
import numpy as np
|
|
||||||
import tensorflow_probability as tfp
|
|
||||||
from jax.experimental.maps import mesh, xmap
|
|
||||||
from jax.experimental.pjit import PartitionSpec, pjit
|
|
||||||
|
|
||||||
tfp = tfp.substrates.jax
|
|
||||||
tfd = tfp.distributions
|
|
||||||
|
|
||||||
|
|
||||||
def cic_paint(mesh, positions):
|
|
||||||
""" Paints positions onto mesh
|
|
||||||
mesh: [nx, ny, nz]
|
|
||||||
positions: [npart, 3]
|
|
||||||
"""
|
|
||||||
positions = jnp.expand_dims(positions, 1)
|
|
||||||
floor = jnp.floor(positions)
|
|
||||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
|
||||||
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
|
|
||||||
|
|
||||||
neighboor_coords = floor + connection
|
|
||||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
|
||||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
|
||||||
|
|
||||||
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
|
|
||||||
inserted_window_dims=(0, 1, 2),
|
|
||||||
scatter_dims_to_operand_dims=(0, 1,
|
|
||||||
2))
|
|
||||||
mesh = lax.scatter_add(
|
|
||||||
mesh,
|
|
||||||
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
|
|
||||||
kernel.reshape([-1, 8]), dnums)
|
|
||||||
return mesh
|
|
||||||
|
|
||||||
|
|
||||||
# And let's draw some points from some 3D distribution
|
|
||||||
dist = tfd.MultivariateNormalDiag(loc=[16., 16., 16.],
|
|
||||||
scale_identity_multiplier=3.)
|
|
||||||
pos = dist.sample(1e4, seed=jax.random.PRNGKey(0))
|
|
||||||
|
|
||||||
f = pjit(lambda x: cic_paint(x, pos),
|
|
||||||
in_axis_resources=PartitionSpec('x', 'y', 'z'),
|
|
||||||
out_axis_resources=None)
|
|
||||||
|
|
||||||
devices = np.array(jax.devices()).reshape((2, 2, 1))
|
|
||||||
|
|
||||||
# Let's import the mesh
|
|
||||||
m = jnp.zeros([32, 32, 32])
|
|
||||||
|
|
||||||
with mesh(devices, ('x', 'y', 'z')):
|
|
||||||
# Shard the mesh, I'm not sure this is absolutely necessary
|
|
||||||
m = pjit(lambda x: x,
|
|
||||||
in_axis_resources=None,
|
|
||||||
out_axis_resources=PartitionSpec('x', 'y', 'z'))(m)
|
|
||||||
|
|
||||||
# Apply the sharded CiC function
|
|
||||||
res = f(m)
|
|
||||||
|
|
||||||
plt.imshow(res.sum(axis=2))
|
|
||||||
plt.show()
|
|
198
jaxpm/distributed.py
Normal file
198
jaxpm/distributed.py
Normal file
|
@ -0,0 +1,198 @@
|
||||||
|
from typing import Any, Callable, Hashable
|
||||||
|
|
||||||
|
Specs = Any
|
||||||
|
AxisName = Hashable
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import jaxdecomp
|
||||||
|
from jax import lax
|
||||||
|
from jax.experimental.shard_map import shard_map
|
||||||
|
from jax.sharding import AbstractMesh, Mesh
|
||||||
|
from jax.sharding import PartitionSpec as P
|
||||||
|
|
||||||
|
|
||||||
|
def autoshmap(
|
||||||
|
f: Callable,
|
||||||
|
gpu_mesh: Mesh | AbstractMesh | None,
|
||||||
|
in_specs: Specs,
|
||||||
|
out_specs: Specs,
|
||||||
|
check_rep: bool = False,
|
||||||
|
auto: frozenset[AxisName] = frozenset()) -> Callable:
|
||||||
|
"""Helper function to wrap the provided function in a shard map if
|
||||||
|
the code is being executed in a mesh context."""
|
||||||
|
if gpu_mesh is None or gpu_mesh.empty:
|
||||||
|
return f
|
||||||
|
else:
|
||||||
|
return shard_map(f, gpu_mesh, in_specs, out_specs, check_rep, auto)
|
||||||
|
|
||||||
|
|
||||||
|
def fft3d(x):
|
||||||
|
return jaxdecomp.pfft3d(x)
|
||||||
|
|
||||||
|
|
||||||
|
def ifft3d(x):
|
||||||
|
return jaxdecomp.pifft3d(x).real
|
||||||
|
|
||||||
|
|
||||||
|
def get_halo_size(halo_size, sharding):
|
||||||
|
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||||
|
if gpu_mesh is None or gpu_mesh.empty:
|
||||||
|
zero_ext = (0, 0)
|
||||||
|
zero_tuple = (0, 0)
|
||||||
|
return (zero_tuple, zero_tuple, zero_tuple), zero_ext
|
||||||
|
else:
|
||||||
|
pdims = gpu_mesh.devices.shape
|
||||||
|
halo_x = (0, 0) if pdims[0] == 1 else (halo_size, halo_size)
|
||||||
|
halo_y = (0, 0) if pdims[1] == 1 else (halo_size, halo_size)
|
||||||
|
|
||||||
|
halo_x_ext = 0 if pdims[0] == 1 else halo_size // 2
|
||||||
|
halo_y_ext = 0 if pdims[1] == 1 else halo_size // 2
|
||||||
|
return ((halo_x, halo_y, (0, 0)), (halo_x_ext, halo_y_ext))
|
||||||
|
|
||||||
|
|
||||||
|
def halo_exchange(x, halo_extents, halo_periods=(True, True)):
|
||||||
|
if (halo_extents[0] > 0 or halo_extents[1] > 0):
|
||||||
|
return jaxdecomp.halo_exchange(x, halo_extents, halo_periods)
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def slice_unpad_impl(x, pad_width):
|
||||||
|
|
||||||
|
halo_x, _ = pad_width[0]
|
||||||
|
halo_y, _ = pad_width[1]
|
||||||
|
# Apply corrections along x
|
||||||
|
x = x.at[halo_x:halo_x + halo_x // 2].add(x[:halo_x // 2])
|
||||||
|
x = x.at[-(halo_x + halo_x // 2):-halo_x].add(x[-halo_x // 2:])
|
||||||
|
# Apply corrections along y
|
||||||
|
x = x.at[:, halo_y:halo_y + halo_y // 2].add(x[:, :halo_y // 2])
|
||||||
|
x = x.at[:, -(halo_y + halo_y // 2):-halo_y].add(x[:, -halo_y // 2:])
|
||||||
|
|
||||||
|
unpad_slice = [slice(None)] * 3
|
||||||
|
if halo_x > 0:
|
||||||
|
unpad_slice[0] = slice(halo_x, -halo_x)
|
||||||
|
if halo_y > 0:
|
||||||
|
unpad_slice[1] = slice(halo_y, -halo_y)
|
||||||
|
|
||||||
|
return x[tuple(unpad_slice)]
|
||||||
|
|
||||||
|
|
||||||
|
def slice_pad(x, pad_width, sharding):
|
||||||
|
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||||
|
if gpu_mesh is not None and not (gpu_mesh.empty) and (
|
||||||
|
pad_width[0][0] > 0 or pad_width[1][0] > 0):
|
||||||
|
assert sharding is not None
|
||||||
|
spec = sharding.spec
|
||||||
|
return shard_map((partial(jnp.pad, pad_width=pad_width)),
|
||||||
|
mesh=gpu_mesh,
|
||||||
|
in_specs=spec,
|
||||||
|
out_specs=spec)(x)
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def slice_unpad(x, pad_width, sharding):
|
||||||
|
mesh = sharding.mesh if sharding is not None else None
|
||||||
|
if mesh is not None and not (mesh.empty) and (pad_width[0][0] > 0
|
||||||
|
or pad_width[1][0] > 0):
|
||||||
|
assert sharding is not None
|
||||||
|
spec = sharding.spec
|
||||||
|
return shard_map(partial(slice_unpad_impl, pad_width=pad_width),
|
||||||
|
mesh=mesh,
|
||||||
|
in_specs=spec,
|
||||||
|
out_specs=spec)(x)
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def get_local_shape(mesh_shape, sharding=None):
|
||||||
|
""" Helper function to get the local size of a mesh given the global size.
|
||||||
|
"""
|
||||||
|
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||||
|
if gpu_mesh is None or gpu_mesh.empty:
|
||||||
|
return mesh_shape
|
||||||
|
else:
|
||||||
|
pdims = gpu_mesh.devices.shape
|
||||||
|
return [
|
||||||
|
mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1],
|
||||||
|
*mesh_shape[2:]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _axis_names(spec):
|
||||||
|
if len(spec) == 1:
|
||||||
|
x_axis, = spec
|
||||||
|
y_axis = None
|
||||||
|
single_axis = True
|
||||||
|
elif len(spec) == 2:
|
||||||
|
x_axis, y_axis = spec
|
||||||
|
if y_axis == None:
|
||||||
|
single_axis = True
|
||||||
|
elif x_axis == None:
|
||||||
|
x_axis = y_axis
|
||||||
|
single_axis = True
|
||||||
|
else:
|
||||||
|
single_axis = False
|
||||||
|
else:
|
||||||
|
raise ValueError("Only 1 or 2 axis sharding is supported")
|
||||||
|
return x_axis, y_axis, single_axis
|
||||||
|
|
||||||
|
|
||||||
|
def uniform_particles(mesh_shape, sharding=None):
|
||||||
|
|
||||||
|
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||||
|
if gpu_mesh is not None and not (gpu_mesh.empty):
|
||||||
|
local_mesh_shape = get_local_shape(mesh_shape, sharding)
|
||||||
|
spec = sharding.spec
|
||||||
|
x_axis, y_axis, single_axis = _axis_names(spec)
|
||||||
|
|
||||||
|
def particles():
|
||||||
|
x_indx = lax.axis_index(x_axis)
|
||||||
|
y_indx = 0 if single_axis else lax.axis_index(y_axis)
|
||||||
|
|
||||||
|
x = jnp.arange(local_mesh_shape[0]) + x_indx * local_mesh_shape[0]
|
||||||
|
y = jnp.arange(local_mesh_shape[1]) + y_indx * local_mesh_shape[1]
|
||||||
|
z = jnp.arange(local_mesh_shape[2])
|
||||||
|
return jnp.stack(jnp.meshgrid(x, y, z, indexing='ij'), axis=-1)
|
||||||
|
|
||||||
|
return shard_map(particles, mesh=gpu_mesh, in_specs=(),
|
||||||
|
out_specs=spec)()
|
||||||
|
else:
|
||||||
|
return jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape],
|
||||||
|
indexing='ij'),
|
||||||
|
axis=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def normal_field(mesh_shape, seed, sharding=None):
|
||||||
|
"""Generate a Gaussian random field with the given power spectrum."""
|
||||||
|
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||||
|
if gpu_mesh is not None and not (gpu_mesh.empty):
|
||||||
|
local_mesh_shape = get_local_shape(mesh_shape, sharding)
|
||||||
|
|
||||||
|
size = jax.device_count()
|
||||||
|
# rank = jax.process_index()
|
||||||
|
# process_index is multi_host only
|
||||||
|
# to make the code work both in multi host and single controller we can do this trick
|
||||||
|
keys = jax.random.split(seed, size)
|
||||||
|
spec = sharding.spec
|
||||||
|
x_axis, y_axis, single_axis = _axis_names(spec)
|
||||||
|
|
||||||
|
def normal(keys, shape, dtype):
|
||||||
|
idx = lax.axis_index(x_axis)
|
||||||
|
if not single_axis:
|
||||||
|
y_index = lax.axis_index(y_axis)
|
||||||
|
x_size = lax.psum(1, axis_name=x_axis)
|
||||||
|
idx += y_index * x_size
|
||||||
|
|
||||||
|
return jax.random.normal(key=keys[idx], shape=shape, dtype=dtype)
|
||||||
|
|
||||||
|
return shard_map(
|
||||||
|
partial(normal, shape=local_mesh_shape, dtype='float32'),
|
||||||
|
mesh=gpu_mesh,
|
||||||
|
in_specs=P(None),
|
||||||
|
out_specs=spec)(keys) # yapf: disable
|
||||||
|
else:
|
||||||
|
return jax.random.normal(shape=mesh_shape, key=seed)
|
|
@ -1,6 +1,6 @@
|
||||||
import jax.numpy as np
|
import jax.numpy as np
|
||||||
|
from jax.numpy import interp
|
||||||
from jax_cosmo.background import *
|
from jax_cosmo.background import *
|
||||||
from jax_cosmo.scipy.interpolate import interp
|
|
||||||
from jax_cosmo.scipy.ode import odeint
|
from jax_cosmo.scipy.ode import odeint
|
||||||
|
|
||||||
|
|
||||||
|
@ -587,5 +587,6 @@ def dGf2a(cosmo, a):
|
||||||
cache = cosmo._workspace['background.growth_factor']
|
cache = cosmo._workspace['background.growth_factor']
|
||||||
f2p = cache['h2'] / cache['a'] * cache['g2']
|
f2p = cache['h2'] / cache['a'] * cache['g2']
|
||||||
f2p = interp(np.log(a), np.log(cache['a']), f2p)
|
f2p = interp(np.log(a), np.log(cache['a']), f2p)
|
||||||
E = E(cosmo, a)
|
E_a = E(cosmo, a)
|
||||||
return (f2p * a**3 * E + D2f * a**3 * dEa(cosmo, a) + 3 * a**2 * E * D2f)
|
return (f2p * a**3 * E_a + D2f * a**3 * dEa(cosmo, a) +
|
||||||
|
3 * a**2 * E_a * D2f)
|
||||||
|
|
|
@ -1,30 +1,46 @@
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from jax.lax import FftType
|
||||||
|
from jax.sharding import PartitionSpec as P
|
||||||
|
from jaxdecomp import fftfreq3d, get_output_specs
|
||||||
|
|
||||||
|
from jaxpm.distributed import autoshmap
|
||||||
|
|
||||||
|
|
||||||
def fftk(shape, symmetric=True, finite=False, dtype=np.float32):
|
def fftk(k_array):
|
||||||
"""
|
"""
|
||||||
Return wave-vectors for a given shape
|
Generate Fourier transform wave numbers for a given mesh.
|
||||||
"""
|
|
||||||
k = []
|
|
||||||
for d in range(len(shape)):
|
|
||||||
kd = np.fft.fftfreq(shape[d])
|
|
||||||
kd *= 2 * np.pi
|
|
||||||
kdshape = np.ones(len(shape), dtype='int')
|
|
||||||
if symmetric and d == len(shape) - 1:
|
|
||||||
kd = kd[:shape[d] // 2 + 1]
|
|
||||||
kdshape[d] = len(kd)
|
|
||||||
kd = kd.reshape(kdshape)
|
|
||||||
|
|
||||||
k.append(kd.astype(dtype))
|
Args:
|
||||||
del kd, kdshape
|
nc (int): Shape of the mesh grid.
|
||||||
return k
|
|
||||||
|
Returns:
|
||||||
|
list: List of wave number arrays for each dimension in
|
||||||
|
the order [kx, ky, kz].
|
||||||
|
"""
|
||||||
|
kx, ky, kz = fftfreq3d(k_array)
|
||||||
|
# to the order of dimensions in the transposed FFT
|
||||||
|
return kx, ky, kz
|
||||||
|
|
||||||
|
|
||||||
|
def interpolate_power_spectrum(input, k, pk, sharding=None):
|
||||||
|
|
||||||
|
pk_fn = lambda x: jnp.interp(x.reshape(-1), k, pk).reshape(x.shape)
|
||||||
|
|
||||||
|
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||||
|
specs = sharding.spec if sharding is not None else P()
|
||||||
|
out_specs = P(*get_output_specs(
|
||||||
|
FftType.FFT, specs, mesh=gpu_mesh)) if gpu_mesh is not None else P()
|
||||||
|
|
||||||
|
return autoshmap(pk_fn,
|
||||||
|
gpu_mesh=gpu_mesh,
|
||||||
|
in_specs=out_specs,
|
||||||
|
out_specs=out_specs)(input)
|
||||||
|
|
||||||
|
|
||||||
def gradient_kernel(kvec, direction, order=1):
|
def gradient_kernel(kvec, direction, order=1):
|
||||||
"""
|
"""
|
||||||
Computes the gradient kernel in the requested direction
|
Computes the gradient kernel in the requested direction
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
-----------
|
-----------
|
||||||
kvec: list
|
kvec: list
|
||||||
|
@ -50,20 +66,27 @@ def gradient_kernel(kvec, direction, order=1):
|
||||||
return wts
|
return wts
|
||||||
|
|
||||||
|
|
||||||
def invlaplace_kernel(kvec):
|
def invlaplace_kernel(kvec, fd=False):
|
||||||
"""
|
"""
|
||||||
Compute the inverse Laplace kernel
|
Compute the inverse Laplace kernel.
|
||||||
|
|
||||||
|
cf. [Feng+2016](https://arxiv.org/pdf/1603.00476)
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
-----------
|
-----------
|
||||||
kvec: list
|
kvec: list
|
||||||
List of wave-vectors
|
List of wave-vectors
|
||||||
|
fd: bool
|
||||||
|
Finite difference kernel
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
--------
|
--------
|
||||||
wts: array
|
wts: array
|
||||||
Complex kernel values
|
Complex kernel values
|
||||||
"""
|
"""
|
||||||
|
if fd:
|
||||||
|
kk = sum((ki * jnp.sinc(ki / (2 * jnp.pi)))**2 for ki in kvec)
|
||||||
|
else:
|
||||||
kk = sum(ki**2 for ki in kvec)
|
kk = sum(ki**2 for ki in kvec)
|
||||||
kk_nozeros = jnp.where(kk == 0, 1, kk)
|
kk_nozeros = jnp.where(kk == 0, 1, kk)
|
||||||
return -jnp.where(kk == 0, 0, 1 / kk_nozeros)
|
return -jnp.where(kk == 0, 0, 1 / kk_nozeros)
|
||||||
|
@ -79,12 +102,10 @@ def longrange_kernel(kvec, r_split):
|
||||||
List of wave-vectors
|
List of wave-vectors
|
||||||
r_split: float
|
r_split: float
|
||||||
Splitting radius
|
Splitting radius
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
--------
|
--------
|
||||||
wts: array
|
wts: array
|
||||||
Complex kernel values
|
Complex kernel values
|
||||||
|
|
||||||
TODO: @modichirag add documentation
|
TODO: @modichirag add documentation
|
||||||
"""
|
"""
|
||||||
if r_split != 0:
|
if r_split != 0:
|
||||||
|
@ -105,13 +126,12 @@ def cic_compensation(kvec):
|
||||||
-----------
|
-----------
|
||||||
kvec: list
|
kvec: list
|
||||||
List of wave-vectors
|
List of wave-vectors
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
--------
|
--------
|
||||||
wts: array
|
wts: array
|
||||||
Complex kernel values
|
Complex kernel values
|
||||||
"""
|
"""
|
||||||
kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
|
kwts = [jnp.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
|
||||||
wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
|
wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
|
||||||
return wts
|
return wts
|
||||||
|
|
||||||
|
|
|
@ -1,15 +1,24 @@
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
import jax.lax as lax
|
import jax.lax as lax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
from jax.sharding import NamedSharding
|
||||||
|
from jax.sharding import PartitionSpec as P
|
||||||
|
|
||||||
|
from jaxpm.distributed import (autoshmap, fft3d, get_halo_size, halo_exchange,
|
||||||
|
ifft3d, slice_pad, slice_unpad)
|
||||||
from jaxpm.kernels import cic_compensation, fftk
|
from jaxpm.kernels import cic_compensation, fftk
|
||||||
|
from jaxpm.painting_utils import gather, scatter
|
||||||
|
|
||||||
|
|
||||||
def cic_paint(mesh, positions, weight=None):
|
def _cic_paint_impl(grid_mesh, positions, weight=None):
|
||||||
""" Paints positions onto mesh
|
""" Paints positions onto mesh
|
||||||
mesh: [nx, ny, nz]
|
mesh: [nx, ny, nz]
|
||||||
positions: [npart, 3]
|
displacement field: [nx, ny, nz, 3]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
positions = positions.reshape([-1, 3])
|
||||||
positions = jnp.expand_dims(positions, 1)
|
positions = jnp.expand_dims(positions, 1)
|
||||||
floor = jnp.floor(positions)
|
floor = jnp.floor(positions)
|
||||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
||||||
|
@ -19,40 +28,98 @@ def cic_paint(mesh, positions, weight=None):
|
||||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||||
if weight is not None:
|
if weight is not None:
|
||||||
|
if jnp.isscalar(weight):
|
||||||
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
|
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
|
||||||
|
else:
|
||||||
|
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]),
|
||||||
|
kernel)
|
||||||
|
|
||||||
neighboor_coords = jnp.mod(
|
neighboor_coords = jnp.mod(
|
||||||
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
|
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
|
||||||
jnp.array(mesh.shape))
|
jnp.array(grid_mesh.shape))
|
||||||
|
|
||||||
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
|
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
|
||||||
inserted_window_dims=(0, 1, 2),
|
inserted_window_dims=(0, 1, 2),
|
||||||
scatter_dims_to_operand_dims=(0, 1,
|
scatter_dims_to_operand_dims=(0, 1,
|
||||||
2))
|
2))
|
||||||
mesh = lax.scatter_add(mesh, neighboor_coords, kernel.reshape([-1, 8]),
|
mesh = lax.scatter_add(grid_mesh, neighboor_coords,
|
||||||
dnums)
|
kernel.reshape([-1, 8]), dnums)
|
||||||
return mesh
|
return mesh
|
||||||
|
|
||||||
|
|
||||||
def cic_read(mesh, positions):
|
@partial(jax.jit, static_argnums=(3, 4))
|
||||||
|
def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
|
||||||
|
|
||||||
|
positions = positions.reshape((*grid_mesh.shape, 3))
|
||||||
|
|
||||||
|
halo_size, halo_extents = get_halo_size(halo_size, sharding)
|
||||||
|
grid_mesh = slice_pad(grid_mesh, halo_size, sharding)
|
||||||
|
|
||||||
|
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
||||||
|
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
||||||
|
grid_mesh = autoshmap(_cic_paint_impl,
|
||||||
|
gpu_mesh=gpu_mesh,
|
||||||
|
in_specs=(spec, spec, P()),
|
||||||
|
out_specs=spec)(grid_mesh, positions, weight)
|
||||||
|
grid_mesh = halo_exchange(grid_mesh,
|
||||||
|
halo_extents=halo_extents,
|
||||||
|
halo_periods=(True, True))
|
||||||
|
grid_mesh = slice_unpad(grid_mesh, halo_size, sharding)
|
||||||
|
|
||||||
|
return grid_mesh
|
||||||
|
|
||||||
|
|
||||||
|
def _cic_read_impl(grid_mesh, positions):
|
||||||
""" Paints positions onto mesh
|
""" Paints positions onto mesh
|
||||||
mesh: [nx, ny, nz]
|
mesh: [nx, ny, nz]
|
||||||
positions: [npart, 3]
|
positions: [nx,ny,nz, 3]
|
||||||
"""
|
"""
|
||||||
|
# Save original shape for reshaping output later
|
||||||
|
original_shape = positions.shape
|
||||||
|
# Reshape positions to a flat list of 3D coordinates
|
||||||
|
positions = positions.reshape([-1, 3])
|
||||||
|
# Expand dimensions to calculate neighbor coordinates
|
||||||
positions = jnp.expand_dims(positions, 1)
|
positions = jnp.expand_dims(positions, 1)
|
||||||
|
# Floor the positions to get the base grid cell for each particle
|
||||||
floor = jnp.floor(positions)
|
floor = jnp.floor(positions)
|
||||||
|
# Define connections to calculate all neighbor coordinates
|
||||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
||||||
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
|
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
|
||||||
|
# Calculate the 8 neighboring coordinates
|
||||||
neighboor_coords = floor + connection
|
neighboor_coords = floor + connection
|
||||||
|
# Calculate kernel weights based on distance from each neighboring coordinate
|
||||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||||
|
# Modulo operation to wrap around edges if necessary
|
||||||
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'),
|
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'),
|
||||||
jnp.array(mesh.shape))
|
jnp.array(grid_mesh.shape))
|
||||||
|
# Ensure grid_mesh shape is as expected
|
||||||
|
# Retrieve values from grid_mesh at each neighboring coordinate and multiply by kernel
|
||||||
|
return (grid_mesh[neighboor_coords[..., 0],
|
||||||
|
neighboor_coords[..., 1],
|
||||||
|
neighboor_coords[..., 2]] * kernel).sum(axis=-1).reshape(original_shape[:-1]) # yapf: disable
|
||||||
|
|
||||||
return (mesh[neighboor_coords[..., 0], neighboor_coords[..., 1],
|
|
||||||
neighboor_coords[..., 3]] * kernel).sum(axis=-1)
|
@partial(jax.jit, static_argnums=(2, 3))
|
||||||
|
def cic_read(grid_mesh, positions, halo_size=0, sharding=None):
|
||||||
|
|
||||||
|
original_shape = positions.shape
|
||||||
|
positions = positions.reshape((*grid_mesh.shape, 3))
|
||||||
|
|
||||||
|
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
|
||||||
|
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
|
||||||
|
grid_mesh = halo_exchange(grid_mesh,
|
||||||
|
halo_extents=halo_extents,
|
||||||
|
halo_periods=(True, True))
|
||||||
|
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
||||||
|
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
||||||
|
|
||||||
|
displacement = autoshmap(_cic_read_impl,
|
||||||
|
gpu_mesh=gpu_mesh,
|
||||||
|
in_specs=(spec, spec),
|
||||||
|
out_specs=spec)(grid_mesh, positions)
|
||||||
|
|
||||||
|
return displacement.reshape(original_shape[:-1])
|
||||||
|
|
||||||
|
|
||||||
def cic_paint_2d(mesh, positions, weight):
|
def cic_paint_2d(mesh, positions, weight):
|
||||||
|
@ -84,6 +151,99 @@ def cic_paint_2d(mesh, positions, weight):
|
||||||
return mesh
|
return mesh
|
||||||
|
|
||||||
|
|
||||||
|
def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
|
||||||
|
|
||||||
|
halo_x, _ = halo_size[0]
|
||||||
|
halo_y, _ = halo_size[1]
|
||||||
|
|
||||||
|
original_shape = displacements.shape
|
||||||
|
particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32')
|
||||||
|
if not jnp.isscalar(weight):
|
||||||
|
if weight.shape != original_shape[:-1]:
|
||||||
|
raise ValueError("Weight shape must match particle shape")
|
||||||
|
else:
|
||||||
|
weight = weight.flatten()
|
||||||
|
# Padding is forced to be zero in a single gpu run
|
||||||
|
|
||||||
|
a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]),
|
||||||
|
jnp.arange(particle_mesh.shape[1]),
|
||||||
|
jnp.arange(particle_mesh.shape[2]),
|
||||||
|
indexing='ij')
|
||||||
|
|
||||||
|
particle_mesh = jnp.pad(particle_mesh, halo_size)
|
||||||
|
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
|
||||||
|
return scatter(pmid.reshape([-1, 3]),
|
||||||
|
displacements.reshape([-1, 3]),
|
||||||
|
particle_mesh,
|
||||||
|
chunk_size=2**24,
|
||||||
|
val=weight)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(jax.jit, static_argnums=(1, 2, 4))
|
||||||
|
def cic_paint_dx(displacements,
|
||||||
|
halo_size=0,
|
||||||
|
sharding=None,
|
||||||
|
weight=1.0,
|
||||||
|
chunk_size=2**24):
|
||||||
|
|
||||||
|
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
|
||||||
|
|
||||||
|
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
||||||
|
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
||||||
|
grid_mesh = autoshmap(partial(_cic_paint_dx_impl,
|
||||||
|
halo_size=halo_size,
|
||||||
|
weight=weight,
|
||||||
|
chunk_size=chunk_size),
|
||||||
|
gpu_mesh=gpu_mesh,
|
||||||
|
in_specs=spec,
|
||||||
|
out_specs=spec)(displacements)
|
||||||
|
|
||||||
|
grid_mesh = halo_exchange(grid_mesh,
|
||||||
|
halo_extents=halo_extents,
|
||||||
|
halo_periods=(True, True))
|
||||||
|
grid_mesh = slice_unpad(grid_mesh, halo_size, sharding)
|
||||||
|
return grid_mesh
|
||||||
|
|
||||||
|
|
||||||
|
def _cic_read_dx_impl(grid_mesh, disp, halo_size):
|
||||||
|
|
||||||
|
halo_x, _ = halo_size[0]
|
||||||
|
halo_y, _ = halo_size[1]
|
||||||
|
|
||||||
|
original_shape = [
|
||||||
|
dim - 2 * halo[0] for dim, halo in zip(grid_mesh.shape, halo_size)
|
||||||
|
]
|
||||||
|
a, b, c = jnp.meshgrid(jnp.arange(original_shape[0]),
|
||||||
|
jnp.arange(original_shape[1]),
|
||||||
|
jnp.arange(original_shape[2]),
|
||||||
|
indexing='ij')
|
||||||
|
|
||||||
|
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
|
||||||
|
|
||||||
|
pmid = pmid.reshape([-1, 3])
|
||||||
|
disp = disp.reshape([-1, 3])
|
||||||
|
|
||||||
|
return gather(pmid, disp, grid_mesh).reshape(original_shape)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(jax.jit, static_argnums=(2, 3))
|
||||||
|
def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None):
|
||||||
|
|
||||||
|
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
|
||||||
|
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
|
||||||
|
grid_mesh = halo_exchange(grid_mesh,
|
||||||
|
halo_extents=halo_extents,
|
||||||
|
halo_periods=(True, True))
|
||||||
|
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
||||||
|
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
||||||
|
displacements = autoshmap(partial(_cic_read_dx_impl, halo_size=halo_size),
|
||||||
|
gpu_mesh=gpu_mesh,
|
||||||
|
in_specs=(spec),
|
||||||
|
out_specs=spec)(grid_mesh, disp)
|
||||||
|
|
||||||
|
return displacements
|
||||||
|
|
||||||
|
|
||||||
def compensate_cic(field):
|
def compensate_cic(field):
|
||||||
"""
|
"""
|
||||||
Compensate for CiC painting
|
Compensate for CiC painting
|
||||||
|
@ -92,9 +252,8 @@ def compensate_cic(field):
|
||||||
Returns:
|
Returns:
|
||||||
compensated_field
|
compensated_field
|
||||||
"""
|
"""
|
||||||
nc = field.shape
|
delta_k = fft3d(field)
|
||||||
kvec = fftk(nc)
|
|
||||||
|
|
||||||
delta_k = jnp.fft.rfftn(field)
|
kvec = fftk(delta_k)
|
||||||
delta_k = cic_compensation(kvec) * delta_k
|
delta_k = cic_compensation(kvec) * delta_k
|
||||||
return jnp.fft.irfftn(delta_k)
|
return ifft3d(delta_k)
|
||||||
|
|
190
jaxpm/painting_utils.py
Normal file
190
jaxpm/painting_utils.py
Normal file
|
@ -0,0 +1,190 @@
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from jax.lax import scan
|
||||||
|
|
||||||
|
|
||||||
|
def _chunk_split(ptcl_num, chunk_size, *arrays):
|
||||||
|
"""Split and reshape particle arrays into chunks and remainders, with the remainders
|
||||||
|
preceding the chunks. 0D ones are duplicated as full arrays in the chunks."""
|
||||||
|
chunk_size = ptcl_num if chunk_size is None else min(chunk_size, ptcl_num)
|
||||||
|
remainder_size = ptcl_num % chunk_size
|
||||||
|
chunk_num = ptcl_num // chunk_size
|
||||||
|
|
||||||
|
remainder = None
|
||||||
|
chunks = arrays
|
||||||
|
if remainder_size:
|
||||||
|
remainder = [x[:remainder_size] if x.ndim != 0 else x for x in arrays]
|
||||||
|
chunks = [x[remainder_size:] if x.ndim != 0 else x for x in arrays]
|
||||||
|
|
||||||
|
# `scan` triggers errors in scatter and gather without the `full`
|
||||||
|
chunks = [
|
||||||
|
x.reshape(chunk_num, chunk_size, *x.shape[1:])
|
||||||
|
if x.ndim != 0 else jnp.full(chunk_num, x) for x in chunks
|
||||||
|
]
|
||||||
|
|
||||||
|
return remainder, chunks
|
||||||
|
|
||||||
|
|
||||||
|
def enmesh(base_indices, displacements, cell_size, base_shape, offset,
|
||||||
|
new_cell_size, new_shape):
|
||||||
|
"""Multilinear enmeshing."""
|
||||||
|
base_indices = jnp.asarray(base_indices)
|
||||||
|
displacements = jnp.asarray(displacements)
|
||||||
|
with jax.experimental.enable_x64():
|
||||||
|
cell_size = jnp.float64(
|
||||||
|
cell_size) if new_cell_size is not None else jnp.array(
|
||||||
|
cell_size, dtype=displacements.dtype)
|
||||||
|
if base_shape is not None:
|
||||||
|
base_shape = jnp.array(base_shape, dtype=base_indices.dtype)
|
||||||
|
offset = jnp.float64(offset)
|
||||||
|
if new_cell_size is not None:
|
||||||
|
new_cell_size = jnp.float64(new_cell_size)
|
||||||
|
if new_shape is not None:
|
||||||
|
new_shape = jnp.array(new_shape, dtype=base_indices.dtype)
|
||||||
|
|
||||||
|
spatial_dim = base_indices.shape[1]
|
||||||
|
neighbor_offsets = (
|
||||||
|
jnp.arange(2**spatial_dim, dtype=base_indices.dtype)[:, jnp.newaxis] >>
|
||||||
|
jnp.arange(spatial_dim, dtype=base_indices.dtype)) & 1
|
||||||
|
|
||||||
|
if new_cell_size is not None:
|
||||||
|
particle_positions = base_indices * cell_size + displacements - offset
|
||||||
|
particle_positions = particle_positions[:, jnp.
|
||||||
|
newaxis] # insert neighbor axis
|
||||||
|
new_indices = particle_positions + neighbor_offsets * new_cell_size # multilinear
|
||||||
|
|
||||||
|
if base_shape is not None:
|
||||||
|
grid_length = base_shape * cell_size
|
||||||
|
new_indices %= grid_length
|
||||||
|
|
||||||
|
new_indices //= new_cell_size
|
||||||
|
new_displacements = particle_positions - new_indices * new_cell_size
|
||||||
|
|
||||||
|
if base_shape is not None:
|
||||||
|
new_displacements -= jnp.rint(
|
||||||
|
new_displacements / grid_length
|
||||||
|
) * grid_length # also abs(new_displacements) < new_cell_size is expected
|
||||||
|
|
||||||
|
new_indices = new_indices.astype(base_indices.dtype)
|
||||||
|
new_displacements = new_displacements.astype(displacements.dtype)
|
||||||
|
new_cell_size = new_cell_size.astype(displacements.dtype)
|
||||||
|
|
||||||
|
new_displacements /= new_cell_size
|
||||||
|
else:
|
||||||
|
offset_indices, offset_displacements = jnp.divmod(offset, cell_size)
|
||||||
|
base_indices -= offset_indices.astype(base_indices.dtype)
|
||||||
|
displacements -= offset_displacements.astype(displacements.dtype)
|
||||||
|
|
||||||
|
# insert neighbor axis
|
||||||
|
base_indices = base_indices[:, jnp.newaxis]
|
||||||
|
displacements = displacements[:, jnp.newaxis]
|
||||||
|
|
||||||
|
# multilinear
|
||||||
|
displacements /= cell_size
|
||||||
|
new_indices = jnp.floor(displacements).astype(base_indices.dtype)
|
||||||
|
new_indices += neighbor_offsets
|
||||||
|
new_displacements = displacements - new_indices
|
||||||
|
new_indices += base_indices
|
||||||
|
|
||||||
|
if base_shape is not None:
|
||||||
|
new_indices %= base_shape
|
||||||
|
|
||||||
|
weights = 1 - jnp.abs(new_displacements)
|
||||||
|
|
||||||
|
if base_shape is None and new_shape is not None: # all new_indices >= 0 if base_shape is not None
|
||||||
|
new_indices = jnp.where(new_indices < 0, new_shape, new_indices)
|
||||||
|
|
||||||
|
weights = weights.prod(axis=-1)
|
||||||
|
|
||||||
|
return new_indices, weights
|
||||||
|
|
||||||
|
|
||||||
|
def _scatter_chunk(carry, chunk):
|
||||||
|
mesh, offset, cell_size, mesh_shape = carry
|
||||||
|
pmid, disp, val = chunk
|
||||||
|
spatial_ndim = pmid.shape[1]
|
||||||
|
spatial_shape = mesh.shape
|
||||||
|
|
||||||
|
# multilinear mesh indices and fractions
|
||||||
|
ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size,
|
||||||
|
spatial_shape)
|
||||||
|
# scatter
|
||||||
|
ind = tuple(ind[..., i] for i in range(spatial_ndim))
|
||||||
|
mesh = mesh.at[ind].add(jnp.multiply(jnp.expand_dims(val, axis=-1), frac))
|
||||||
|
carry = mesh, offset, cell_size, mesh_shape
|
||||||
|
return carry, None
|
||||||
|
|
||||||
|
|
||||||
|
def scatter(pmid,
|
||||||
|
disp,
|
||||||
|
mesh,
|
||||||
|
chunk_size=2**24,
|
||||||
|
val=1.,
|
||||||
|
offset=0,
|
||||||
|
cell_size=1.):
|
||||||
|
ptcl_num, spatial_ndim = pmid.shape
|
||||||
|
val = jnp.asarray(val)
|
||||||
|
mesh = jnp.asarray(mesh)
|
||||||
|
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
|
||||||
|
carry = mesh, offset, cell_size, mesh.shape
|
||||||
|
if remainder is not None:
|
||||||
|
carry = _scatter_chunk(carry, remainder)[0]
|
||||||
|
carry = scan(_scatter_chunk, carry, chunks)[0]
|
||||||
|
mesh = carry[0]
|
||||||
|
return mesh
|
||||||
|
|
||||||
|
|
||||||
|
def _chunk_cat(remainder_array, chunked_array):
|
||||||
|
"""Reshape and concatenate one remainder and one chunked particle arrays."""
|
||||||
|
array = chunked_array.reshape(-1, *chunked_array.shape[2:])
|
||||||
|
|
||||||
|
if remainder_array is not None:
|
||||||
|
array = jnp.concatenate((remainder_array, array), axis=0)
|
||||||
|
|
||||||
|
return array
|
||||||
|
|
||||||
|
|
||||||
|
def gather(pmid, disp, mesh, chunk_size=2**24, val=0, offset=0, cell_size=1.):
|
||||||
|
ptcl_num, spatial_ndim = pmid.shape
|
||||||
|
|
||||||
|
mesh = jnp.asarray(mesh)
|
||||||
|
|
||||||
|
val = jnp.asarray(val)
|
||||||
|
|
||||||
|
if mesh.shape[spatial_ndim:] != val.shape[1:]:
|
||||||
|
raise ValueError('channel shape mismatch: '
|
||||||
|
f'{mesh.shape[spatial_ndim:]} != {val.shape[1:]}')
|
||||||
|
|
||||||
|
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
|
||||||
|
|
||||||
|
carry = mesh, offset, cell_size, mesh.shape
|
||||||
|
val_0 = None
|
||||||
|
if remainder is not None:
|
||||||
|
val_0 = _gather_chunk(carry, remainder)[1]
|
||||||
|
val = scan(_gather_chunk, carry, chunks)[1]
|
||||||
|
|
||||||
|
val = _chunk_cat(val_0, val)
|
||||||
|
|
||||||
|
return val
|
||||||
|
|
||||||
|
|
||||||
|
def _gather_chunk(carry, chunk):
|
||||||
|
mesh, offset, cell_size, mesh_shape = carry
|
||||||
|
pmid, disp, val = chunk
|
||||||
|
|
||||||
|
spatial_ndim = pmid.shape[1]
|
||||||
|
|
||||||
|
spatial_shape = mesh.shape[:spatial_ndim]
|
||||||
|
chan_ndim = mesh.ndim - spatial_ndim
|
||||||
|
chan_axis = tuple(range(-chan_ndim, 0))
|
||||||
|
|
||||||
|
# multilinear mesh indices and fractions
|
||||||
|
ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size,
|
||||||
|
spatial_shape)
|
||||||
|
|
||||||
|
# gather
|
||||||
|
ind = tuple(ind[..., i] for i in range(spatial_ndim))
|
||||||
|
frac = jnp.expand_dims(frac, chan_axis)
|
||||||
|
val += (mesh.at[ind].get(mode='drop', fill_value=0) * frac).sum(axis=1)
|
||||||
|
|
||||||
|
return carry, val
|
129
jaxpm/plotting.py
Normal file
129
jaxpm/plotting.py
Normal file
|
@ -0,0 +1,129 @@
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def plot_fields(fields_dict, sum_over=None):
|
||||||
|
"""
|
||||||
|
Plots sum projections of 3D fields along different axes,
|
||||||
|
slicing only the first `sum_over` elements along each axis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- fields: list of 3D arrays representing fields to plot
|
||||||
|
- names: list of names for each field, used in titles
|
||||||
|
- sum_over: number of slices to sum along each axis (default: fields[0].shape[0] // 8)
|
||||||
|
"""
|
||||||
|
sum_over = sum_over or list(fields_dict.values())[0].shape[0] // 8
|
||||||
|
nb_rows = len(fields_dict)
|
||||||
|
nb_cols = 3
|
||||||
|
fig, axes = plt.subplots(nb_rows, nb_cols, figsize=(15, 5 * nb_rows))
|
||||||
|
|
||||||
|
def plot_subplots(proj_axis, field, row, title):
|
||||||
|
slicing = [slice(None)] * field.ndim
|
||||||
|
slicing[proj_axis] = slice(None, sum_over)
|
||||||
|
slicing = tuple(slicing)
|
||||||
|
|
||||||
|
# Sum projection over the specified axis and plot
|
||||||
|
axes[row, proj_axis].imshow(
|
||||||
|
field[slicing].sum(axis=proj_axis) + 1,
|
||||||
|
cmap='magma',
|
||||||
|
extent=[0, field.shape[proj_axis], 0, field.shape[proj_axis]])
|
||||||
|
axes[row, proj_axis].set_xlabel('Mpc/h')
|
||||||
|
axes[row, proj_axis].set_ylabel('Mpc/h')
|
||||||
|
axes[row, proj_axis].set_title(title)
|
||||||
|
|
||||||
|
# Plot each field across the three axes
|
||||||
|
for i, (name, field) in enumerate(fields_dict.items()):
|
||||||
|
for proj_axis in range(3):
|
||||||
|
plot_subplots(proj_axis, field, i,
|
||||||
|
f"{name} projection {proj_axis}")
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def plot_fields_single_projection(fields_dict,
|
||||||
|
sum_over=None,
|
||||||
|
project_axis=0,
|
||||||
|
vmin=None,
|
||||||
|
vmax=None,
|
||||||
|
colorbar=False):
|
||||||
|
"""
|
||||||
|
Plots a single projection (along axis 0) of 3D fields in a grid,
|
||||||
|
summing over the first `sum_over` elements along the 0-axis, with 4 images per row.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- fields_dict: dictionary where keys are field names and values are 3D arrays
|
||||||
|
- sum_over: number of slices to sum along the projection axis (default: fields[0].shape[0] // 8)
|
||||||
|
"""
|
||||||
|
sum_over = sum_over or list(fields_dict.values())[0].shape[0] // 8
|
||||||
|
nb_fields = len(fields_dict)
|
||||||
|
nb_cols = 4 # Set number of images per row
|
||||||
|
nb_rows = (nb_fields + nb_cols - 1) // nb_cols # Calculate required rows
|
||||||
|
|
||||||
|
fig, axes = plt.subplots(nb_rows,
|
||||||
|
nb_cols,
|
||||||
|
figsize=(5 * nb_cols, 5 * nb_rows))
|
||||||
|
axes = np.atleast_2d(axes) # Ensure axes is always a 2D array
|
||||||
|
|
||||||
|
for i, (name, field) in enumerate(fields_dict.items()):
|
||||||
|
row, col = divmod(i, nb_cols)
|
||||||
|
|
||||||
|
# Define the slice for the 0-axis projection
|
||||||
|
slicing = [slice(None)] * field.ndim
|
||||||
|
slicing[project_axis] = slice(None, sum_over)
|
||||||
|
slicing = tuple(slicing)
|
||||||
|
|
||||||
|
# Sum projection over axis 0 and plot
|
||||||
|
a = axes[row,
|
||||||
|
col].imshow(field[slicing].sum(axis=project_axis) + 1,
|
||||||
|
cmap='magma',
|
||||||
|
extent=[0, field.shape[1], 0, field.shape[2]],
|
||||||
|
vmin=vmin,
|
||||||
|
vmax=vmax)
|
||||||
|
axes[row, col].set_xlabel('Mpc/h')
|
||||||
|
axes[row, col].set_ylabel('Mpc/h')
|
||||||
|
axes[row, col].set_title(f"{name} projection 0")
|
||||||
|
if colorbar:
|
||||||
|
fig.colorbar(a, ax=axes[row, col], shrink=0.7)
|
||||||
|
|
||||||
|
# Remove any empty subplots
|
||||||
|
for j in range(i + 1, nb_rows * nb_cols):
|
||||||
|
fig.delaxes(axes.flatten()[j])
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def stack_slices(array):
|
||||||
|
"""
|
||||||
|
Stacks 2D slices of an array into a single array based on provided partition dimensions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- array_slices: a 2D list of array slices (list of lists format) where
|
||||||
|
array_slices[i][j] is the slice located at row i, column j in the grid.
|
||||||
|
- pdims: a tuple representing the grid dimensions (rows, columns).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- A single array constructed by stacking the slices.
|
||||||
|
"""
|
||||||
|
# Initialize an empty list to store the vertically stacked rows
|
||||||
|
pdims = array.sharding.mesh.devices.shape
|
||||||
|
|
||||||
|
field_slices = []
|
||||||
|
|
||||||
|
# Iterate over rows in pdims[0]
|
||||||
|
for i in range(pdims[0]):
|
||||||
|
row_slices = []
|
||||||
|
|
||||||
|
# Iterate over columns in pdims[1]
|
||||||
|
for j in range(pdims[1]):
|
||||||
|
slice_index = i * pdims[0] + j
|
||||||
|
row_slices.append(array.addressable_data(slice_index))
|
||||||
|
# Stack the current row of slices vertically
|
||||||
|
stacked_row = np.hstack(row_slices)
|
||||||
|
field_slices.append(stacked_row)
|
||||||
|
|
||||||
|
# Stack all rows horizontally to form the full array
|
||||||
|
full_array = np.vstack(field_slices)
|
||||||
|
|
||||||
|
return full_array
|
170
jaxpm/pm.py
170
jaxpm/pm.py
|
@ -1,50 +1,92 @@
|
||||||
import jax
|
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jax_cosmo as jc
|
import jax_cosmo as jc
|
||||||
from jax_cosmo import Cosmology
|
|
||||||
|
|
||||||
from jaxpm.growth import growth_factor, growth_rate, dGfa, growth_factor_second, growth_rate_second, dGf2a
|
from jaxpm.distributed import fft3d, ifft3d, normal_field
|
||||||
from jaxpm.kernels import PGD_kernel, fftk, gradient_kernel, invlaplace_kernel, longrange_kernel
|
from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second,
|
||||||
from jaxpm.painting import cic_paint, cic_read
|
growth_rate, growth_rate_second)
|
||||||
|
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel,
|
||||||
|
invlaplace_kernel, longrange_kernel)
|
||||||
|
from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx
|
||||||
|
|
||||||
|
|
||||||
|
def pm_forces(positions,
|
||||||
def pm_forces(positions, mesh_shape, delta=None, r_split=0):
|
mesh_shape=None,
|
||||||
|
delta=None,
|
||||||
|
r_split=0,
|
||||||
|
paint_absolute_pos=True,
|
||||||
|
halo_size=0,
|
||||||
|
sharding=None):
|
||||||
"""
|
"""
|
||||||
Computes gravitational forces on particles using a PM scheme
|
Computes gravitational forces on particles using a PM scheme
|
||||||
"""
|
"""
|
||||||
|
if mesh_shape is None:
|
||||||
|
assert (delta is not None),\
|
||||||
|
"If mesh_shape is not provided, delta should be provided"
|
||||||
|
mesh_shape = delta.shape
|
||||||
|
|
||||||
|
if paint_absolute_pos:
|
||||||
|
paint_fn = lambda pos: cic_paint(jnp.zeros(shape=mesh_shape,
|
||||||
|
device=sharding),
|
||||||
|
pos,
|
||||||
|
halo_size=halo_size,
|
||||||
|
sharding=sharding)
|
||||||
|
read_fn = lambda grid_mesh, pos: cic_read(
|
||||||
|
grid_mesh, pos, halo_size=halo_size, sharding=sharding)
|
||||||
|
else:
|
||||||
|
paint_fn = lambda disp: cic_paint_dx(
|
||||||
|
disp, halo_size=halo_size, sharding=sharding)
|
||||||
|
read_fn = lambda grid_mesh, disp: cic_read_dx(
|
||||||
|
grid_mesh, disp, halo_size=halo_size, sharding=sharding)
|
||||||
|
|
||||||
if delta is None:
|
if delta is None:
|
||||||
delta_k = jnp.fft.rfftn(cic_paint(jnp.zeros(mesh_shape), positions))
|
field = paint_fn(positions)
|
||||||
|
delta_k = fft3d(field)
|
||||||
elif jnp.isrealobj(delta):
|
elif jnp.isrealobj(delta):
|
||||||
delta_k = jnp.fft.rfftn(delta)
|
delta_k = fft3d(delta)
|
||||||
else:
|
else:
|
||||||
delta_k = delta
|
delta_k = delta
|
||||||
|
|
||||||
|
kvec = fftk(delta_k)
|
||||||
# Computes gravitational potential
|
# Computes gravitational potential
|
||||||
kvec = fftk(mesh_shape)
|
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(
|
||||||
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec, r_split=r_split)
|
kvec, r_split=r_split)
|
||||||
# Computes gravitational forces
|
# Computes gravitational forces
|
||||||
return jnp.stack([cic_read(jnp.fft.irfftn(- gradient_kernel(kvec, i) * pot_k), positions)
|
forces = jnp.stack([
|
||||||
for i in range(3)], axis=-1)
|
read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),positions
|
||||||
|
) for i in range(3)], axis=-1) # yapf: disable
|
||||||
|
|
||||||
|
return forces
|
||||||
|
|
||||||
|
|
||||||
def lpt(cosmo:Cosmology, init_mesh, positions, a, order=1):
|
def lpt(cosmo,
|
||||||
|
initial_conditions,
|
||||||
|
particles=None,
|
||||||
|
a=0.1,
|
||||||
|
halo_size=0,
|
||||||
|
sharding=None,
|
||||||
|
order=1):
|
||||||
"""
|
"""
|
||||||
Computes first and second order LPT displacement and momentum,
|
Computes first and second order LPT displacement and momentum,
|
||||||
e.g. Eq. 2 and 3 [Jenkins2010](https://arxiv.org/pdf/0910.0258)
|
e.g. Eq. 2 and 3 [Jenkins2010](https://arxiv.org/pdf/0910.0258)
|
||||||
"""
|
"""
|
||||||
|
paint_absolute_pos = particles is not None
|
||||||
|
if particles is None:
|
||||||
|
particles = jnp.zeros_like(initial_conditions,
|
||||||
|
shape=(*initial_conditions.shape, 3))
|
||||||
|
|
||||||
a = jnp.atleast_1d(a)
|
a = jnp.atleast_1d(a)
|
||||||
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
|
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
|
||||||
delta_k = jnp.fft.rfftn(init_mesh) # TODO: pass the modes directly to save one or two fft?
|
delta_k = fft3d(initial_conditions)
|
||||||
mesh_shape = init_mesh.shape
|
initial_force = pm_forces(particles,
|
||||||
|
delta=delta_k,
|
||||||
init_force = pm_forces(positions, mesh_shape, delta=delta_k)
|
paint_absolute_pos=paint_absolute_pos,
|
||||||
dx = growth_factor(cosmo, a) * init_force
|
halo_size=halo_size,
|
||||||
|
sharding=sharding)
|
||||||
|
dx = growth_factor(cosmo, a) * initial_force
|
||||||
p = a**2 * growth_rate(cosmo, a) * E * dx
|
p = a**2 * growth_rate(cosmo, a) * E * dx
|
||||||
f = a**2 * E * dGfa(cosmo, a) * init_force
|
f = a**2 * E * dGfa(cosmo, a) * initial_force
|
||||||
|
|
||||||
if order == 2:
|
if order == 2:
|
||||||
kvec = fftk(mesh_shape)
|
kvec = fftk(delta_k)
|
||||||
pot_k = delta_k * invlaplace_kernel(kvec)
|
pot_k = delta_k * invlaplace_kernel(kvec)
|
||||||
|
|
||||||
delta2 = 0
|
delta2 = 0
|
||||||
|
@ -54,7 +96,7 @@ def lpt(cosmo:Cosmology, init_mesh, positions, a, order=1):
|
||||||
# Add products of diagonal terms = 0 + s11*s00 + s22*(s11+s00)...
|
# Add products of diagonal terms = 0 + s11*s00 + s22*(s11+s00)...
|
||||||
# shear_ii = jnp.fft.irfftn(- ki**2 * pot_k)
|
# shear_ii = jnp.fft.irfftn(- ki**2 * pot_k)
|
||||||
nabla_i_nabla_i = gradient_kernel(kvec, i)**2
|
nabla_i_nabla_i = gradient_kernel(kvec, i)**2
|
||||||
shear_ii = jnp.fft.irfftn(nabla_i_nabla_i * pot_k)
|
shear_ii = ifft3d(nabla_i_nabla_i * pot_k)
|
||||||
delta2 += shear_ii * shear_acc
|
delta2 += shear_ii * shear_acc
|
||||||
shear_acc += shear_ii
|
shear_acc += shear_ii
|
||||||
|
|
||||||
|
@ -62,10 +104,16 @@ def lpt(cosmo:Cosmology, init_mesh, positions, a, order=1):
|
||||||
for j in range(i + 1, 3):
|
for j in range(i + 1, 3):
|
||||||
# Substract squared strict-up-triangle terms
|
# Substract squared strict-up-triangle terms
|
||||||
# delta2 -= jnp.fft.irfftn(- ki * kj * pot_k)**2
|
# delta2 -= jnp.fft.irfftn(- ki * kj * pot_k)**2
|
||||||
nabla_i_nabla_j = gradient_kernel(kvec, i) * gradient_kernel(kvec, j)
|
nabla_i_nabla_j = gradient_kernel(kvec, i) * gradient_kernel(
|
||||||
delta2 -= jnp.fft.irfftn(nabla_i_nabla_j * pot_k)**2
|
kvec, j)
|
||||||
|
delta2 -= ifft3d(nabla_i_nabla_j * pot_k)**2
|
||||||
|
|
||||||
init_force2 = pm_forces(positions, mesh_shape, delta=jnp.fft.rfftn(delta2))
|
delta_k2 = fft3d(delta2)
|
||||||
|
init_force2 = pm_forces(particles,
|
||||||
|
delta=delta_k2,
|
||||||
|
paint_absolute_pos=paint_absolute_pos,
|
||||||
|
halo_size=halo_size,
|
||||||
|
sharding=sharding)
|
||||||
# NOTE: growth_factor_second is renormalized: - D2 = 3/7 * growth_factor_second
|
# NOTE: growth_factor_second is renormalized: - D2 = 3/7 * growth_factor_second
|
||||||
dx2 = 3 / 7 * growth_factor_second(cosmo, a) * init_force2
|
dx2 = 3 / 7 * growth_factor_second(cosmo, a) * init_force2
|
||||||
p2 = a**2 * growth_rate_second(cosmo, a) * E * dx2
|
p2 = a**2 * growth_rate_second(cosmo, a) * E * dx2
|
||||||
|
@ -78,23 +126,28 @@ def lpt(cosmo:Cosmology, init_mesh, positions, a, order=1):
|
||||||
return dx, p, f
|
return dx, p, f
|
||||||
|
|
||||||
|
|
||||||
def linear_field(mesh_shape, box_size, pk, seed):
|
def linear_field(mesh_shape, box_size, pk, seed, sharding=None):
|
||||||
"""
|
"""
|
||||||
Generate initial conditions.
|
Generate initial conditions.
|
||||||
"""
|
"""
|
||||||
kvec = fftk(mesh_shape)
|
# Initialize a random field with one slice on each gpu
|
||||||
|
field = normal_field(mesh_shape, seed=seed, sharding=sharding)
|
||||||
|
field = fft3d(field)
|
||||||
|
kvec = fftk(field)
|
||||||
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2
|
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2
|
||||||
for i, kk in enumerate(kvec))**0.5
|
for i, kk in enumerate(kvec))**0.5
|
||||||
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (
|
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (
|
||||||
box_size[0] * box_size[1] * box_size[2])
|
box_size[0] * box_size[1] * box_size[2])
|
||||||
|
|
||||||
field = jax.random.normal(seed, mesh_shape)
|
field = field * (pkmesh)**0.5
|
||||||
field = jnp.fft.rfftn(field) * pkmesh**0.5
|
field = ifft3d(field)
|
||||||
field = jnp.fft.irfftn(field)
|
|
||||||
return field
|
return field
|
||||||
|
|
||||||
|
|
||||||
def make_ode_fn(mesh_shape):
|
def make_ode_fn(mesh_shape,
|
||||||
|
paint_absolute_pos=True,
|
||||||
|
halo_size=0,
|
||||||
|
sharding=None):
|
||||||
|
|
||||||
def nbody_ode(state, a, cosmo):
|
def nbody_ode(state, a, cosmo):
|
||||||
"""
|
"""
|
||||||
|
@ -102,7 +155,11 @@ def make_ode_fn(mesh_shape):
|
||||||
"""
|
"""
|
||||||
pos, vel = state
|
pos, vel = state
|
||||||
|
|
||||||
forces = pm_forces(pos, mesh_shape=mesh_shape) * 1.5 * cosmo.Omega_m
|
forces = pm_forces(pos,
|
||||||
|
mesh_shape=mesh_shape,
|
||||||
|
paint_absolute_pos=paint_absolute_pos,
|
||||||
|
halo_size=halo_size,
|
||||||
|
sharding=sharding) * 1.5 * cosmo.Omega_m
|
||||||
|
|
||||||
# Computes the update of position (drift)
|
# Computes the update of position (drift)
|
||||||
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
|
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
|
||||||
|
@ -114,16 +171,24 @@ def make_ode_fn(mesh_shape):
|
||||||
|
|
||||||
return nbody_ode
|
return nbody_ode
|
||||||
|
|
||||||
def get_ode_fn(cosmo:Cosmology, mesh_shape):
|
|
||||||
|
def make_diffrax_ode(cosmo,
|
||||||
|
mesh_shape,
|
||||||
|
paint_absolute_pos=True,
|
||||||
|
halo_size=0,
|
||||||
|
sharding=None):
|
||||||
|
|
||||||
def nbody_ode(a, state, args):
|
def nbody_ode(a, state, args):
|
||||||
"""
|
"""
|
||||||
State is an array [position, velocities]
|
state is a tuple (position, velocities)
|
||||||
|
|
||||||
Compatible with [Diffrax API](https://docs.kidger.site/diffrax/)
|
|
||||||
"""
|
"""
|
||||||
pos, vel = state
|
pos, vel = state
|
||||||
forces = pm_forces(pos, mesh_shape) * 1.5 * cosmo.Omega_m
|
|
||||||
|
forces = pm_forces(pos,
|
||||||
|
mesh_shape=mesh_shape,
|
||||||
|
paint_absolute_pos=paint_absolute_pos,
|
||||||
|
halo_size=halo_size,
|
||||||
|
sharding=sharding) * 1.5 * cosmo.Omega_m
|
||||||
|
|
||||||
# Computes the update of position (drift)
|
# Computes the update of position (drift)
|
||||||
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
|
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
|
||||||
|
@ -145,16 +210,19 @@ def pgd_correction(pos, mesh_shape, params):
|
||||||
pos: particle positions [npart, 3]
|
pos: particle positions [npart, 3]
|
||||||
params: [alpha, kl, ks] pgd parameters
|
params: [alpha, kl, ks] pgd parameters
|
||||||
"""
|
"""
|
||||||
kvec = fftk(mesh_shape)
|
|
||||||
delta = cic_paint(jnp.zeros(mesh_shape), pos)
|
delta = cic_paint(jnp.zeros(mesh_shape), pos)
|
||||||
|
delta_k = fft3d(delta)
|
||||||
|
kvec = fftk(delta_k)
|
||||||
alpha, kl, ks = params
|
alpha, kl, ks = params
|
||||||
delta_k = jnp.fft.rfftn(delta)
|
|
||||||
PGD_range = PGD_kernel(kvec, kl, ks)
|
PGD_range = PGD_kernel(kvec, kl, ks)
|
||||||
|
|
||||||
pot_k_pgd = (delta_k * invlaplace_kernel(kvec)) * PGD_range
|
pot_k_pgd = (delta_k * invlaplace_kernel(kvec)) * PGD_range
|
||||||
|
|
||||||
forces_pgd= jnp.stack([cic_read(jnp.fft.irfftn(- gradient_kernel(kvec, i)*pot_k_pgd), pos)
|
forces_pgd = jnp.stack([
|
||||||
for i in range(3)],axis=-1)
|
cic_read(fft3d(-gradient_kernel(kvec, i) * pot_k_pgd), pos)
|
||||||
|
for i in range(3)
|
||||||
|
],
|
||||||
|
axis=-1)
|
||||||
|
|
||||||
dpos_pgd = forces_pgd * alpha
|
dpos_pgd = forces_pgd * alpha
|
||||||
|
|
||||||
|
@ -162,27 +230,30 @@ def pgd_correction(pos, mesh_shape, params):
|
||||||
|
|
||||||
|
|
||||||
def make_neural_ode_fn(model, mesh_shape):
|
def make_neural_ode_fn(model, mesh_shape):
|
||||||
|
|
||||||
def neural_nbody_ode(state, a, cosmo: Cosmology, params):
|
def neural_nbody_ode(state, a, cosmo: Cosmology, params):
|
||||||
"""
|
"""
|
||||||
state is a tuple (position, velocities)
|
state is a tuple (position, velocities)
|
||||||
"""
|
"""
|
||||||
pos, vel = state
|
pos, vel = state
|
||||||
kvec = fftk(mesh_shape)
|
|
||||||
|
|
||||||
delta = cic_paint(jnp.zeros(mesh_shape), pos)
|
delta = cic_paint(jnp.zeros(mesh_shape), pos)
|
||||||
|
delta_k = fft3d(delta)
|
||||||
delta_k = jnp.fft.rfftn(delta)
|
kvec = fftk(delta_k)
|
||||||
|
|
||||||
# Computes gravitational potential
|
# Computes gravitational potential
|
||||||
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec, r_split=0)
|
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec,
|
||||||
|
r_split=0)
|
||||||
|
|
||||||
# Apply a correction filter
|
# Apply a correction filter
|
||||||
kk = jnp.sqrt(sum((ki / jnp.pi)**2 for ki in kvec))
|
kk = jnp.sqrt(sum((ki / jnp.pi)**2 for ki in kvec))
|
||||||
pot_k = pot_k * (1. + model.apply(params, kk, jnp.atleast_1d(a)))
|
pot_k = pot_k * (1. + model.apply(params, kk, jnp.atleast_1d(a)))
|
||||||
|
|
||||||
# Computes gravitational forces
|
# Computes gravitational forces
|
||||||
forces = jnp.stack([cic_read(jnp.fft.irfftn(- gradient_kernel(kvec, i)*pot_k), pos)
|
forces = jnp.stack([
|
||||||
for i in range(3)],axis=-1)
|
cic_read(fft3d(-gradient_kernel(kvec, i) * pot_k), pos)
|
||||||
|
for i in range(3)
|
||||||
|
],
|
||||||
|
axis=-1)
|
||||||
|
|
||||||
forces = forces * 1.5 * cosmo.Omega_m
|
forces = forces * 1.5 * cosmo.Omega_m
|
||||||
|
|
||||||
|
@ -193,4 +264,5 @@ def make_neural_ode_fn(model, mesh_shape):
|
||||||
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
|
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
|
||||||
|
|
||||||
return dpos, dvel
|
return dpos, dvel
|
||||||
|
|
||||||
return neural_nbody_ode
|
return neural_nbody_ode
|
||||||
|
|
211
jaxpm/utils.py
211
jaxpm/utils.py
|
@ -1,89 +1,161 @@
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from jax.scipy.stats import norm
|
from jax.scipy.stats import norm
|
||||||
|
from scipy.special import legendre
|
||||||
|
|
||||||
__all__ = ['power_spectrum']
|
__all__ = [
|
||||||
|
'power_spectrum', 'transfer', 'coherence', 'pktranscoh',
|
||||||
|
'cross_correlation_coefficients', 'gaussian_smoothing'
|
||||||
def _initialize_pk(shape, boxsize, kmin, dk):
|
|
||||||
"""
|
|
||||||
Helper function to initialize various (fixed) values for powerspectra... not differentiable!
|
|
||||||
"""
|
|
||||||
I = np.eye(len(shape), dtype='int') * -2 + 1
|
|
||||||
|
|
||||||
W = np.empty(shape, dtype='f4')
|
|
||||||
W[...] = 2.0
|
|
||||||
W[..., 0] = 1.0
|
|
||||||
W[..., -1] = 1.0
|
|
||||||
|
|
||||||
kmax = np.pi * np.min(np.array(shape)) / np.max(np.array(boxsize)) + dk / 2
|
|
||||||
kedges = np.arange(kmin, kmax, dk)
|
|
||||||
|
|
||||||
k = [
|
|
||||||
np.fft.fftfreq(N, 1. / (N * 2 * np.pi / L))[:pkshape].reshape(kshape)
|
|
||||||
for N, L, kshape, pkshape in zip(shape, boxsize, I, shape)
|
|
||||||
]
|
]
|
||||||
kmag = sum(ki**2 for ki in k)**0.5
|
|
||||||
|
|
||||||
xsum = np.zeros(len(kedges) + 1)
|
|
||||||
Nsum = np.zeros(len(kedges) + 1)
|
|
||||||
|
|
||||||
dig = np.digitize(kmag.flat, kedges)
|
|
||||||
|
|
||||||
xsum.flat += np.bincount(dig, weights=(W * kmag).flat, minlength=xsum.size)
|
|
||||||
Nsum.flat += np.bincount(dig, weights=W.flat, minlength=xsum.size)
|
|
||||||
return dig, Nsum, xsum, W, k, kedges
|
|
||||||
|
|
||||||
|
|
||||||
def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
|
def _initialize_pk(mesh_shape, box_shape, kedges, los):
|
||||||
"""
|
"""
|
||||||
Calculate the powerspectra given real space field
|
Parameters
|
||||||
|
----------
|
||||||
Args:
|
mesh_shape : tuple of int
|
||||||
|
Shape of the mesh grid.
|
||||||
field: real valued field
|
box_shape : tuple of float
|
||||||
kmin: minimum k-value for binned powerspectra
|
Physical dimensions of the box.
|
||||||
dk: differential in each kbin
|
kedges : None, int, float, or list
|
||||||
boxsize: length of each boxlength (can be strangly shaped?)
|
If None, set dk to twice the minimum.
|
||||||
|
If int, specifies number of edges.
|
||||||
Returns:
|
If float, specifies dk.
|
||||||
|
los : array_like
|
||||||
kbins: the central value of the bins for plotting
|
Line-of-sight vector.
|
||||||
power: real valued array of power in each bin
|
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dig : ndarray
|
||||||
|
Indices of the bins to which each value in input array belongs.
|
||||||
|
kcount : ndarray
|
||||||
|
Count of values in each bin.
|
||||||
|
kedges : ndarray
|
||||||
|
Edges of the bins.
|
||||||
|
mumesh : ndarray
|
||||||
|
Mu values for the mesh grid.
|
||||||
"""
|
"""
|
||||||
shape = field.shape
|
kmax = np.pi * np.min(mesh_shape / box_shape) # = knyquist
|
||||||
nx, ny, nz = shape
|
|
||||||
|
|
||||||
#initialze values related to powerspectra (mode bins and weights)
|
if isinstance(kedges, None | int | float):
|
||||||
dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk)
|
if kedges is None:
|
||||||
|
dk = 2 * np.pi / np.min(
|
||||||
|
box_shape) * 2 # twice the minimum wavenumber
|
||||||
|
if isinstance(kedges, int):
|
||||||
|
dk = kmax / (kedges + 1) # final number of bins will be kedges-1
|
||||||
|
elif isinstance(kedges, float):
|
||||||
|
dk = kedges
|
||||||
|
kedges = np.arange(dk, kmax, dk) + dk / 2 # from dk/2 to kmax-dk/2
|
||||||
|
|
||||||
#fast fourier transform
|
kshapes = np.eye(len(mesh_shape), dtype=np.int32) * -2 + 1
|
||||||
fft_image = jnp.fft.fftn(field)
|
kvec = [(2 * np.pi * m / l) * np.fft.fftfreq(m).reshape(kshape)
|
||||||
|
for m, l, kshape in zip(mesh_shape, box_shape, kshapes)]
|
||||||
|
kmesh = sum(ki**2 for ki in kvec)**0.5
|
||||||
|
|
||||||
#absolute value of fast fourier transform
|
dig = np.digitize(kmesh.reshape(-1), kedges)
|
||||||
pk = jnp.real(fft_image * jnp.conj(fft_image))
|
kcount = np.bincount(dig, minlength=len(kedges) + 1)
|
||||||
|
|
||||||
#calculating powerspectra
|
# Central value of each bin
|
||||||
real = jnp.real(pk).reshape([-1])
|
# kavg = (kedges[1:] + kedges[:-1]) / 2
|
||||||
imag = jnp.imag(pk).reshape([-1])
|
kavg = np.bincount(
|
||||||
|
dig, weights=kmesh.reshape(-1), minlength=len(kedges) + 1) / kcount
|
||||||
|
kavg = kavg[1:-1]
|
||||||
|
|
||||||
Psum = jnp.bincount(dig, weights=(W.flatten() * imag),
|
if los is None:
|
||||||
length=xsum.size) * 1j
|
mumesh = 1.
|
||||||
Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
|
else:
|
||||||
|
mumesh = sum(ki * losi for ki, losi in zip(kvec, los))
|
||||||
|
kmesh_nozeros = np.where(kmesh == 0, 1, kmesh)
|
||||||
|
mumesh = np.where(kmesh == 0, 0, mumesh / kmesh_nozeros)
|
||||||
|
|
||||||
P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')
|
return dig, kcount, kavg, mumesh
|
||||||
|
|
||||||
#normalization for powerspectra
|
|
||||||
norm = np.prod(np.array(shape[:])).astype('float32')**2
|
|
||||||
|
|
||||||
#find central values of each bin
|
|
||||||
kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2
|
|
||||||
|
|
||||||
return kbins, P / norm
|
|
||||||
|
|
||||||
|
|
||||||
def cross_correlation_coefficients(field_a,field_b, kmin=5, dk=0.5, boxsize=False):
|
def power_spectrum(mesh,
|
||||||
|
mesh2=None,
|
||||||
|
box_shape=None,
|
||||||
|
kedges: int | float | list = None,
|
||||||
|
multipoles=0,
|
||||||
|
los=[0., 0., 1.]):
|
||||||
|
"""
|
||||||
|
Compute the auto and cross spectrum of 3D fields, with multipoles.
|
||||||
|
"""
|
||||||
|
# Initialize
|
||||||
|
mesh_shape = np.array(mesh.shape)
|
||||||
|
if box_shape is None:
|
||||||
|
box_shape = mesh_shape
|
||||||
|
else:
|
||||||
|
box_shape = np.asarray(box_shape)
|
||||||
|
|
||||||
|
if multipoles == 0:
|
||||||
|
los = None
|
||||||
|
else:
|
||||||
|
los = np.asarray(los)
|
||||||
|
los = los / np.linalg.norm(los)
|
||||||
|
poles = np.atleast_1d(multipoles)
|
||||||
|
dig, kcount, kavg, mumesh = _initialize_pk(mesh_shape, box_shape, kedges,
|
||||||
|
los)
|
||||||
|
n_bins = len(kavg) + 2
|
||||||
|
|
||||||
|
# FFTs
|
||||||
|
meshk = jnp.fft.fftn(mesh, norm='ortho')
|
||||||
|
if mesh2 is None:
|
||||||
|
mmk = meshk.real**2 + meshk.imag**2
|
||||||
|
else:
|
||||||
|
mmk = meshk * jnp.fft.fftn(mesh2, norm='ortho').conj()
|
||||||
|
|
||||||
|
# Sum powers
|
||||||
|
pk = jnp.empty((len(poles), n_bins))
|
||||||
|
for i_ell, ell in enumerate(poles):
|
||||||
|
weights = (mmk * (2 * ell + 1) * legendre(ell)(mumesh)).reshape(-1)
|
||||||
|
if mesh2 is None:
|
||||||
|
psum = jnp.bincount(dig, weights=weights, length=n_bins)
|
||||||
|
else: # XXX: bincount is really slow with complex numbers
|
||||||
|
psum_real = jnp.bincount(dig, weights=weights.real, length=n_bins)
|
||||||
|
psum_imag = jnp.bincount(dig, weights=weights.imag, length=n_bins)
|
||||||
|
psum = (psum_real**2 + psum_imag**2)**.5
|
||||||
|
pk = pk.at[i_ell].set(psum)
|
||||||
|
|
||||||
|
# Normalization and conversion from cell units to [Mpc/h]^3
|
||||||
|
pk = (pk / kcount)[:, 1:-1] * (box_shape / mesh_shape).prod()
|
||||||
|
|
||||||
|
# pk = jnp.concatenate([kavg[None], pk])
|
||||||
|
if np.ndim(multipoles) == 0:
|
||||||
|
return kavg, pk[0]
|
||||||
|
else:
|
||||||
|
return kavg, pk
|
||||||
|
|
||||||
|
|
||||||
|
def transfer(mesh0, mesh1, box_shape, kedges: int | float | list = None):
|
||||||
|
pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
|
||||||
|
ks, pk0 = pk_fn(mesh0)
|
||||||
|
ks, pk1 = pk_fn(mesh1)
|
||||||
|
return ks, (pk1 / pk0)**.5
|
||||||
|
|
||||||
|
|
||||||
|
def coherence(mesh0, mesh1, box_shape, kedges: int | float | list = None):
|
||||||
|
pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
|
||||||
|
ks, pk01 = pk_fn(mesh0, mesh1)
|
||||||
|
ks, pk0 = pk_fn(mesh0)
|
||||||
|
ks, pk1 = pk_fn(mesh1)
|
||||||
|
return ks, pk01 / (pk0 * pk1)**.5
|
||||||
|
|
||||||
|
|
||||||
|
def pktranscoh(mesh0, mesh1, box_shape, kedges: int | float | list = None):
|
||||||
|
pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
|
||||||
|
ks, pk01 = pk_fn(mesh0, mesh1)
|
||||||
|
ks, pk0 = pk_fn(mesh0)
|
||||||
|
ks, pk1 = pk_fn(mesh1)
|
||||||
|
return ks, pk0, pk1, (pk1 / pk0)**.5, pk01 / (pk0 * pk1)**.5
|
||||||
|
|
||||||
|
|
||||||
|
def cross_correlation_coefficients(field_a,
|
||||||
|
field_b,
|
||||||
|
kmin=5,
|
||||||
|
dk=0.5,
|
||||||
|
boxsize=False):
|
||||||
"""
|
"""
|
||||||
Calculate the cross correlation coefficients given two real space field
|
Calculate the cross correlation coefficients given two real space field
|
||||||
|
|
||||||
|
@ -118,7 +190,8 @@ def cross_correlation_coefficients(field_a,field_b, kmin=5, dk=0.5, boxsize=Fals
|
||||||
real = jnp.real(pk).reshape([-1])
|
real = jnp.real(pk).reshape([-1])
|
||||||
imag = jnp.imag(pk).reshape([-1])
|
imag = jnp.imag(pk).reshape([-1])
|
||||||
|
|
||||||
Psum = jnp.bincount(dig, weights=(W.flatten() * imag), length=xsum.size) * 1j
|
Psum = jnp.bincount(dig, weights=(W.flatten() * imag),
|
||||||
|
length=xsum.size) * 1j
|
||||||
Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
|
Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
|
||||||
|
|
||||||
P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')
|
P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')
|
||||||
|
|
179
notebooks/05-MultiHost_PM.py
Normal file
179
notebooks/05-MultiHost_PM.py
Normal file
|
@ -0,0 +1,179 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["EQX_ON_ERROR"] = "nan" # avoid an allgather caused by diffrax
|
||||||
|
import jax
|
||||||
|
|
||||||
|
jax.distributed.initialize()
|
||||||
|
rank = jax.process_index()
|
||||||
|
size = jax.process_count()
|
||||||
|
if rank == 0:
|
||||||
|
print(f"SIZE is {jax.device_count()}")
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import jax_cosmo as jc
|
||||||
|
import numpy as np
|
||||||
|
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
|
||||||
|
PIDController, SaveAt, diffeqsolve)
|
||||||
|
from jax.experimental.mesh_utils import create_device_mesh
|
||||||
|
from jax.experimental.multihost_utils import process_allgather
|
||||||
|
from jax.sharding import Mesh, NamedSharding
|
||||||
|
from jax.sharding import PartitionSpec as P
|
||||||
|
|
||||||
|
from jaxpm.kernels import interpolate_power_spectrum
|
||||||
|
from jaxpm.painting import cic_paint_dx
|
||||||
|
from jaxpm.pm import linear_field, lpt, make_diffrax_ode
|
||||||
|
|
||||||
|
all_gather = partial(process_allgather, tiled=True)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_arguments():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Run a cosmological simulation with JAX.")
|
||||||
|
parser.add_argument(
|
||||||
|
"-p",
|
||||||
|
"--pdims",
|
||||||
|
type=int,
|
||||||
|
nargs=2,
|
||||||
|
default=[1, jax.devices()],
|
||||||
|
help="Processor grid dimensions as two integers (e.g., 2 4).")
|
||||||
|
parser.add_argument(
|
||||||
|
"-m",
|
||||||
|
"--mesh_shape",
|
||||||
|
type=int,
|
||||||
|
nargs=3,
|
||||||
|
default=[512, 512, 512],
|
||||||
|
help="Shape of the simulation mesh as three values (e.g., 512 512 512)."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-b",
|
||||||
|
"--box_size",
|
||||||
|
type=float,
|
||||||
|
nargs=3,
|
||||||
|
default=[500.0, 500.0, 500.0],
|
||||||
|
help=
|
||||||
|
"Box size of the simulation as three values (e.g., 500.0 500.0 1000.0)."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-st",
|
||||||
|
"--snapshots",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="Number of snapshots to save during the simulation.")
|
||||||
|
parser.add_argument("-H",
|
||||||
|
"--halo_size",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="Halo size for the simulation.")
|
||||||
|
parser.add_argument("-s",
|
||||||
|
"--solver",
|
||||||
|
type=str,
|
||||||
|
choices=['leapfrog', 'dopri8'],
|
||||||
|
default='leapfrog',
|
||||||
|
help="ODE solver choice: 'leapfrog' or 'dopri8'.")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def create_mesh_and_sharding(pdims):
|
||||||
|
devices = create_device_mesh(pdims)
|
||||||
|
mesh = Mesh(devices, axis_names=('x', 'y'))
|
||||||
|
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||||
|
return mesh, sharding
|
||||||
|
|
||||||
|
|
||||||
|
@partial(jax.jit, static_argnums=(2, 3, 4, 5, 6))
|
||||||
|
def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
|
||||||
|
solver_choice, nb_snapshots, sharding):
|
||||||
|
k = jnp.logspace(-4, 1, 128)
|
||||||
|
pk = jc.power.linear_matter_power(
|
||||||
|
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
|
||||||
|
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)
|
||||||
|
|
||||||
|
initial_conditions = linear_field(mesh_shape,
|
||||||
|
box_size,
|
||||||
|
pk_fn,
|
||||||
|
seed=jax.random.PRNGKey(0),
|
||||||
|
sharding=sharding)
|
||||||
|
|
||||||
|
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
|
||||||
|
|
||||||
|
dx, p, _ = lpt(cosmo,
|
||||||
|
initial_conditions,
|
||||||
|
a=0.1,
|
||||||
|
halo_size=halo_size,
|
||||||
|
sharding=sharding)
|
||||||
|
|
||||||
|
ode_fn = ODETerm(
|
||||||
|
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
|
||||||
|
|
||||||
|
# Choose solver
|
||||||
|
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()
|
||||||
|
stepsize_controller = ConstantStepSize(
|
||||||
|
) if solver_choice == "leapfrog" else PIDController(rtol=1e-5, atol=1e-5)
|
||||||
|
res = diffeqsolve(ode_fn,
|
||||||
|
solver,
|
||||||
|
t0=0.1,
|
||||||
|
t1=1.,
|
||||||
|
dt0=0.01,
|
||||||
|
y0=jnp.stack([dx, p], axis=0),
|
||||||
|
args=cosmo,
|
||||||
|
saveat=SaveAt(ts=jnp.linspace(0.2, 1., nb_snapshots)),
|
||||||
|
stepsize_controller=stepsize_controller)
|
||||||
|
|
||||||
|
ode_fields = [
|
||||||
|
cic_paint_dx(sol[0], halo_size=halo_size, sharding=sharding)
|
||||||
|
for sol in res.ys
|
||||||
|
]
|
||||||
|
lpt_field = cic_paint_dx(dx, halo_size=halo_size, sharding=sharding)
|
||||||
|
return initial_conditions, lpt_field, ode_fields, res.stats
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_arguments()
|
||||||
|
mesh_shape = args.mesh_shape
|
||||||
|
box_size = args.box_size
|
||||||
|
halo_size = args.halo_size
|
||||||
|
solver_choice = args.solver
|
||||||
|
nb_snapshots = args.snapshots
|
||||||
|
|
||||||
|
sharding = create_mesh_and_sharding(args.pdims)
|
||||||
|
|
||||||
|
initial_conditions, lpt_displacements, ode_solutions, solver_stats = run_simulation(
|
||||||
|
0.25, 0.8, tuple(mesh_shape), tuple(box_size), halo_size,
|
||||||
|
solver_choice, nb_snapshots, sharding)
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
os.makedirs("fields", exist_ok=True)
|
||||||
|
print(f"[{rank}] Simulation done")
|
||||||
|
print(f"Solver stats: {solver_stats}")
|
||||||
|
|
||||||
|
# Save initial conditions
|
||||||
|
initial_conditions_g = all_gather(initial_conditions)
|
||||||
|
if rank == 0:
|
||||||
|
print(f"[{rank}] Saving initial_conditions")
|
||||||
|
np.save("fields/initial_conditions.npy", initial_conditions_g)
|
||||||
|
print(f"[{rank}] initial_conditions saved")
|
||||||
|
del initial_conditions_g, initial_conditions
|
||||||
|
|
||||||
|
# Save LPT displacements
|
||||||
|
lpt_displacements_g = all_gather(lpt_displacements)
|
||||||
|
if rank == 0:
|
||||||
|
print(f"[{rank}] Saving lpt_displacements")
|
||||||
|
np.save("fields/lpt_displacements.npy", lpt_displacements_g)
|
||||||
|
print(f"[{rank}] lpt_displacements saved")
|
||||||
|
del lpt_displacements_g, lpt_displacements
|
||||||
|
|
||||||
|
# Save each ODE solution separately
|
||||||
|
for i, sol in enumerate(ode_solutions):
|
||||||
|
sol_g = all_gather(sol)
|
||||||
|
if rank == 0:
|
||||||
|
print(f"[{rank}] Saving ode_solution_{i}")
|
||||||
|
np.save(f"fields/ode_solution_{i}.npy", sol_g)
|
||||||
|
print(f"[{rank}] ode_solution_{i} saved")
|
||||||
|
del sol_g
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
1
notebooks/06-Animating_PM_Fields.ipynb.REMOVED.git-id
Normal file
1
notebooks/06-Animating_PM_Fields.ipynb.REMOVED.git-id
Normal file
|
@ -0,0 +1 @@
|
||||||
|
c4a44973e4f11841a8c14f4d200e7e87887419aa
|
39
notebooks/README.md
Normal file
39
notebooks/README.md
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
# Particle Mesh Simulation with JAXPM on Multi-GPU and Multi-Host Systems
|
||||||
|
|
||||||
|
This collection of notebooks demonstrates how to perform Particle Mesh (PM) simulations using **JAXPM**, leveraging JAX for efficient computation on multi-GPU and multi-host systems. Each notebook progressively covers different setups, from single-GPU simulations to advanced, distributed, multi-host simulations across multiple nodes.
|
||||||
|
|
||||||
|
## Table of Contents
|
||||||
|
|
||||||
|
1. **[Single-GPU Particle Mesh Simulation](01-Introduction.ipynb)**
|
||||||
|
- Introduction to basic PM simulations on a single GPU.
|
||||||
|
- Uses JAXPM to run simulations with absolute particle positions and Cloud-in-Cell (CIC) painting.
|
||||||
|
|
||||||
|
2. **[Advanced Particle Mesh Simulation on a Single GPU](02-Advanced_usage.ipynb)**
|
||||||
|
- Explore using diffrax solvers in the ODE step.
|
||||||
|
- Explores second order Lagrangian Perturbation Theory (LPT) simulations.
|
||||||
|
- Introduces weighted density field projections
|
||||||
|
|
||||||
|
3. **[Multi-GPU Particle Mesh Simulation with Halo Exchange](03-MultiGPU_PM_Halo.ipynb)**
|
||||||
|
- Extends PM simulation to multi-GPU setups with halo exchange.
|
||||||
|
- Uses sharding and device mesh configurations to manage distributed data across GPUs.
|
||||||
|
|
||||||
|
4. **[Multi-GPU Particle Mesh Simulation with Advanced Solvers](04-MultiGPU_PM_Solvers.ipynb)**
|
||||||
|
- Compares different ODE solvers (Leapfrog and Dopri5) in multi-GPU simulations.
|
||||||
|
- Highlights performance, memory considerations, and solver impact on simulation quality.
|
||||||
|
|
||||||
|
5. **[Multi-Host Particle Mesh Simulation](05-MultiHost_PM.ipynb)**
|
||||||
|
- Extends PM simulations to multi-host, multi-GPU setups for large-scale simulations.
|
||||||
|
- Guides through job submission, device initialization, and retrieving results across nodes.
|
||||||
|
|
||||||
|
## Getting Started
|
||||||
|
|
||||||
|
Each notebook includes installation instructions and guidelines for configuring JAXPM and required dependencies. Follow the setup instructions in each notebook to ensure an optimal environment.
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
- **JAXPM** (included in the installation commands within notebooks)
|
||||||
|
- **Diffrax** for ODE solvers
|
||||||
|
- **JAX** with CUDA support for multi-GPU or TPU setups
|
||||||
|
- **SLURM** for job scheduling on clusters (if running multi-host setups)
|
||||||
|
|
||||||
|
> **Note**: These notebooks are tested on the **Jean Zay** supercomputer and may require configuration changes for different HPC clusters.
|
30
pyproject.toml
Normal file
30
pyproject.toml
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
[build-system]
|
||||||
|
requires = ["setuptools", "wheel"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "JaxPM"
|
||||||
|
version = "0.0.1"
|
||||||
|
description = "A dead simple FastPM implementation in JAX"
|
||||||
|
authors = [{ name = "JaxPM developers" }]
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.9"
|
||||||
|
license = { file = "LICENSE" }
|
||||||
|
urls = { "Homepage" = "https://github.com/DifferentiableUniverseInitiative/JaxPM" }
|
||||||
|
dependencies = ["jax_cosmo", "jax>=0.4.30", "jaxdecomp>=0.2.2"]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
test = [
|
||||||
|
"jax>=0.4.30",
|
||||||
|
"numpy",
|
||||||
|
"jax_cosmo",
|
||||||
|
"jaxdecomp>=0.2.2",
|
||||||
|
"pytest>=8.0.0",
|
||||||
|
"pfft-python @ git+https://github.com/MP-Gadget/pfft-python",
|
||||||
|
"pmesh @ git+https://github.com/MP-Gadget/pmesh",
|
||||||
|
"fastpm @ git+https://github.com/ASKabalan/fastpm-python",
|
||||||
|
"diffrax"
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.setuptools]
|
||||||
|
packages = ["jaxpm"]
|
4
pytest.ini
Normal file
4
pytest.ini
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
[pytest]
|
||||||
|
markers =
|
||||||
|
distributed: mark a test as distributed
|
||||||
|
single_device: mark a test as single_device
|
11
setup.py
11
setup.py
|
@ -1,11 +0,0 @@
|
||||||
from setuptools import find_packages, setup
|
|
||||||
|
|
||||||
setup(
|
|
||||||
name='JaxPM',
|
|
||||||
version='0.0.1',
|
|
||||||
url='https://github.com/DifferentiableUniverseInitiative/JaxPM',
|
|
||||||
author='JaxPM developers',
|
|
||||||
description='A dead simple FastPM implementation in JAX',
|
|
||||||
packages=find_packages(),
|
|
||||||
install_requires=['jax', 'jax_cosmo'],
|
|
||||||
)
|
|
175
tests/conftest.py
Normal file
175
tests/conftest.py
Normal file
|
@ -0,0 +1,175 @@
|
||||||
|
# Parameterized fixture for mesh_shape
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
os.environ["EQX_ON_ERROR"] = "nan"
|
||||||
|
setup_done = False
|
||||||
|
on_cluster = False
|
||||||
|
|
||||||
|
|
||||||
|
def is_on_cluster():
|
||||||
|
global on_cluster
|
||||||
|
return on_cluster
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_distributed():
|
||||||
|
global setup_done
|
||||||
|
global on_cluster
|
||||||
|
if not setup_done:
|
||||||
|
if "SLURM_JOB_ID" in os.environ:
|
||||||
|
on_cluster = True
|
||||||
|
print("Running on cluster")
|
||||||
|
import jax
|
||||||
|
jax.distributed.initialize()
|
||||||
|
setup_done = True
|
||||||
|
on_cluster = True
|
||||||
|
else:
|
||||||
|
print("Running locally")
|
||||||
|
setup_done = True
|
||||||
|
on_cluster = False
|
||||||
|
os.environ["JAX_PLATFORM_NAME"] = "cpu"
|
||||||
|
os.environ[
|
||||||
|
"XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
|
||||||
|
import jax
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(
|
||||||
|
scope="session",
|
||||||
|
params=[
|
||||||
|
((32, 32, 32), (256., 256., 256.)), # BOX
|
||||||
|
((32, 32, 64), (256., 256., 512.)), # RECTANGULAR
|
||||||
|
])
|
||||||
|
def simulation_config(request):
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", params=[0.1, 0.5, 0.8])
|
||||||
|
def lpt_scale_factor(request):
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def cosmo():
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from jax_cosmo import Cosmology
|
||||||
|
Planck18 = partial(
|
||||||
|
Cosmology,
|
||||||
|
# Omega_m = 0.3111
|
||||||
|
Omega_c=0.2607,
|
||||||
|
Omega_b=0.0490,
|
||||||
|
Omega_k=0.0,
|
||||||
|
h=0.6766,
|
||||||
|
n_s=0.9665,
|
||||||
|
sigma8=0.8102,
|
||||||
|
w0=-1.0,
|
||||||
|
wa=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
return Planck18()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def particle_mesh(simulation_config):
|
||||||
|
from pmesh.pm import ParticleMesh
|
||||||
|
mesh_shape, box_shape = simulation_config
|
||||||
|
return ParticleMesh(BoxSize=box_shape, Nmesh=mesh_shape, dtype='f4')
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def fpm_initial_conditions(cosmo, particle_mesh):
|
||||||
|
import jax_cosmo as jc
|
||||||
|
import numpy as np
|
||||||
|
from jax import numpy as jnp
|
||||||
|
|
||||||
|
# Generate initial particle positions
|
||||||
|
grid = particle_mesh.generate_uniform_particle_grid(shift=0).astype(
|
||||||
|
np.float32)
|
||||||
|
# Interpolate with linear_matter spectrum to get initial density field
|
||||||
|
k = jnp.logspace(-4, 1, 128)
|
||||||
|
pk = jc.power.linear_matter_power(cosmo, k)
|
||||||
|
|
||||||
|
def pk_fn(x):
|
||||||
|
return jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape)
|
||||||
|
|
||||||
|
whitec = particle_mesh.generate_whitenoise(42,
|
||||||
|
type='complex',
|
||||||
|
unitary=False)
|
||||||
|
lineark = whitec.apply(lambda k, v: pk_fn(sum(ki**2 for ki in k)**0.5)**0.5
|
||||||
|
* v * (1 / v.BoxSize).prod()**0.5)
|
||||||
|
init_mesh = lineark.c2r().value # XXX
|
||||||
|
|
||||||
|
return lineark, grid, init_mesh
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def initial_conditions(fpm_initial_conditions):
|
||||||
|
_, _, init_mesh = fpm_initial_conditions
|
||||||
|
return init_mesh
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def solver(cosmo, particle_mesh):
|
||||||
|
from fastpm.core import Cosmology as FastPMCosmology
|
||||||
|
from fastpm.core import Solver
|
||||||
|
ref_cosmo = FastPMCosmology(cosmo)
|
||||||
|
return Solver(particle_mesh, ref_cosmo, B=1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def fpm_lpt1(solver, fpm_initial_conditions, lpt_scale_factor):
|
||||||
|
|
||||||
|
lineark, grid, _ = fpm_initial_conditions
|
||||||
|
statelpt = solver.lpt(lineark, grid, lpt_scale_factor, order=1)
|
||||||
|
return statelpt
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def fpm_lpt1_field(fpm_lpt1, particle_mesh):
|
||||||
|
return particle_mesh.paint(fpm_lpt1.X).value
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def fpm_lpt2(solver, fpm_initial_conditions, lpt_scale_factor):
|
||||||
|
|
||||||
|
lineark, grid, _ = fpm_initial_conditions
|
||||||
|
statelpt = solver.lpt(lineark, grid, lpt_scale_factor, order=2)
|
||||||
|
return statelpt
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def fpm_lpt2_field(fpm_lpt2, particle_mesh):
|
||||||
|
return particle_mesh.paint(fpm_lpt2.X).value
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def nbody_from_lpt1(solver, fpm_lpt1, particle_mesh, lpt_scale_factor):
|
||||||
|
import numpy as np
|
||||||
|
from fastpm.core import leapfrog
|
||||||
|
|
||||||
|
if lpt_scale_factor == 0.8:
|
||||||
|
pytest.skip("Do not run nbody simulation from scale factor 0.8")
|
||||||
|
|
||||||
|
stages = np.linspace(lpt_scale_factor, 1.0, 10, endpoint=True)
|
||||||
|
|
||||||
|
finalstate = solver.nbody(fpm_lpt1, leapfrog(stages))
|
||||||
|
fpm_mesh = particle_mesh.paint(finalstate.X).value
|
||||||
|
|
||||||
|
return fpm_mesh
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def nbody_from_lpt2(solver, fpm_lpt2, particle_mesh, lpt_scale_factor):
|
||||||
|
import numpy as np
|
||||||
|
from fastpm.core import leapfrog
|
||||||
|
|
||||||
|
if lpt_scale_factor == 0.8:
|
||||||
|
pytest.skip("Do not run nbody simulation from scale factor 0.8")
|
||||||
|
|
||||||
|
stages = np.linspace(lpt_scale_factor, 1.0, 10, endpoint=True)
|
||||||
|
|
||||||
|
finalstate = solver.nbody(fpm_lpt2, leapfrog(stages))
|
||||||
|
fpm_mesh = particle_mesh.paint(finalstate.X).value
|
||||||
|
|
||||||
|
return fpm_mesh
|
13
tests/helpers.py
Normal file
13
tests/helpers.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
|
||||||
|
def MSE(x, y):
|
||||||
|
return jnp.mean((x - y)**2)
|
||||||
|
|
||||||
|
|
||||||
|
def MSE_3D(x, y):
|
||||||
|
return ((x - y)**2).mean(axis=0)
|
||||||
|
|
||||||
|
|
||||||
|
def MSRE(x, y):
|
||||||
|
return jnp.mean(((x - y) / y)**2)
|
155
tests/test_against_fpm.py
Normal file
155
tests/test_against_fpm.py
Normal file
|
@ -0,0 +1,155 @@
|
||||||
|
import pytest
|
||||||
|
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve
|
||||||
|
from helpers import MSE, MSRE
|
||||||
|
from jax import numpy as jnp
|
||||||
|
|
||||||
|
from jaxpm.distributed import uniform_particles
|
||||||
|
from jaxpm.painting import cic_paint, cic_paint_dx
|
||||||
|
from jaxpm.pm import lpt, make_diffrax_ode
|
||||||
|
from jaxpm.utils import power_spectrum
|
||||||
|
|
||||||
|
_TOLERANCE = 1e-4
|
||||||
|
_PM_TOLERANCE = 1e-3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.single_device
|
||||||
|
@pytest.mark.parametrize("order", [1, 2])
|
||||||
|
def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor,
|
||||||
|
fpm_lpt1_field, fpm_lpt2_field, cosmo, order):
|
||||||
|
|
||||||
|
mesh_shape, box_shape = simulation_config
|
||||||
|
cosmo._workspace = {}
|
||||||
|
particles = uniform_particles(mesh_shape)
|
||||||
|
|
||||||
|
# Initial displacement
|
||||||
|
dx, _, _ = lpt(cosmo,
|
||||||
|
initial_conditions,
|
||||||
|
particles,
|
||||||
|
a=lpt_scale_factor,
|
||||||
|
order=order)
|
||||||
|
|
||||||
|
fpm_ref_field = fpm_lpt1_field if order == 1 else fpm_lpt2_field
|
||||||
|
|
||||||
|
lpt_field = cic_paint(jnp.zeros(mesh_shape), particles + dx)
|
||||||
|
_, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape)
|
||||||
|
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
|
||||||
|
|
||||||
|
assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE
|
||||||
|
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.single_device
|
||||||
|
@pytest.mark.parametrize("order", [1, 2])
|
||||||
|
def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor,
|
||||||
|
fpm_lpt1_field, fpm_lpt2_field, cosmo, order):
|
||||||
|
|
||||||
|
mesh_shape, box_shape = simulation_config
|
||||||
|
cosmo._workspace = {}
|
||||||
|
# Initial displacement
|
||||||
|
dx, _, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order)
|
||||||
|
|
||||||
|
lpt_field = cic_paint_dx(dx)
|
||||||
|
|
||||||
|
fpm_ref_field = fpm_lpt1_field if order == 1 else fpm_lpt2_field
|
||||||
|
|
||||||
|
_, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape)
|
||||||
|
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
|
||||||
|
|
||||||
|
assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE
|
||||||
|
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.single_device
|
||||||
|
@pytest.mark.parametrize("order", [1, 2])
|
||||||
|
def test_nbody_absolute(simulation_config, initial_conditions,
|
||||||
|
lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2,
|
||||||
|
cosmo, order):
|
||||||
|
|
||||||
|
mesh_shape, box_shape = simulation_config
|
||||||
|
cosmo._workspace = {}
|
||||||
|
particles = uniform_particles(mesh_shape)
|
||||||
|
|
||||||
|
# Initial displacement
|
||||||
|
dx, p, _ = lpt(cosmo,
|
||||||
|
initial_conditions,
|
||||||
|
particles,
|
||||||
|
a=lpt_scale_factor,
|
||||||
|
order=order)
|
||||||
|
|
||||||
|
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
|
||||||
|
|
||||||
|
solver = Dopri5()
|
||||||
|
controller = PIDController(rtol=1e-8,
|
||||||
|
atol=1e-8,
|
||||||
|
pcoeff=0.4,
|
||||||
|
icoeff=1,
|
||||||
|
dcoeff=0)
|
||||||
|
|
||||||
|
saveat = SaveAt(t1=True)
|
||||||
|
|
||||||
|
y0 = jnp.stack([particles + dx, p])
|
||||||
|
|
||||||
|
solutions = diffeqsolve(ode_fn,
|
||||||
|
solver,
|
||||||
|
t0=lpt_scale_factor,
|
||||||
|
t1=1.0,
|
||||||
|
dt0=None,
|
||||||
|
y0=y0,
|
||||||
|
stepsize_controller=controller,
|
||||||
|
saveat=saveat)
|
||||||
|
|
||||||
|
final_field = cic_paint(jnp.zeros(mesh_shape), solutions.ys[-1, 0])
|
||||||
|
|
||||||
|
fpm_ref_field = nbody_from_lpt1 if order == 1 else nbody_from_lpt2
|
||||||
|
|
||||||
|
_, jpm_ps = power_spectrum(final_field, box_shape=box_shape)
|
||||||
|
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
|
||||||
|
|
||||||
|
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE
|
||||||
|
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.single_device
|
||||||
|
@pytest.mark.parametrize("order", [1, 2])
|
||||||
|
def test_nbody_relative(simulation_config, initial_conditions,
|
||||||
|
lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2,
|
||||||
|
cosmo, order):
|
||||||
|
|
||||||
|
mesh_shape, box_shape = simulation_config
|
||||||
|
cosmo._workspace = {}
|
||||||
|
|
||||||
|
# Initial displacement
|
||||||
|
dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order)
|
||||||
|
|
||||||
|
ode_fn = ODETerm(
|
||||||
|
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
|
||||||
|
|
||||||
|
solver = Dopri5()
|
||||||
|
controller = PIDController(rtol=1e-9,
|
||||||
|
atol=1e-9,
|
||||||
|
pcoeff=0.4,
|
||||||
|
icoeff=1,
|
||||||
|
dcoeff=0)
|
||||||
|
|
||||||
|
saveat = SaveAt(t1=True)
|
||||||
|
|
||||||
|
y0 = jnp.stack([dx, p])
|
||||||
|
|
||||||
|
solutions = diffeqsolve(ode_fn,
|
||||||
|
solver,
|
||||||
|
t0=lpt_scale_factor,
|
||||||
|
t1=1.0,
|
||||||
|
dt0=None,
|
||||||
|
y0=y0,
|
||||||
|
stepsize_controller=controller,
|
||||||
|
saveat=saveat)
|
||||||
|
|
||||||
|
final_field = cic_paint_dx(solutions.ys[-1, 0])
|
||||||
|
|
||||||
|
fpm_ref_field = nbody_from_lpt1 if order == 1 else nbody_from_lpt2
|
||||||
|
|
||||||
|
_, jpm_ps = power_spectrum(final_field, box_shape=box_shape)
|
||||||
|
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
|
||||||
|
|
||||||
|
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE
|
||||||
|
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE
|
152
tests/test_distributed_pm.py
Normal file
152
tests/test_distributed_pm.py
Normal file
|
@ -0,0 +1,152 @@
|
||||||
|
from conftest import initialize_distributed
|
||||||
|
|
||||||
|
initialize_distributed() # ignore : E402
|
||||||
|
|
||||||
|
import jax # noqa : E402
|
||||||
|
import jax.numpy as jnp # noqa : E402
|
||||||
|
import pytest # noqa : E402
|
||||||
|
from diffrax import SaveAt # noqa : E402
|
||||||
|
from diffrax import Dopri5, ODETerm, PIDController, diffeqsolve
|
||||||
|
from helpers import MSE # noqa : E402
|
||||||
|
from jax import lax # noqa : E402
|
||||||
|
from jax.experimental.multihost_utils import process_allgather # noqa : E402
|
||||||
|
from jax.sharding import NamedSharding
|
||||||
|
from jax.sharding import PartitionSpec as P # noqa : E402
|
||||||
|
|
||||||
|
from jaxpm.distributed import uniform_particles # noqa : E402
|
||||||
|
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
|
||||||
|
from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402
|
||||||
|
|
||||||
|
_TOLERANCE = 3.0 # 🙃🙃
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.distributed
|
||||||
|
@pytest.mark.parametrize("order", [1, 2])
|
||||||
|
@pytest.mark.parametrize("absolute_painting", [True, False])
|
||||||
|
def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
||||||
|
absolute_painting):
|
||||||
|
|
||||||
|
mesh_shape, box_shape = simulation_config
|
||||||
|
# SINGLE DEVICE RUN
|
||||||
|
cosmo._workspace = {}
|
||||||
|
if absolute_painting:
|
||||||
|
particles = uniform_particles(mesh_shape)
|
||||||
|
# Initial displacement
|
||||||
|
dx, p, _ = lpt(cosmo,
|
||||||
|
initial_conditions,
|
||||||
|
particles,
|
||||||
|
a=0.1,
|
||||||
|
order=order)
|
||||||
|
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
|
||||||
|
y0 = jnp.stack([particles + dx, p])
|
||||||
|
else:
|
||||||
|
dx, p, _ = lpt(cosmo, initial_conditions, a=0.1, order=order)
|
||||||
|
ode_fn = ODETerm(
|
||||||
|
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
|
||||||
|
y0 = jnp.stack([dx, p])
|
||||||
|
|
||||||
|
solver = Dopri5()
|
||||||
|
controller = PIDController(rtol=1e-8,
|
||||||
|
atol=1e-8,
|
||||||
|
pcoeff=0.4,
|
||||||
|
icoeff=1,
|
||||||
|
dcoeff=0)
|
||||||
|
|
||||||
|
saveat = SaveAt(t1=True)
|
||||||
|
|
||||||
|
solutions = diffeqsolve(ode_fn,
|
||||||
|
solver,
|
||||||
|
t0=0.1,
|
||||||
|
t1=1.0,
|
||||||
|
dt0=None,
|
||||||
|
y0=y0,
|
||||||
|
stepsize_controller=controller,
|
||||||
|
saveat=saveat)
|
||||||
|
|
||||||
|
if absolute_painting:
|
||||||
|
single_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape),
|
||||||
|
solutions.ys[-1, 0])
|
||||||
|
else:
|
||||||
|
single_device_final_field = cic_paint_dx(solutions.ys[-1, 0])
|
||||||
|
|
||||||
|
print("Done with single device run")
|
||||||
|
# MULTI DEVICE RUN
|
||||||
|
|
||||||
|
mesh = jax.make_mesh((1, 8), ('x', 'y'))
|
||||||
|
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||||
|
halo_size = mesh_shape[0] // 2
|
||||||
|
|
||||||
|
initial_conditions = lax.with_sharding_constraint(initial_conditions,
|
||||||
|
sharding)
|
||||||
|
|
||||||
|
print(f"sharded initial conditions {initial_conditions.sharding}")
|
||||||
|
|
||||||
|
cosmo._workspace = {}
|
||||||
|
if absolute_painting:
|
||||||
|
particles = uniform_particles(mesh_shape, sharding=sharding)
|
||||||
|
# Initial displacement
|
||||||
|
dx, p, _ = lpt(cosmo,
|
||||||
|
initial_conditions,
|
||||||
|
particles,
|
||||||
|
a=0.1,
|
||||||
|
order=order,
|
||||||
|
halo_size=halo_size,
|
||||||
|
sharding=sharding)
|
||||||
|
|
||||||
|
ode_fn = ODETerm(
|
||||||
|
make_diffrax_ode(cosmo,
|
||||||
|
mesh_shape,
|
||||||
|
halo_size=halo_size,
|
||||||
|
sharding=sharding))
|
||||||
|
|
||||||
|
y0 = jnp.stack([particles + dx, p])
|
||||||
|
else:
|
||||||
|
dx, p, _ = lpt(cosmo,
|
||||||
|
initial_conditions,
|
||||||
|
a=0.1,
|
||||||
|
order=order,
|
||||||
|
halo_size=halo_size,
|
||||||
|
sharding=sharding)
|
||||||
|
ode_fn = ODETerm(
|
||||||
|
make_diffrax_ode(cosmo,
|
||||||
|
mesh_shape,
|
||||||
|
paint_absolute_pos=False,
|
||||||
|
halo_size=halo_size,
|
||||||
|
sharding=sharding))
|
||||||
|
y0 = jnp.stack([dx, p])
|
||||||
|
|
||||||
|
solver = Dopri5()
|
||||||
|
controller = PIDController(rtol=1e-8,
|
||||||
|
atol=1e-8,
|
||||||
|
pcoeff=0.4,
|
||||||
|
icoeff=1,
|
||||||
|
dcoeff=0)
|
||||||
|
|
||||||
|
saveat = SaveAt(t1=True)
|
||||||
|
|
||||||
|
solutions = diffeqsolve(ode_fn,
|
||||||
|
solver,
|
||||||
|
t0=0.1,
|
||||||
|
t1=1.0,
|
||||||
|
dt0=None,
|
||||||
|
y0=y0,
|
||||||
|
stepsize_controller=controller,
|
||||||
|
saveat=saveat)
|
||||||
|
|
||||||
|
if absolute_painting:
|
||||||
|
multi_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape),
|
||||||
|
solutions.ys[-1, 0],
|
||||||
|
halo_size=halo_size,
|
||||||
|
sharding=sharding)
|
||||||
|
else:
|
||||||
|
multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0],
|
||||||
|
halo_size=halo_size,
|
||||||
|
sharding=sharding)
|
||||||
|
|
||||||
|
multi_device_final_field = process_allgather(multi_device_final_field,
|
||||||
|
tiled=True)
|
||||||
|
|
||||||
|
mse = MSE(single_device_final_field, multi_device_final_field)
|
||||||
|
print(f"MSE is {mse}")
|
||||||
|
|
||||||
|
assert mse < _TOLERANCE
|
Loading…
Add table
Reference in a new issue