mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-05 19:30:54 +00:00
Compare commits
No commits in common. "main" and "v0.0.2" have entirely different histories.
37 changed files with 449 additions and 4304 deletions
21
.github/workflows/formatting.yml
vendored
21
.github/workflows/formatting.yml
vendored
|
@ -1,21 +0,0 @@
|
|||
name: Code Formatting
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v3
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip isort
|
||||
python -m pip install pre-commit
|
||||
- name: Run pre-commit
|
||||
run: python -m pre_commit run --all-files
|
55
.github/workflows/python-publish.yml
vendored
55
.github/workflows/python-publish.yml
vendored
|
@ -1,55 +0,0 @@
|
|||
name: Upload Python Package
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
release-build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.x"
|
||||
|
||||
- name: Build release distributions
|
||||
run: |
|
||||
# NOTE: put your own distribution build steps here.
|
||||
python -m pip install build
|
||||
python -m build
|
||||
|
||||
- name: Upload distributions
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: release-dists
|
||||
path: dist/
|
||||
|
||||
pypi-publish:
|
||||
runs-on: ubuntu-latest
|
||||
needs:
|
||||
- release-build
|
||||
permissions:
|
||||
# IMPORTANT: this permission is mandatory for trusted publishing
|
||||
id-token: write
|
||||
|
||||
environment:
|
||||
name: pypi
|
||||
url: https://pypi.org/p/jaxpm
|
||||
|
||||
steps:
|
||||
- name: Retrieve release distributions
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: release-dists
|
||||
path: dist/
|
||||
|
||||
- name: Publish release distributions to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
packages-dir: dist/
|
46
.github/workflows/tests.yml
vendored
46
.github/workflows/tests.yml
vendored
|
@ -1,46 +0,0 @@
|
|||
name: Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
run_tests:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.10" , "3.11" , "3.12"]
|
||||
|
||||
steps:
|
||||
- name: Checkout Source
|
||||
uses: actions/checkout@v2.3.1
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get install -y libopenmpi-dev
|
||||
python -m pip install --upgrade pip
|
||||
pip install jax==0.4.35
|
||||
pip install numpy setuptools cython wheel
|
||||
pip install git+https://github.com/MP-Gadget/pfft-python
|
||||
pip install git+https://github.com/MP-Gadget/pmesh
|
||||
pip install git+https://github.com/ASKabalan/fastpm-python --no-build-isolation
|
||||
pip install -r requirements-test.txt
|
||||
pip install .
|
||||
|
||||
- name: Run Single Device Tests
|
||||
run: |
|
||||
cd tests
|
||||
pytest -v -m "not distributed"
|
||||
- name: Run Distributed tests
|
||||
run: |
|
||||
pytest -v -m distributed
|
8
.gitignore
vendored
8
.gitignore
vendored
|
@ -98,11 +98,6 @@ __pypackages__/
|
|||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
|
||||
out
|
||||
traces
|
||||
*.npy
|
||||
*.out
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
|
@ -132,6 +127,3 @@ dmypy.json
|
|||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# Hide version file
|
||||
_version.py
|
||||
|
|
|
@ -14,4 +14,4 @@ repos:
|
|||
rev: 5.13.2
|
||||
hooks:
|
||||
- id: isort
|
||||
name: isort (python)
|
||||
name: isort (python)
|
2
LICENSE
2
LICENSE
|
@ -1,6 +1,6 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2021-2025 Differentiable Universe Initiative
|
||||
Copyright (c) 2021 Differentiable Universe Initiative
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
|
|
|
@ -1,2 +0,0 @@
|
|||
prune notebooks
|
||||
prune tests
|
26
README.md
26
README.md
|
@ -1,35 +1,19 @@
|
|||
# JaxPM
|
||||
[](https://colab.research.google.com/github/DifferentiableUniverseInitiative/JaxPM/blob/main/notebooks/01-Introduction.ipynb)
|
||||
[](https://pypi.org/project/jaxpm/) [](https://github.com/DifferentiableUniverseInitiative/JaxPM/actions/workflows/tests.yml) <!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
|
||||
<!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
|
||||
[](#contributors-)
|
||||
<!-- ALL-CONTRIBUTORS-BADGE:END -->
|
||||
JAX-powered Cosmological Particle-Mesh N-body Solver
|
||||
|
||||
> ### Note
|
||||
> **The new JaxPM v0.1.xx** supports multi-GPU model distribution while remaining compatible with previous releases. These significant changes are still under development and testing, so please report any issues you encounter.
|
||||
> For the older but more stable version, install:
|
||||
> ```bash
|
||||
> pip install jaxpm==0.0.2
|
||||
> ```
|
||||
|
||||
## Install
|
||||
|
||||
Basic installation can be done using pip:
|
||||
```bash
|
||||
pip install jaxpm
|
||||
```
|
||||
For more advanced installation for optimized distribution on gpu clusters, please install jaxDecomp first. See instructions [here](https://github.com/DifferentiableUniverseInitiative/jaxDecomp).
|
||||
|
||||
**This project is currently in an early design phase. All inputs are welcome on the [design document](https://github.com/DifferentiableUniverseInitiative/JaxPM/blob/main/design.md)**
|
||||
|
||||
## Goals
|
||||
|
||||
Provide a modern infrastructure to support differentiable PM N-body simulations using JAX:
|
||||
- Keep implementation simple and readable, in pure NumPy API
|
||||
- Transparent distribution using builtin `xmap`
|
||||
- Any order forward and backward automatic differentiation
|
||||
- Support automated batching using `vmap`
|
||||
- Compatibility with external optimizer libraries like `optax`
|
||||
- Now fully distributable on **multi-GPU and multi-node** systems using [jaxDecomp](https://github.com/DifferentiableUniverseInitiative/jaxDecomp) working with`JAX v0.4.35`
|
||||
|
||||
|
||||
## Open development and use
|
||||
|
||||
|
@ -39,10 +23,6 @@ Current expectations are:
|
|||
- Everyone is welcome to contribute, and can join the JOSS publication (until it is submitted to the journal).
|
||||
- Anyone (including main contributors) can use this code as a framework to build and publish their own applications, with no expectation that they *need* to extend authorship to all jaxpm developers.
|
||||
|
||||
## Getting Started
|
||||
|
||||
To dive into JaxPM’s capabilities, please explore the **notebook section** for detailed tutorials and examples on various setups, from single-device simulations to multi-host configurations. You can find the notebooks' [README here](notebooks/README.md) for a structured guide through each tutorial.
|
||||
|
||||
|
||||
## Contributors ✨
|
||||
|
||||
|
|
52
design.md
Normal file
52
design.md
Normal file
|
@ -0,0 +1,52 @@
|
|||
# Design Document for JaxPM
|
||||
|
||||
This document aims to detail some of the API, implementation choices, and internal mechanism.
|
||||
|
||||
## Objective
|
||||
|
||||
Provide a user-friendly framework for distributed Particle-Mesh N-body simulations.
|
||||
|
||||
## Related Work
|
||||
|
||||
This project would be the latest iteration of a number of past libraries that have provided differentiable N-body models.
|
||||
|
||||
- [FlowPM](https://github.com/DifferentiableUniverseInitiative/flowpm): TensorFlow
|
||||
- [vmad FastPM](https://github.com/rainwoodman/vmad/blob/master/vmad/lib/fastpm.py): VMAD
|
||||
- Borg
|
||||
|
||||
|
||||
In addition, a number of fast N-body simulation projets exist out there:
|
||||
- [FastPM](https://github.com/fastpm/fastpm)
|
||||
- ...
|
||||
|
||||
## Design Overview
|
||||
|
||||
### Coding principles
|
||||
|
||||
Following recent trends and JAX philosophy, the library should have a functional programming type of interface.
|
||||
|
||||
|
||||
### Illustration of API
|
||||
|
||||
Here is a potential illustration of what the user interface could be for the simulation code:
|
||||
```python
|
||||
import jaxpm as jpm
|
||||
import jax_cosmo as jc
|
||||
|
||||
# Instantiate differentiable cosmology object
|
||||
cosmo = jc.Planck()
|
||||
|
||||
# Creates initial conditions
|
||||
inital_conditions = jpm.generate_ic(cosmo, boxsize, nmesh, dtype='float32')
|
||||
|
||||
# Create a particular solver
|
||||
solver = jpm.solvers.fastpm(cosmo, B=1)
|
||||
|
||||
# Initialize and run the simulation
|
||||
state = solver.init(initial_conditions)
|
||||
state = solver.nbody(state)
|
||||
|
||||
# Painting the results
|
||||
density = jpm.zeros(boxsize, nmesh)
|
||||
density = jpm.paint(density, state.positions)
|
||||
```
|
14
dev/job_pfft.sh
Normal file
14
dev/job_pfft.sh
Normal file
|
@ -0,0 +1,14 @@
|
|||
#!/bin/bash
|
||||
#SBATCH -A m1727
|
||||
#SBATCH -C gpu
|
||||
#SBATCH -q debug
|
||||
#SBATCH -t 0:05:00
|
||||
#SBATCH -N 2
|
||||
#SBATCH --ntasks-per-node=4
|
||||
#SBATCH -c 32
|
||||
#SBATCH --gpus-per-task=1
|
||||
#SBATCH --gpu-bind=none
|
||||
|
||||
module load python cudnn/8.2.0 nccl/2.11.4 cudatoolkit
|
||||
export SLURM_CPU_BIND="cores"
|
||||
srun python test_pfft.py
|
96
dev/test_pfft.py
Normal file
96
dev/test_pfft.py
Normal file
|
@ -0,0 +1,96 @@
|
|||
# Can be executed with:
|
||||
# srun -n 4 -c 32 --gpus-per-task 1 --gpu-bind=none python test_pfft.py
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
import jax.lax as lax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from jax.experimental.maps import Mesh, xmap
|
||||
from jax.experimental.pjit import PartitionSpec, pjit
|
||||
|
||||
jax.distributed.initialize()
|
||||
|
||||
cube_size = 2048
|
||||
|
||||
|
||||
@partial(xmap,
|
||||
in_axes=[...],
|
||||
out_axes=['x', 'y', ...],
|
||||
axis_sizes={
|
||||
'x': cube_size,
|
||||
'y': cube_size
|
||||
},
|
||||
axis_resources={
|
||||
'x': 'nx',
|
||||
'y': 'ny',
|
||||
'key_x': 'nx',
|
||||
'key_y': 'ny'
|
||||
})
|
||||
def pnormal(key):
|
||||
return jax.random.normal(key, shape=[cube_size])
|
||||
|
||||
|
||||
@partial(xmap,
|
||||
in_axes={
|
||||
0: 'x',
|
||||
1: 'y'
|
||||
},
|
||||
out_axes=['x', 'y', ...],
|
||||
axis_resources={
|
||||
'x': 'nx',
|
||||
'y': 'ny'
|
||||
})
|
||||
@jax.jit
|
||||
def pfft3d(mesh):
|
||||
# [x, y, z]
|
||||
mesh = jnp.fft.fft(mesh) # Transform on z
|
||||
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now x is exposed, [z,y,x]
|
||||
mesh = jnp.fft.fft(mesh) # Transform on x
|
||||
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now y is exposed, [z,x,y]
|
||||
mesh = jnp.fft.fft(mesh) # Transform on y
|
||||
# [z, x, y]
|
||||
return mesh
|
||||
|
||||
|
||||
@partial(xmap,
|
||||
in_axes={
|
||||
0: 'x',
|
||||
1: 'y'
|
||||
},
|
||||
out_axes=['x', 'y', ...],
|
||||
axis_resources={
|
||||
'x': 'nx',
|
||||
'y': 'ny'
|
||||
})
|
||||
@jax.jit
|
||||
def pifft3d(mesh):
|
||||
# [z, x, y]
|
||||
mesh = jnp.fft.ifft(mesh) # Transform on y
|
||||
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now x is exposed, [z,y,x]
|
||||
mesh = jnp.fft.ifft(mesh) # Transform on x
|
||||
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now z is exposed, [x,y,z]
|
||||
mesh = jnp.fft.ifft(mesh) # Transform on z
|
||||
# [x, y, z]
|
||||
return mesh
|
||||
|
||||
|
||||
key = jax.random.PRNGKey(42)
|
||||
# keys = jax.random.split(key, 4).reshape((2,2,2))
|
||||
|
||||
# We reshape all our devices to the mesh shape we want
|
||||
devices = np.array(jax.devices()).reshape((2, 4))
|
||||
|
||||
with Mesh(devices, ('nx', 'ny')):
|
||||
mesh = pnormal(key)
|
||||
kmesh = pfft3d(mesh)
|
||||
kmesh.block_until_ready()
|
||||
|
||||
# jax.profiler.start_trace("tensorboard")
|
||||
# with Mesh(devices, ('nx', 'ny')):
|
||||
# mesh = pnormal(key)
|
||||
# kmesh = pfft3d(mesh)
|
||||
# kmesh.block_until_ready()
|
||||
# jax.profiler.stop_trace()
|
||||
|
||||
print('Done')
|
68
dev/test_script.py
Normal file
68
dev/test_script.py
Normal file
|
@ -0,0 +1,68 @@
|
|||
# Start this script with:
|
||||
# mpirun -np 4 python test_script.py
|
||||
import os
|
||||
|
||||
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'
|
||||
import jax
|
||||
import jax.lax as lax
|
||||
import jax.numpy as jnp
|
||||
import matplotlib.pylab as plt
|
||||
import numpy as np
|
||||
import tensorflow_probability as tfp
|
||||
from jax.experimental.maps import mesh, xmap
|
||||
from jax.experimental.pjit import PartitionSpec, pjit
|
||||
|
||||
tfp = tfp.substrates.jax
|
||||
tfd = tfp.distributions
|
||||
|
||||
|
||||
def cic_paint(mesh, positions):
|
||||
""" Paints positions onto mesh
|
||||
mesh: [nx, ny, nz]
|
||||
positions: [npart, 3]
|
||||
"""
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
floor = jnp.floor(positions)
|
||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
||||
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
|
||||
|
||||
neighboor_coords = floor + connection
|
||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||
|
||||
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
|
||||
inserted_window_dims=(0, 1, 2),
|
||||
scatter_dims_to_operand_dims=(0, 1,
|
||||
2))
|
||||
mesh = lax.scatter_add(
|
||||
mesh,
|
||||
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
|
||||
kernel.reshape([-1, 8]), dnums)
|
||||
return mesh
|
||||
|
||||
|
||||
# And let's draw some points from some 3D distribution
|
||||
dist = tfd.MultivariateNormalDiag(loc=[16., 16., 16.],
|
||||
scale_identity_multiplier=3.)
|
||||
pos = dist.sample(1e4, seed=jax.random.PRNGKey(0))
|
||||
|
||||
f = pjit(lambda x: cic_paint(x, pos),
|
||||
in_axis_resources=PartitionSpec('x', 'y', 'z'),
|
||||
out_axis_resources=None)
|
||||
|
||||
devices = np.array(jax.devices()).reshape((2, 2, 1))
|
||||
|
||||
# Let's import the mesh
|
||||
m = jnp.zeros([32, 32, 32])
|
||||
|
||||
with mesh(devices, ('x', 'y', 'z')):
|
||||
# Shard the mesh, I'm not sure this is absolutely necessary
|
||||
m = pjit(lambda x: x,
|
||||
in_axis_resources=None,
|
||||
out_axis_resources=PartitionSpec('x', 'y', 'z'))(m)
|
||||
|
||||
# Apply the sharded CiC function
|
||||
res = f(m)
|
||||
|
||||
plt.imshow(res.sum(axis=2))
|
||||
plt.show()
|
|
@ -1,198 +0,0 @@
|
|||
from typing import Any, Callable, Hashable
|
||||
|
||||
Specs = Any
|
||||
AxisName = Hashable
|
||||
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jaxdecomp
|
||||
from jax import lax
|
||||
from jax.experimental.shard_map import shard_map
|
||||
from jax.sharding import AbstractMesh, Mesh
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
|
||||
def autoshmap(
|
||||
f: Callable,
|
||||
gpu_mesh: Mesh | AbstractMesh | None,
|
||||
in_specs: Specs,
|
||||
out_specs: Specs,
|
||||
check_rep: bool = False,
|
||||
auto: frozenset[AxisName] = frozenset()) -> Callable:
|
||||
"""Helper function to wrap the provided function in a shard map if
|
||||
the code is being executed in a mesh context."""
|
||||
if gpu_mesh is None or gpu_mesh.empty:
|
||||
return f
|
||||
else:
|
||||
return shard_map(f, gpu_mesh, in_specs, out_specs, check_rep, auto)
|
||||
|
||||
|
||||
def fft3d(x):
|
||||
return jaxdecomp.pfft3d(x)
|
||||
|
||||
|
||||
def ifft3d(x):
|
||||
return jaxdecomp.pifft3d(x).real
|
||||
|
||||
|
||||
def get_halo_size(halo_size, sharding):
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
if gpu_mesh is None or gpu_mesh.empty:
|
||||
zero_ext = (0, 0)
|
||||
zero_tuple = (0, 0)
|
||||
return (zero_tuple, zero_tuple, zero_tuple), zero_ext
|
||||
else:
|
||||
pdims = gpu_mesh.devices.shape
|
||||
halo_x = (0, 0) if pdims[0] == 1 else (halo_size, halo_size)
|
||||
halo_y = (0, 0) if pdims[1] == 1 else (halo_size, halo_size)
|
||||
|
||||
halo_x_ext = 0 if pdims[0] == 1 else halo_size // 2
|
||||
halo_y_ext = 0 if pdims[1] == 1 else halo_size // 2
|
||||
return ((halo_x, halo_y, (0, 0)), (halo_x_ext, halo_y_ext))
|
||||
|
||||
|
||||
def halo_exchange(x, halo_extents, halo_periods=(True, True)):
|
||||
if (halo_extents[0] > 0 or halo_extents[1] > 0):
|
||||
return jaxdecomp.halo_exchange(x, halo_extents, halo_periods)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def slice_unpad_impl(x, pad_width):
|
||||
|
||||
halo_x, _ = pad_width[0]
|
||||
halo_y, _ = pad_width[1]
|
||||
# Apply corrections along x
|
||||
x = x.at[halo_x:halo_x + halo_x // 2].add(x[:halo_x // 2])
|
||||
x = x.at[-(halo_x + halo_x // 2):-halo_x].add(x[-halo_x // 2:])
|
||||
# Apply corrections along y
|
||||
x = x.at[:, halo_y:halo_y + halo_y // 2].add(x[:, :halo_y // 2])
|
||||
x = x.at[:, -(halo_y + halo_y // 2):-halo_y].add(x[:, -halo_y // 2:])
|
||||
|
||||
unpad_slice = [slice(None)] * 3
|
||||
if halo_x > 0:
|
||||
unpad_slice[0] = slice(halo_x, -halo_x)
|
||||
if halo_y > 0:
|
||||
unpad_slice[1] = slice(halo_y, -halo_y)
|
||||
|
||||
return x[tuple(unpad_slice)]
|
||||
|
||||
|
||||
def slice_pad(x, pad_width, sharding):
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
if gpu_mesh is not None and not (gpu_mesh.empty) and (
|
||||
pad_width[0][0] > 0 or pad_width[1][0] > 0):
|
||||
assert sharding is not None
|
||||
spec = sharding.spec
|
||||
return shard_map((partial(jnp.pad, pad_width=pad_width)),
|
||||
mesh=gpu_mesh,
|
||||
in_specs=spec,
|
||||
out_specs=spec)(x)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def slice_unpad(x, pad_width, sharding):
|
||||
mesh = sharding.mesh if sharding is not None else None
|
||||
if mesh is not None and not (mesh.empty) and (pad_width[0][0] > 0
|
||||
or pad_width[1][0] > 0):
|
||||
assert sharding is not None
|
||||
spec = sharding.spec
|
||||
return shard_map(partial(slice_unpad_impl, pad_width=pad_width),
|
||||
mesh=mesh,
|
||||
in_specs=spec,
|
||||
out_specs=spec)(x)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def get_local_shape(mesh_shape, sharding=None):
|
||||
""" Helper function to get the local size of a mesh given the global size.
|
||||
"""
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
if gpu_mesh is None or gpu_mesh.empty:
|
||||
return mesh_shape
|
||||
else:
|
||||
pdims = gpu_mesh.devices.shape
|
||||
return [
|
||||
mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1],
|
||||
*mesh_shape[2:]
|
||||
]
|
||||
|
||||
|
||||
def _axis_names(spec):
|
||||
if len(spec) == 1:
|
||||
x_axis, = spec
|
||||
y_axis = None
|
||||
single_axis = True
|
||||
elif len(spec) == 2:
|
||||
x_axis, y_axis = spec
|
||||
if y_axis == None:
|
||||
single_axis = True
|
||||
elif x_axis == None:
|
||||
x_axis = y_axis
|
||||
single_axis = True
|
||||
else:
|
||||
single_axis = False
|
||||
else:
|
||||
raise ValueError("Only 1 or 2 axis sharding is supported")
|
||||
return x_axis, y_axis, single_axis
|
||||
|
||||
|
||||
def uniform_particles(mesh_shape, sharding=None):
|
||||
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
if gpu_mesh is not None and not (gpu_mesh.empty):
|
||||
local_mesh_shape = get_local_shape(mesh_shape, sharding)
|
||||
spec = sharding.spec
|
||||
x_axis, y_axis, single_axis = _axis_names(spec)
|
||||
|
||||
def particles():
|
||||
x_indx = lax.axis_index(x_axis)
|
||||
y_indx = 0 if single_axis else lax.axis_index(y_axis)
|
||||
|
||||
x = jnp.arange(local_mesh_shape[0]) + x_indx * local_mesh_shape[0]
|
||||
y = jnp.arange(local_mesh_shape[1]) + y_indx * local_mesh_shape[1]
|
||||
z = jnp.arange(local_mesh_shape[2])
|
||||
return jnp.stack(jnp.meshgrid(x, y, z, indexing='ij'), axis=-1)
|
||||
|
||||
return shard_map(particles, mesh=gpu_mesh, in_specs=(),
|
||||
out_specs=spec)()
|
||||
else:
|
||||
return jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape],
|
||||
indexing='ij'),
|
||||
axis=-1)
|
||||
|
||||
|
||||
def normal_field(mesh_shape, seed, sharding=None):
|
||||
"""Generate a Gaussian random field with the given power spectrum."""
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
if gpu_mesh is not None and not (gpu_mesh.empty):
|
||||
local_mesh_shape = get_local_shape(mesh_shape, sharding)
|
||||
|
||||
size = jax.device_count()
|
||||
# rank = jax.process_index()
|
||||
# process_index is multi_host only
|
||||
# to make the code work both in multi host and single controller we can do this trick
|
||||
keys = jax.random.split(seed, size)
|
||||
spec = sharding.spec
|
||||
x_axis, y_axis, single_axis = _axis_names(spec)
|
||||
|
||||
def normal(keys, shape, dtype):
|
||||
idx = lax.axis_index(x_axis)
|
||||
if not single_axis:
|
||||
y_index = lax.axis_index(y_axis)
|
||||
x_size = lax.psum(1, axis_name=x_axis)
|
||||
idx += y_index * x_size
|
||||
|
||||
return jax.random.normal(key=keys[idx], shape=shape, dtype=dtype)
|
||||
|
||||
return shard_map(
|
||||
partial(normal, shape=local_mesh_shape, dtype='float32'),
|
||||
mesh=gpu_mesh,
|
||||
in_specs=P(None),
|
||||
out_specs=spec)(keys) # yapf: disable
|
||||
else:
|
||||
return jax.random.normal(shape=mesh_shape, key=seed)
|
|
@ -1,6 +1,6 @@
|
|||
import jax.numpy as np
|
||||
from jax.numpy import interp
|
||||
from jax_cosmo.background import *
|
||||
from jax_cosmo.scipy.interpolate import interp
|
||||
from jax_cosmo.scipy.ode import odeint
|
||||
|
||||
|
||||
|
@ -587,6 +587,5 @@ def dGf2a(cosmo, a):
|
|||
cache = cosmo._workspace['background.growth_factor']
|
||||
f2p = cache['h2'] / cache['a'] * cache['g2']
|
||||
f2p = interp(np.log(a), np.log(cache['a']), f2p)
|
||||
E_a = E(cosmo, a)
|
||||
return (f2p * a**3 * E_a + D2f * a**3 * dEa(cosmo, a) +
|
||||
3 * a**2 * E_a * D2f)
|
||||
E = E(cosmo, a)
|
||||
return (f2p * a**3 * E + D2f * a**3 * dEa(cosmo, a) + 3 * a**2 * E * D2f)
|
||||
|
|
|
@ -1,46 +1,30 @@
|
|||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from jax.lax import FftType
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jaxdecomp import fftfreq3d, get_output_specs
|
||||
|
||||
from jaxpm.distributed import autoshmap
|
||||
|
||||
|
||||
def fftk(k_array):
|
||||
def fftk(shape, symmetric=True, finite=False, dtype=np.float32):
|
||||
"""
|
||||
Return wave-vectors for a given shape
|
||||
"""
|
||||
Generate Fourier transform wave numbers for a given mesh.
|
||||
k = []
|
||||
for d in range(len(shape)):
|
||||
kd = np.fft.fftfreq(shape[d])
|
||||
kd *= 2 * np.pi
|
||||
kdshape = np.ones(len(shape), dtype='int')
|
||||
if symmetric and d == len(shape) - 1:
|
||||
kd = kd[:shape[d] // 2 + 1]
|
||||
kdshape[d] = len(kd)
|
||||
kd = kd.reshape(kdshape)
|
||||
|
||||
Args:
|
||||
nc (int): Shape of the mesh grid.
|
||||
|
||||
Returns:
|
||||
list: List of wave number arrays for each dimension in
|
||||
the order [kx, ky, kz].
|
||||
"""
|
||||
kx, ky, kz = fftfreq3d(k_array)
|
||||
# to the order of dimensions in the transposed FFT
|
||||
return kx, ky, kz
|
||||
|
||||
|
||||
def interpolate_power_spectrum(input, k, pk, sharding=None):
|
||||
|
||||
pk_fn = lambda x: jnp.interp(x.reshape(-1), k, pk).reshape(x.shape)
|
||||
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
specs = sharding.spec if sharding is not None else P()
|
||||
out_specs = P(*get_output_specs(
|
||||
FftType.FFT, specs, mesh=gpu_mesh)) if gpu_mesh is not None else P()
|
||||
|
||||
return autoshmap(pk_fn,
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=out_specs,
|
||||
out_specs=out_specs)(input)
|
||||
k.append(kd.astype(dtype))
|
||||
del kd, kdshape
|
||||
return k
|
||||
|
||||
|
||||
def gradient_kernel(kvec, direction, order=1):
|
||||
"""
|
||||
Computes the gradient kernel in the requested direction
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
kvec: list
|
||||
|
@ -66,30 +50,23 @@ def gradient_kernel(kvec, direction, order=1):
|
|||
return wts
|
||||
|
||||
|
||||
def invlaplace_kernel(kvec, fd=False):
|
||||
def invlaplace_kernel(kvec):
|
||||
"""
|
||||
Compute the inverse Laplace kernel.
|
||||
|
||||
cf. [Feng+2016](https://arxiv.org/pdf/1603.00476)
|
||||
Compute the inverse Laplace kernel
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
kvec: list
|
||||
List of wave-vectors
|
||||
fd: bool
|
||||
Finite difference kernel
|
||||
|
||||
Returns
|
||||
--------
|
||||
wts: array
|
||||
Complex kernel values
|
||||
"""
|
||||
if fd:
|
||||
kk = sum((ki * jnp.sinc(ki / (2 * jnp.pi)))**2 for ki in kvec)
|
||||
else:
|
||||
kk = sum(ki**2 for ki in kvec)
|
||||
kk_nozeros = jnp.where(kk == 0, 1, kk)
|
||||
return -jnp.where(kk == 0, 0, 1 / kk_nozeros)
|
||||
kk = sum(ki**2 for ki in kvec)
|
||||
kk_nozeros = jnp.where(kk==0, 1, kk)
|
||||
return - jnp.where(kk==0, 0, 1 / kk_nozeros)
|
||||
|
||||
|
||||
def longrange_kernel(kvec, r_split):
|
||||
|
@ -102,10 +79,12 @@ def longrange_kernel(kvec, r_split):
|
|||
List of wave-vectors
|
||||
r_split: float
|
||||
Splitting radius
|
||||
|
||||
Returns
|
||||
--------
|
||||
wts: array
|
||||
Complex kernel values
|
||||
|
||||
TODO: @modichirag add documentation
|
||||
"""
|
||||
if r_split != 0:
|
||||
|
@ -126,12 +105,13 @@ def cic_compensation(kvec):
|
|||
-----------
|
||||
kvec: list
|
||||
List of wave-vectors
|
||||
|
||||
Returns:
|
||||
--------
|
||||
wts: array
|
||||
Complex kernel values
|
||||
"""
|
||||
kwts = [jnp.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
|
||||
kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
|
||||
wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
|
||||
return wts
|
||||
|
||||
|
|
|
@ -1,24 +1,15 @@
|
|||
from functools import partial
|
||||
|
||||
import jax
|
||||
import jax.lax as lax
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import NamedSharding
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
from jaxpm.distributed import (autoshmap, fft3d, get_halo_size, halo_exchange,
|
||||
ifft3d, slice_pad, slice_unpad)
|
||||
from jaxpm.kernels import cic_compensation, fftk
|
||||
from jaxpm.painting_utils import gather, scatter
|
||||
|
||||
|
||||
def _cic_paint_impl(grid_mesh, positions, weight=None):
|
||||
def cic_paint(mesh, positions, weight=None):
|
||||
""" Paints positions onto mesh
|
||||
mesh: [nx, ny, nz]
|
||||
displacement field: [nx, ny, nz, 3]
|
||||
"""
|
||||
|
||||
positions = positions.reshape([-1, 3])
|
||||
mesh: [nx, ny, nz]
|
||||
positions: [npart, 3]
|
||||
"""
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
floor = jnp.floor(positions)
|
||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
||||
|
@ -28,106 +19,48 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
|
|||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||
if weight is not None:
|
||||
if jnp.isscalar(weight):
|
||||
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
|
||||
else:
|
||||
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]),
|
||||
kernel)
|
||||
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
|
||||
|
||||
neighboor_coords = jnp.mod(
|
||||
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
|
||||
jnp.array(grid_mesh.shape))
|
||||
jnp.array(mesh.shape))
|
||||
|
||||
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
|
||||
inserted_window_dims=(0, 1, 2),
|
||||
scatter_dims_to_operand_dims=(0, 1,
|
||||
2))
|
||||
mesh = lax.scatter_add(grid_mesh, neighboor_coords,
|
||||
kernel.reshape([-1, 8]), dnums)
|
||||
mesh = lax.scatter_add(mesh, neighboor_coords, kernel.reshape([-1, 8]),
|
||||
dnums)
|
||||
return mesh
|
||||
|
||||
|
||||
@partial(jax.jit, static_argnums=(3, 4))
|
||||
def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
|
||||
|
||||
positions = positions.reshape((*grid_mesh.shape, 3))
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size, sharding)
|
||||
grid_mesh = slice_pad(grid_mesh, halo_size, sharding)
|
||||
|
||||
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
||||
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
||||
grid_mesh = autoshmap(_cic_paint_impl,
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=(spec, spec, P()),
|
||||
out_specs=spec)(grid_mesh, positions, weight)
|
||||
grid_mesh = halo_exchange(grid_mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True))
|
||||
grid_mesh = slice_unpad(grid_mesh, halo_size, sharding)
|
||||
|
||||
return grid_mesh
|
||||
|
||||
|
||||
def _cic_read_impl(grid_mesh, positions):
|
||||
def cic_read(mesh, positions):
|
||||
""" Paints positions onto mesh
|
||||
mesh: [nx, ny, nz]
|
||||
positions: [nx,ny,nz, 3]
|
||||
"""
|
||||
# Save original shape for reshaping output later
|
||||
original_shape = positions.shape
|
||||
# Reshape positions to a flat list of 3D coordinates
|
||||
positions = positions.reshape([-1, 3])
|
||||
# Expand dimensions to calculate neighbor coordinates
|
||||
mesh: [nx, ny, nz]
|
||||
positions: [npart, 3]
|
||||
"""
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
# Floor the positions to get the base grid cell for each particle
|
||||
floor = jnp.floor(positions)
|
||||
# Define connections to calculate all neighbor coordinates
|
||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
||||
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
|
||||
# Calculate the 8 neighboring coordinates
|
||||
|
||||
neighboor_coords = floor + connection
|
||||
# Calculate kernel weights based on distance from each neighboring coordinate
|
||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||
# Modulo operation to wrap around edges if necessary
|
||||
|
||||
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'),
|
||||
jnp.array(grid_mesh.shape))
|
||||
# Ensure grid_mesh shape is as expected
|
||||
# Retrieve values from grid_mesh at each neighboring coordinate and multiply by kernel
|
||||
return (grid_mesh[neighboor_coords[..., 0],
|
||||
neighboor_coords[..., 1],
|
||||
neighboor_coords[..., 2]] * kernel).sum(axis=-1).reshape(original_shape[:-1]) # yapf: disable
|
||||
jnp.array(mesh.shape))
|
||||
|
||||
|
||||
@partial(jax.jit, static_argnums=(2, 3))
|
||||
def cic_read(grid_mesh, positions, halo_size=0, sharding=None):
|
||||
|
||||
original_shape = positions.shape
|
||||
positions = positions.reshape((*grid_mesh.shape, 3))
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
|
||||
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
|
||||
grid_mesh = halo_exchange(grid_mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True))
|
||||
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
||||
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
||||
|
||||
displacement = autoshmap(_cic_read_impl,
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=(spec, spec),
|
||||
out_specs=spec)(grid_mesh, positions)
|
||||
|
||||
return displacement.reshape(original_shape[:-1])
|
||||
return (mesh[neighboor_coords[..., 0], neighboor_coords[..., 1],
|
||||
neighboor_coords[..., 3]] * kernel).sum(axis=-1)
|
||||
|
||||
|
||||
def cic_paint_2d(mesh, positions, weight):
|
||||
""" Paints positions onto a 2d mesh
|
||||
mesh: [nx, ny]
|
||||
positions: [npart, 2]
|
||||
weight: [npart]
|
||||
"""
|
||||
mesh: [nx, ny]
|
||||
positions: [npart, 2]
|
||||
weight: [npart]
|
||||
"""
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
floor = jnp.floor(positions)
|
||||
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
|
||||
|
@ -151,109 +84,17 @@ def cic_paint_2d(mesh, positions, weight):
|
|||
return mesh
|
||||
|
||||
|
||||
def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
|
||||
|
||||
halo_x, _ = halo_size[0]
|
||||
halo_y, _ = halo_size[1]
|
||||
|
||||
original_shape = displacements.shape
|
||||
particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32')
|
||||
if not jnp.isscalar(weight):
|
||||
if weight.shape != original_shape[:-1]:
|
||||
raise ValueError("Weight shape must match particle shape")
|
||||
else:
|
||||
weight = weight.flatten()
|
||||
# Padding is forced to be zero in a single gpu run
|
||||
|
||||
a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]),
|
||||
jnp.arange(particle_mesh.shape[1]),
|
||||
jnp.arange(particle_mesh.shape[2]),
|
||||
indexing='ij')
|
||||
|
||||
particle_mesh = jnp.pad(particle_mesh, halo_size)
|
||||
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
|
||||
return scatter(pmid.reshape([-1, 3]),
|
||||
displacements.reshape([-1, 3]),
|
||||
particle_mesh,
|
||||
chunk_size=2**24,
|
||||
val=weight)
|
||||
|
||||
|
||||
@partial(jax.jit, static_argnums=(1, 2, 4))
|
||||
def cic_paint_dx(displacements,
|
||||
halo_size=0,
|
||||
sharding=None,
|
||||
weight=1.0,
|
||||
chunk_size=2**24):
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
|
||||
|
||||
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
||||
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
||||
grid_mesh = autoshmap(partial(_cic_paint_dx_impl,
|
||||
halo_size=halo_size,
|
||||
weight=weight,
|
||||
chunk_size=chunk_size),
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=spec,
|
||||
out_specs=spec)(displacements)
|
||||
|
||||
grid_mesh = halo_exchange(grid_mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True))
|
||||
grid_mesh = slice_unpad(grid_mesh, halo_size, sharding)
|
||||
return grid_mesh
|
||||
|
||||
|
||||
def _cic_read_dx_impl(grid_mesh, disp, halo_size):
|
||||
|
||||
halo_x, _ = halo_size[0]
|
||||
halo_y, _ = halo_size[1]
|
||||
|
||||
original_shape = [
|
||||
dim - 2 * halo[0] for dim, halo in zip(grid_mesh.shape, halo_size)
|
||||
]
|
||||
a, b, c = jnp.meshgrid(jnp.arange(original_shape[0]),
|
||||
jnp.arange(original_shape[1]),
|
||||
jnp.arange(original_shape[2]),
|
||||
indexing='ij')
|
||||
|
||||
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
|
||||
|
||||
pmid = pmid.reshape([-1, 3])
|
||||
disp = disp.reshape([-1, 3])
|
||||
|
||||
return gather(pmid, disp, grid_mesh).reshape(original_shape)
|
||||
|
||||
|
||||
@partial(jax.jit, static_argnums=(2, 3))
|
||||
def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None):
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
|
||||
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
|
||||
grid_mesh = halo_exchange(grid_mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True))
|
||||
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
||||
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
||||
displacements = autoshmap(partial(_cic_read_dx_impl, halo_size=halo_size),
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=(spec),
|
||||
out_specs=spec)(grid_mesh, disp)
|
||||
|
||||
return displacements
|
||||
|
||||
|
||||
def compensate_cic(field):
|
||||
"""
|
||||
Compensate for CiC painting
|
||||
Args:
|
||||
field: input 3D cic-painted field
|
||||
Returns:
|
||||
compensated_field
|
||||
"""
|
||||
delta_k = fft3d(field)
|
||||
Compensate for CiC painting
|
||||
Args:
|
||||
field: input 3D cic-painted field
|
||||
Returns:
|
||||
compensated_field
|
||||
"""
|
||||
nc = field.shape
|
||||
kvec = fftk(nc)
|
||||
|
||||
kvec = fftk(delta_k)
|
||||
delta_k = jnp.fft.rfftn(field)
|
||||
delta_k = cic_compensation(kvec) * delta_k
|
||||
return ifft3d(delta_k)
|
||||
return jnp.fft.irfftn(delta_k)
|
||||
|
|
|
@ -1,190 +0,0 @@
|
|||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.lax import scan
|
||||
|
||||
|
||||
def _chunk_split(ptcl_num, chunk_size, *arrays):
|
||||
"""Split and reshape particle arrays into chunks and remainders, with the remainders
|
||||
preceding the chunks. 0D ones are duplicated as full arrays in the chunks."""
|
||||
chunk_size = ptcl_num if chunk_size is None else min(chunk_size, ptcl_num)
|
||||
remainder_size = ptcl_num % chunk_size
|
||||
chunk_num = ptcl_num // chunk_size
|
||||
|
||||
remainder = None
|
||||
chunks = arrays
|
||||
if remainder_size:
|
||||
remainder = [x[:remainder_size] if x.ndim != 0 else x for x in arrays]
|
||||
chunks = [x[remainder_size:] if x.ndim != 0 else x for x in arrays]
|
||||
|
||||
# `scan` triggers errors in scatter and gather without the `full`
|
||||
chunks = [
|
||||
x.reshape(chunk_num, chunk_size, *x.shape[1:])
|
||||
if x.ndim != 0 else jnp.full(chunk_num, x) for x in chunks
|
||||
]
|
||||
|
||||
return remainder, chunks
|
||||
|
||||
|
||||
def enmesh(base_indices, displacements, cell_size, base_shape, offset,
|
||||
new_cell_size, new_shape):
|
||||
"""Multilinear enmeshing."""
|
||||
base_indices = jnp.asarray(base_indices)
|
||||
displacements = jnp.asarray(displacements)
|
||||
with jax.experimental.enable_x64():
|
||||
cell_size = jnp.float64(
|
||||
cell_size) if new_cell_size is not None else jnp.array(
|
||||
cell_size, dtype=displacements.dtype)
|
||||
if base_shape is not None:
|
||||
base_shape = jnp.array(base_shape, dtype=base_indices.dtype)
|
||||
offset = jnp.float64(offset)
|
||||
if new_cell_size is not None:
|
||||
new_cell_size = jnp.float64(new_cell_size)
|
||||
if new_shape is not None:
|
||||
new_shape = jnp.array(new_shape, dtype=base_indices.dtype)
|
||||
|
||||
spatial_dim = base_indices.shape[1]
|
||||
neighbor_offsets = (
|
||||
jnp.arange(2**spatial_dim, dtype=base_indices.dtype)[:, jnp.newaxis] >>
|
||||
jnp.arange(spatial_dim, dtype=base_indices.dtype)) & 1
|
||||
|
||||
if new_cell_size is not None:
|
||||
particle_positions = base_indices * cell_size + displacements - offset
|
||||
particle_positions = particle_positions[:, jnp.
|
||||
newaxis] # insert neighbor axis
|
||||
new_indices = particle_positions + neighbor_offsets * new_cell_size # multilinear
|
||||
|
||||
if base_shape is not None:
|
||||
grid_length = base_shape * cell_size
|
||||
new_indices %= grid_length
|
||||
|
||||
new_indices //= new_cell_size
|
||||
new_displacements = particle_positions - new_indices * new_cell_size
|
||||
|
||||
if base_shape is not None:
|
||||
new_displacements -= jnp.rint(
|
||||
new_displacements / grid_length
|
||||
) * grid_length # also abs(new_displacements) < new_cell_size is expected
|
||||
|
||||
new_indices = new_indices.astype(base_indices.dtype)
|
||||
new_displacements = new_displacements.astype(displacements.dtype)
|
||||
new_cell_size = new_cell_size.astype(displacements.dtype)
|
||||
|
||||
new_displacements /= new_cell_size
|
||||
else:
|
||||
offset_indices, offset_displacements = jnp.divmod(offset, cell_size)
|
||||
base_indices -= offset_indices.astype(base_indices.dtype)
|
||||
displacements -= offset_displacements.astype(displacements.dtype)
|
||||
|
||||
# insert neighbor axis
|
||||
base_indices = base_indices[:, jnp.newaxis]
|
||||
displacements = displacements[:, jnp.newaxis]
|
||||
|
||||
# multilinear
|
||||
displacements /= cell_size
|
||||
new_indices = jnp.floor(displacements).astype(base_indices.dtype)
|
||||
new_indices += neighbor_offsets
|
||||
new_displacements = displacements - new_indices
|
||||
new_indices += base_indices
|
||||
|
||||
if base_shape is not None:
|
||||
new_indices %= base_shape
|
||||
|
||||
weights = 1 - jnp.abs(new_displacements)
|
||||
|
||||
if base_shape is None and new_shape is not None: # all new_indices >= 0 if base_shape is not None
|
||||
new_indices = jnp.where(new_indices < 0, new_shape, new_indices)
|
||||
|
||||
weights = weights.prod(axis=-1)
|
||||
|
||||
return new_indices, weights
|
||||
|
||||
|
||||
def _scatter_chunk(carry, chunk):
|
||||
mesh, offset, cell_size, mesh_shape = carry
|
||||
pmid, disp, val = chunk
|
||||
spatial_ndim = pmid.shape[1]
|
||||
spatial_shape = mesh.shape
|
||||
|
||||
# multilinear mesh indices and fractions
|
||||
ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size,
|
||||
spatial_shape)
|
||||
# scatter
|
||||
ind = tuple(ind[..., i] for i in range(spatial_ndim))
|
||||
mesh = mesh.at[ind].add(jnp.multiply(jnp.expand_dims(val, axis=-1), frac))
|
||||
carry = mesh, offset, cell_size, mesh_shape
|
||||
return carry, None
|
||||
|
||||
|
||||
def scatter(pmid,
|
||||
disp,
|
||||
mesh,
|
||||
chunk_size=2**24,
|
||||
val=1.,
|
||||
offset=0,
|
||||
cell_size=1.):
|
||||
ptcl_num, spatial_ndim = pmid.shape
|
||||
val = jnp.asarray(val)
|
||||
mesh = jnp.asarray(mesh)
|
||||
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
|
||||
carry = mesh, offset, cell_size, mesh.shape
|
||||
if remainder is not None:
|
||||
carry = _scatter_chunk(carry, remainder)[0]
|
||||
carry = scan(_scatter_chunk, carry, chunks)[0]
|
||||
mesh = carry[0]
|
||||
return mesh
|
||||
|
||||
|
||||
def _chunk_cat(remainder_array, chunked_array):
|
||||
"""Reshape and concatenate one remainder and one chunked particle arrays."""
|
||||
array = chunked_array.reshape(-1, *chunked_array.shape[2:])
|
||||
|
||||
if remainder_array is not None:
|
||||
array = jnp.concatenate((remainder_array, array), axis=0)
|
||||
|
||||
return array
|
||||
|
||||
|
||||
def gather(pmid, disp, mesh, chunk_size=2**24, val=0, offset=0, cell_size=1.):
|
||||
ptcl_num, spatial_ndim = pmid.shape
|
||||
|
||||
mesh = jnp.asarray(mesh)
|
||||
|
||||
val = jnp.asarray(val)
|
||||
|
||||
if mesh.shape[spatial_ndim:] != val.shape[1:]:
|
||||
raise ValueError('channel shape mismatch: '
|
||||
f'{mesh.shape[spatial_ndim:]} != {val.shape[1:]}')
|
||||
|
||||
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
|
||||
|
||||
carry = mesh, offset, cell_size, mesh.shape
|
||||
val_0 = None
|
||||
if remainder is not None:
|
||||
val_0 = _gather_chunk(carry, remainder)[1]
|
||||
val = scan(_gather_chunk, carry, chunks)[1]
|
||||
|
||||
val = _chunk_cat(val_0, val)
|
||||
|
||||
return val
|
||||
|
||||
|
||||
def _gather_chunk(carry, chunk):
|
||||
mesh, offset, cell_size, mesh_shape = carry
|
||||
pmid, disp, val = chunk
|
||||
|
||||
spatial_ndim = pmid.shape[1]
|
||||
|
||||
spatial_shape = mesh.shape[:spatial_ndim]
|
||||
chan_ndim = mesh.ndim - spatial_ndim
|
||||
chan_axis = tuple(range(-chan_ndim, 0))
|
||||
|
||||
# multilinear mesh indices and fractions
|
||||
ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size,
|
||||
spatial_shape)
|
||||
|
||||
# gather
|
||||
ind = tuple(ind[..., i] for i in range(spatial_ndim))
|
||||
frac = jnp.expand_dims(frac, chan_axis)
|
||||
val += (mesh.at[ind].get(mode='drop', fill_value=0) * frac).sum(axis=1)
|
||||
|
||||
return carry, val
|
|
@ -1,129 +0,0 @@
|
|||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
|
||||
def plot_fields(fields_dict, sum_over=None):
|
||||
"""
|
||||
Plots sum projections of 3D fields along different axes,
|
||||
slicing only the first `sum_over` elements along each axis.
|
||||
|
||||
Args:
|
||||
- fields: list of 3D arrays representing fields to plot
|
||||
- names: list of names for each field, used in titles
|
||||
- sum_over: number of slices to sum along each axis (default: fields[0].shape[0] // 8)
|
||||
"""
|
||||
sum_over = sum_over or list(fields_dict.values())[0].shape[0] // 8
|
||||
nb_rows = len(fields_dict)
|
||||
nb_cols = 3
|
||||
fig, axes = plt.subplots(nb_rows, nb_cols, figsize=(15, 5 * nb_rows))
|
||||
|
||||
def plot_subplots(proj_axis, field, row, title):
|
||||
slicing = [slice(None)] * field.ndim
|
||||
slicing[proj_axis] = slice(None, sum_over)
|
||||
slicing = tuple(slicing)
|
||||
|
||||
# Sum projection over the specified axis and plot
|
||||
axes[row, proj_axis].imshow(
|
||||
field[slicing].sum(axis=proj_axis) + 1,
|
||||
cmap='magma',
|
||||
extent=[0, field.shape[proj_axis], 0, field.shape[proj_axis]])
|
||||
axes[row, proj_axis].set_xlabel('Mpc/h')
|
||||
axes[row, proj_axis].set_ylabel('Mpc/h')
|
||||
axes[row, proj_axis].set_title(title)
|
||||
|
||||
# Plot each field across the three axes
|
||||
for i, (name, field) in enumerate(fields_dict.items()):
|
||||
for proj_axis in range(3):
|
||||
plot_subplots(proj_axis, field, i,
|
||||
f"{name} projection {proj_axis}")
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_fields_single_projection(fields_dict,
|
||||
sum_over=None,
|
||||
project_axis=0,
|
||||
vmin=None,
|
||||
vmax=None,
|
||||
colorbar=False):
|
||||
"""
|
||||
Plots a single projection (along axis 0) of 3D fields in a grid,
|
||||
summing over the first `sum_over` elements along the 0-axis, with 4 images per row.
|
||||
|
||||
Args:
|
||||
- fields_dict: dictionary where keys are field names and values are 3D arrays
|
||||
- sum_over: number of slices to sum along the projection axis (default: fields[0].shape[0] // 8)
|
||||
"""
|
||||
sum_over = sum_over or list(fields_dict.values())[0].shape[0] // 8
|
||||
nb_fields = len(fields_dict)
|
||||
nb_cols = 4 # Set number of images per row
|
||||
nb_rows = (nb_fields + nb_cols - 1) // nb_cols # Calculate required rows
|
||||
|
||||
fig, axes = plt.subplots(nb_rows,
|
||||
nb_cols,
|
||||
figsize=(5 * nb_cols, 5 * nb_rows))
|
||||
axes = np.atleast_2d(axes) # Ensure axes is always a 2D array
|
||||
|
||||
for i, (name, field) in enumerate(fields_dict.items()):
|
||||
row, col = divmod(i, nb_cols)
|
||||
|
||||
# Define the slice for the 0-axis projection
|
||||
slicing = [slice(None)] * field.ndim
|
||||
slicing[project_axis] = slice(None, sum_over)
|
||||
slicing = tuple(slicing)
|
||||
|
||||
# Sum projection over axis 0 and plot
|
||||
a = axes[row,
|
||||
col].imshow(field[slicing].sum(axis=project_axis) + 1,
|
||||
cmap='magma',
|
||||
extent=[0, field.shape[1], 0, field.shape[2]],
|
||||
vmin=vmin,
|
||||
vmax=vmax)
|
||||
axes[row, col].set_xlabel('Mpc/h')
|
||||
axes[row, col].set_ylabel('Mpc/h')
|
||||
axes[row, col].set_title(f"{name} projection 0")
|
||||
if colorbar:
|
||||
fig.colorbar(a, ax=axes[row, col], shrink=0.7)
|
||||
|
||||
# Remove any empty subplots
|
||||
for j in range(i + 1, nb_rows * nb_cols):
|
||||
fig.delaxes(axes.flatten()[j])
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
def stack_slices(array):
|
||||
"""
|
||||
Stacks 2D slices of an array into a single array based on provided partition dimensions.
|
||||
|
||||
Args:
|
||||
- array_slices: a 2D list of array slices (list of lists format) where
|
||||
array_slices[i][j] is the slice located at row i, column j in the grid.
|
||||
- pdims: a tuple representing the grid dimensions (rows, columns).
|
||||
|
||||
Returns:
|
||||
- A single array constructed by stacking the slices.
|
||||
"""
|
||||
# Initialize an empty list to store the vertically stacked rows
|
||||
pdims = array.sharding.mesh.devices.shape
|
||||
|
||||
field_slices = []
|
||||
|
||||
# Iterate over rows in pdims[0]
|
||||
for i in range(pdims[0]):
|
||||
row_slices = []
|
||||
|
||||
# Iterate over columns in pdims[1]
|
||||
for j in range(pdims[1]):
|
||||
slice_index = i * pdims[0] + j
|
||||
row_slices.append(array.addressable_data(slice_index))
|
||||
# Stack the current row of slices vertically
|
||||
stacked_row = np.hstack(row_slices)
|
||||
field_slices.append(stacked_row)
|
||||
|
||||
# Stack all rows horizontally to form the full array
|
||||
full_array = np.vstack(field_slices)
|
||||
|
||||
return full_array
|
206
jaxpm/pm.py
206
jaxpm/pm.py
|
@ -1,92 +1,50 @@
|
|||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax_cosmo as jc
|
||||
from jax_cosmo import Cosmology
|
||||
|
||||
from jaxpm.distributed import fft3d, ifft3d, normal_field
|
||||
from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second,
|
||||
growth_rate, growth_rate_second)
|
||||
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel,
|
||||
invlaplace_kernel, longrange_kernel)
|
||||
from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx
|
||||
from jaxpm.growth import growth_factor, growth_rate, dGfa, growth_factor_second, growth_rate_second, dGf2a
|
||||
from jaxpm.kernels import PGD_kernel, fftk, gradient_kernel, invlaplace_kernel, longrange_kernel
|
||||
from jaxpm.painting import cic_paint, cic_read
|
||||
|
||||
|
||||
def pm_forces(positions,
|
||||
mesh_shape=None,
|
||||
delta=None,
|
||||
r_split=0,
|
||||
paint_absolute_pos=True,
|
||||
halo_size=0,
|
||||
sharding=None):
|
||||
|
||||
def pm_forces(positions, mesh_shape, delta=None, r_split=0):
|
||||
"""
|
||||
Computes gravitational forces on particles using a PM scheme
|
||||
"""
|
||||
if mesh_shape is None:
|
||||
assert (delta is not None),\
|
||||
"If mesh_shape is not provided, delta should be provided"
|
||||
mesh_shape = delta.shape
|
||||
|
||||
if paint_absolute_pos:
|
||||
paint_fn = lambda pos: cic_paint(jnp.zeros(shape=mesh_shape,
|
||||
device=sharding),
|
||||
pos,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
read_fn = lambda grid_mesh, pos: cic_read(
|
||||
grid_mesh, pos, halo_size=halo_size, sharding=sharding)
|
||||
else:
|
||||
paint_fn = lambda disp: cic_paint_dx(
|
||||
disp, halo_size=halo_size, sharding=sharding)
|
||||
read_fn = lambda grid_mesh, disp: cic_read_dx(
|
||||
grid_mesh, disp, halo_size=halo_size, sharding=sharding)
|
||||
|
||||
if delta is None:
|
||||
field = paint_fn(positions)
|
||||
delta_k = fft3d(field)
|
||||
delta_k = jnp.fft.rfftn(cic_paint(jnp.zeros(mesh_shape), positions))
|
||||
elif jnp.isrealobj(delta):
|
||||
delta_k = fft3d(delta)
|
||||
delta_k = jnp.fft.rfftn(delta)
|
||||
else:
|
||||
delta_k = delta
|
||||
|
||||
kvec = fftk(delta_k)
|
||||
# Computes gravitational potential
|
||||
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(
|
||||
kvec, r_split=r_split)
|
||||
kvec = fftk(mesh_shape)
|
||||
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec, r_split=r_split)
|
||||
# Computes gravitational forces
|
||||
forces = jnp.stack([
|
||||
read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),positions
|
||||
) for i in range(3)], axis=-1) # yapf: disable
|
||||
|
||||
return forces
|
||||
return jnp.stack([cic_read(jnp.fft.irfftn(- gradient_kernel(kvec, i) * pot_k), positions)
|
||||
for i in range(3)], axis=-1)
|
||||
|
||||
|
||||
def lpt(cosmo,
|
||||
initial_conditions,
|
||||
particles=None,
|
||||
a=0.1,
|
||||
halo_size=0,
|
||||
sharding=None,
|
||||
order=1):
|
||||
def lpt(cosmo:Cosmology, init_mesh, positions, a, order=1):
|
||||
"""
|
||||
Computes first and second order LPT displacement and momentum,
|
||||
Computes first and second order LPT displacement and momentum,
|
||||
e.g. Eq. 2 and 3 [Jenkins2010](https://arxiv.org/pdf/0910.0258)
|
||||
"""
|
||||
paint_absolute_pos = particles is not None
|
||||
if particles is None:
|
||||
particles = jnp.zeros_like(initial_conditions,
|
||||
shape=(*initial_conditions.shape, 3))
|
||||
|
||||
a = jnp.atleast_1d(a)
|
||||
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
|
||||
delta_k = fft3d(initial_conditions)
|
||||
initial_force = pm_forces(particles,
|
||||
delta=delta_k,
|
||||
paint_absolute_pos=paint_absolute_pos,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
dx = growth_factor(cosmo, a) * initial_force
|
||||
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
|
||||
delta_k = jnp.fft.rfftn(init_mesh) # TODO: pass the modes directly to save one or two fft?
|
||||
mesh_shape = init_mesh.shape
|
||||
|
||||
init_force = pm_forces(positions, mesh_shape, delta=delta_k)
|
||||
dx = growth_factor(cosmo, a) * init_force
|
||||
p = a**2 * growth_rate(cosmo, a) * E * dx
|
||||
f = a**2 * E * dGfa(cosmo, a) * initial_force
|
||||
f = a**2 * E * dGfa(cosmo, a) * init_force
|
||||
|
||||
if order == 2:
|
||||
kvec = fftk(delta_k)
|
||||
kvec = fftk(mesh_shape)
|
||||
pot_k = delta_k * invlaplace_kernel(kvec)
|
||||
|
||||
delta2 = 0
|
||||
|
@ -96,58 +54,47 @@ def lpt(cosmo,
|
|||
# Add products of diagonal terms = 0 + s11*s00 + s22*(s11+s00)...
|
||||
# shear_ii = jnp.fft.irfftn(- ki**2 * pot_k)
|
||||
nabla_i_nabla_i = gradient_kernel(kvec, i)**2
|
||||
shear_ii = ifft3d(nabla_i_nabla_i * pot_k)
|
||||
delta2 += shear_ii * shear_acc
|
||||
shear_ii = jnp.fft.irfftn(nabla_i_nabla_i * pot_k)
|
||||
delta2 += shear_ii * shear_acc
|
||||
shear_acc += shear_ii
|
||||
|
||||
# for kj in kvec[i+1:]:
|
||||
for j in range(i + 1, 3):
|
||||
for j in range(i+1, 3):
|
||||
# Substract squared strict-up-triangle terms
|
||||
# delta2 -= jnp.fft.irfftn(- ki * kj * pot_k)**2
|
||||
nabla_i_nabla_j = gradient_kernel(kvec, i) * gradient_kernel(
|
||||
kvec, j)
|
||||
delta2 -= ifft3d(nabla_i_nabla_j * pot_k)**2
|
||||
nabla_i_nabla_j = gradient_kernel(kvec, i) * gradient_kernel(kvec, j)
|
||||
delta2 -= jnp.fft.irfftn(nabla_i_nabla_j * pot_k)**2
|
||||
|
||||
delta_k2 = fft3d(delta2)
|
||||
init_force2 = pm_forces(particles,
|
||||
delta=delta_k2,
|
||||
paint_absolute_pos=paint_absolute_pos,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
init_force2 = pm_forces(positions, mesh_shape, delta=jnp.fft.rfftn(delta2))
|
||||
# NOTE: growth_factor_second is renormalized: - D2 = 3/7 * growth_factor_second
|
||||
dx2 = 3 / 7 * growth_factor_second(cosmo, a) * init_force2
|
||||
dx2 = 3/7 * growth_factor_second(cosmo, a) * init_force2
|
||||
p2 = a**2 * growth_rate_second(cosmo, a) * E * dx2
|
||||
f2 = a**2 * E * dGf2a(cosmo, a) * init_force2
|
||||
|
||||
dx += dx2
|
||||
p += p2
|
||||
f += f2
|
||||
p += p2
|
||||
f += f2
|
||||
|
||||
return dx, p, f
|
||||
|
||||
|
||||
def linear_field(mesh_shape, box_size, pk, seed, sharding=None):
|
||||
def linear_field(mesh_shape, box_size, pk, seed):
|
||||
"""
|
||||
Generate initial conditions.
|
||||
"""
|
||||
# Initialize a random field with one slice on each gpu
|
||||
field = normal_field(mesh_shape, seed=seed, sharding=sharding)
|
||||
field = fft3d(field)
|
||||
kvec = fftk(field)
|
||||
kvec = fftk(mesh_shape)
|
||||
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2
|
||||
for i, kk in enumerate(kvec))**0.5
|
||||
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (
|
||||
box_size[0] * box_size[1] * box_size[2])
|
||||
|
||||
field = field * (pkmesh)**0.5
|
||||
field = ifft3d(field)
|
||||
field = jax.random.normal(seed, mesh_shape)
|
||||
field = jnp.fft.rfftn(field) * pkmesh**0.5
|
||||
field = jnp.fft.irfftn(field)
|
||||
return field
|
||||
|
||||
|
||||
def make_ode_fn(mesh_shape,
|
||||
paint_absolute_pos=True,
|
||||
halo_size=0,
|
||||
sharding=None):
|
||||
def make_ode_fn(mesh_shape):
|
||||
|
||||
def nbody_ode(state, a, cosmo):
|
||||
"""
|
||||
|
@ -155,11 +102,7 @@ def make_ode_fn(mesh_shape,
|
|||
"""
|
||||
pos, vel = state
|
||||
|
||||
forces = pm_forces(pos,
|
||||
mesh_shape=mesh_shape,
|
||||
paint_absolute_pos=paint_absolute_pos,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding) * 1.5 * cosmo.Omega_m
|
||||
forces = pm_forces(pos, mesh_shape=mesh_shape) * 1.5 * cosmo.Omega_m
|
||||
|
||||
# Computes the update of position (drift)
|
||||
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
|
||||
|
@ -171,28 +114,20 @@ def make_ode_fn(mesh_shape,
|
|||
|
||||
return nbody_ode
|
||||
|
||||
|
||||
def make_diffrax_ode(cosmo,
|
||||
mesh_shape,
|
||||
paint_absolute_pos=True,
|
||||
halo_size=0,
|
||||
sharding=None):
|
||||
def get_ode_fn(cosmo:Cosmology, mesh_shape):
|
||||
|
||||
def nbody_ode(a, state, args):
|
||||
"""
|
||||
state is a tuple (position, velocities)
|
||||
State is an array [position, velocities]
|
||||
|
||||
Compatible with [Diffrax API](https://docs.kidger.site/diffrax/)
|
||||
"""
|
||||
pos, vel = state
|
||||
|
||||
forces = pm_forces(pos,
|
||||
mesh_shape=mesh_shape,
|
||||
paint_absolute_pos=paint_absolute_pos,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding) * 1.5 * cosmo.Omega_m
|
||||
forces = pm_forces(pos, mesh_shape) * 1.5 * cosmo.Omega_m
|
||||
|
||||
# Computes the update of position (drift)
|
||||
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
|
||||
|
||||
|
||||
# Computes the update of velocity (kick)
|
||||
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
|
||||
|
||||
|
@ -203,57 +138,51 @@ def make_diffrax_ode(cosmo,
|
|||
|
||||
def pgd_correction(pos, mesh_shape, params):
|
||||
"""
|
||||
improve the short-range interactions of PM-Nbody simulations with potential gradient descent method,
|
||||
improve the short-range interactions of PM-Nbody simulations with potential gradient descent method,
|
||||
based on https://arxiv.org/abs/1804.00671
|
||||
|
||||
args:
|
||||
pos: particle positions [npart, 3]
|
||||
params: [alpha, kl, ks] pgd parameters
|
||||
"""
|
||||
kvec = fftk(mesh_shape)
|
||||
delta = cic_paint(jnp.zeros(mesh_shape), pos)
|
||||
delta_k = fft3d(delta)
|
||||
kvec = fftk(delta_k)
|
||||
alpha, kl, ks = params
|
||||
PGD_range = PGD_kernel(kvec, kl, ks)
|
||||
|
||||
pot_k_pgd = (delta_k * invlaplace_kernel(kvec)) * PGD_range
|
||||
|
||||
forces_pgd = jnp.stack([
|
||||
cic_read(fft3d(-gradient_kernel(kvec, i) * pot_k_pgd), pos)
|
||||
for i in range(3)
|
||||
],
|
||||
axis=-1)
|
||||
|
||||
dpos_pgd = forces_pgd * alpha
|
||||
delta_k = jnp.fft.rfftn(delta)
|
||||
PGD_range=PGD_kernel(kvec, kl, ks)
|
||||
|
||||
pot_k_pgd=(delta_k * invlaplace_kernel(kvec))*PGD_range
|
||||
|
||||
forces_pgd= jnp.stack([cic_read(jnp.fft.irfftn(- gradient_kernel(kvec, i)*pot_k_pgd), pos)
|
||||
for i in range(3)],axis=-1)
|
||||
|
||||
dpos_pgd = forces_pgd*alpha
|
||||
|
||||
return dpos_pgd
|
||||
|
||||
|
||||
def make_neural_ode_fn(model, mesh_shape):
|
||||
|
||||
def neural_nbody_ode(state, a, cosmo: Cosmology, params):
|
||||
def neural_nbody_ode(state, a, cosmo:Cosmology, params):
|
||||
"""
|
||||
state is a tuple (position, velocities)
|
||||
"""
|
||||
pos, vel = state
|
||||
kvec = fftk(mesh_shape)
|
||||
|
||||
delta = cic_paint(jnp.zeros(mesh_shape), pos)
|
||||
delta_k = fft3d(delta)
|
||||
kvec = fftk(delta_k)
|
||||
|
||||
delta_k = jnp.fft.rfftn(delta)
|
||||
|
||||
# Computes gravitational potential
|
||||
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec,
|
||||
r_split=0)
|
||||
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec, r_split=0)
|
||||
|
||||
# Apply a correction filter
|
||||
kk = jnp.sqrt(sum((ki / jnp.pi)**2 for ki in kvec))
|
||||
pot_k = pot_k * (1. + model.apply(params, kk, jnp.atleast_1d(a)))
|
||||
kk = jnp.sqrt(sum((ki/jnp.pi)**2 for ki in kvec))
|
||||
pot_k = pot_k *(1. + model.apply(params, kk, jnp.atleast_1d(a)))
|
||||
|
||||
# Computes gravitational forces
|
||||
forces = jnp.stack([
|
||||
cic_read(fft3d(-gradient_kernel(kvec, i) * pot_k), pos)
|
||||
for i in range(3)
|
||||
],
|
||||
axis=-1)
|
||||
forces = jnp.stack([cic_read(jnp.fft.irfftn(- gradient_kernel(kvec, i)*pot_k), pos)
|
||||
for i in range(3)],axis=-1)
|
||||
|
||||
forces = forces * 1.5 * cosmo.Omega_m
|
||||
|
||||
|
@ -264,5 +193,4 @@ def make_neural_ode_fn(model, mesh_shape):
|
|||
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
|
||||
|
||||
return dpos, dvel
|
||||
|
||||
return neural_nbody_ode
|
||||
|
|
227
jaxpm/utils.py
227
jaxpm/utils.py
|
@ -1,168 +1,47 @@
|
|||
from functools import partial
|
||||
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from jax.scipy.stats import norm
|
||||
from scipy.special import legendre
|
||||
|
||||
__all__ = [
|
||||
'power_spectrum', 'transfer', 'coherence', 'pktranscoh',
|
||||
'cross_correlation_coefficients', 'gaussian_smoothing'
|
||||
]
|
||||
__all__ = ['power_spectrum']
|
||||
|
||||
|
||||
def _initialize_pk(mesh_shape, box_shape, kedges, los):
|
||||
def _initialize_pk(shape, boxsize, kmin, dk):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
mesh_shape : tuple of int
|
||||
Shape of the mesh grid.
|
||||
box_shape : tuple of float
|
||||
Physical dimensions of the box.
|
||||
kedges : None, int, float, or list
|
||||
If None, set dk to twice the minimum.
|
||||
If int, specifies number of edges.
|
||||
If float, specifies dk.
|
||||
los : array_like
|
||||
Line-of-sight vector.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dig : ndarray
|
||||
Indices of the bins to which each value in input array belongs.
|
||||
kcount : ndarray
|
||||
Count of values in each bin.
|
||||
kedges : ndarray
|
||||
Edges of the bins.
|
||||
mumesh : ndarray
|
||||
Mu values for the mesh grid.
|
||||
Helper function to initialize various (fixed) values for powerspectra... not differentiable!
|
||||
"""
|
||||
kmax = np.pi * np.min(mesh_shape / box_shape) # = knyquist
|
||||
I = np.eye(len(shape), dtype='int') * -2 + 1
|
||||
|
||||
if isinstance(kedges, None | int | float):
|
||||
if kedges is None:
|
||||
dk = 2 * np.pi / np.min(
|
||||
box_shape) * 2 # twice the minimum wavenumber
|
||||
if isinstance(kedges, int):
|
||||
dk = kmax / (kedges + 1) # final number of bins will be kedges-1
|
||||
elif isinstance(kedges, float):
|
||||
dk = kedges
|
||||
kedges = np.arange(dk, kmax, dk) + dk / 2 # from dk/2 to kmax-dk/2
|
||||
W = np.empty(shape, dtype='f4')
|
||||
W[...] = 2.0
|
||||
W[..., 0] = 1.0
|
||||
W[..., -1] = 1.0
|
||||
|
||||
kshapes = np.eye(len(mesh_shape), dtype=np.int32) * -2 + 1
|
||||
kvec = [(2 * np.pi * m / l) * np.fft.fftfreq(m).reshape(kshape)
|
||||
for m, l, kshape in zip(mesh_shape, box_shape, kshapes)]
|
||||
kmesh = sum(ki**2 for ki in kvec)**0.5
|
||||
kmax = np.pi * np.min(np.array(shape)) / np.max(np.array(boxsize)) + dk / 2
|
||||
kedges = np.arange(kmin, kmax, dk)
|
||||
|
||||
dig = np.digitize(kmesh.reshape(-1), kedges)
|
||||
kcount = np.bincount(dig, minlength=len(kedges) + 1)
|
||||
k = [
|
||||
np.fft.fftfreq(N, 1. / (N * 2 * np.pi / L))[:pkshape].reshape(kshape)
|
||||
for N, L, kshape, pkshape in zip(shape, boxsize, I, shape)
|
||||
]
|
||||
kmag = sum(ki**2 for ki in k)**0.5
|
||||
|
||||
# Central value of each bin
|
||||
# kavg = (kedges[1:] + kedges[:-1]) / 2
|
||||
kavg = np.bincount(
|
||||
dig, weights=kmesh.reshape(-1), minlength=len(kedges) + 1) / kcount
|
||||
kavg = kavg[1:-1]
|
||||
xsum = np.zeros(len(kedges) + 1)
|
||||
Nsum = np.zeros(len(kedges) + 1)
|
||||
|
||||
if los is None:
|
||||
mumesh = 1.
|
||||
else:
|
||||
mumesh = sum(ki * losi for ki, losi in zip(kvec, los))
|
||||
kmesh_nozeros = np.where(kmesh == 0, 1, kmesh)
|
||||
mumesh = np.where(kmesh == 0, 0, mumesh / kmesh_nozeros)
|
||||
dig = np.digitize(kmag.flat, kedges)
|
||||
|
||||
return dig, kcount, kavg, mumesh
|
||||
xsum.flat += np.bincount(dig, weights=(W * kmag).flat, minlength=xsum.size)
|
||||
Nsum.flat += np.bincount(dig, weights=W.flat, minlength=xsum.size)
|
||||
return dig, Nsum, xsum, W, k, kedges
|
||||
|
||||
|
||||
def power_spectrum(mesh,
|
||||
mesh2=None,
|
||||
box_shape=None,
|
||||
kedges: int | float | list = None,
|
||||
multipoles=0,
|
||||
los=[0., 0., 1.]):
|
||||
def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
|
||||
"""
|
||||
Compute the auto and cross spectrum of 3D fields, with multipoles.
|
||||
"""
|
||||
# Initialize
|
||||
mesh_shape = np.array(mesh.shape)
|
||||
if box_shape is None:
|
||||
box_shape = mesh_shape
|
||||
else:
|
||||
box_shape = np.asarray(box_shape)
|
||||
|
||||
if multipoles == 0:
|
||||
los = None
|
||||
else:
|
||||
los = np.asarray(los)
|
||||
los = los / np.linalg.norm(los)
|
||||
poles = np.atleast_1d(multipoles)
|
||||
dig, kcount, kavg, mumesh = _initialize_pk(mesh_shape, box_shape, kedges,
|
||||
los)
|
||||
n_bins = len(kavg) + 2
|
||||
|
||||
# FFTs
|
||||
meshk = jnp.fft.fftn(mesh, norm='ortho')
|
||||
if mesh2 is None:
|
||||
mmk = meshk.real**2 + meshk.imag**2
|
||||
else:
|
||||
mmk = meshk * jnp.fft.fftn(mesh2, norm='ortho').conj()
|
||||
|
||||
# Sum powers
|
||||
pk = jnp.empty((len(poles), n_bins))
|
||||
for i_ell, ell in enumerate(poles):
|
||||
weights = (mmk * (2 * ell + 1) * legendre(ell)(mumesh)).reshape(-1)
|
||||
if mesh2 is None:
|
||||
psum = jnp.bincount(dig, weights=weights, length=n_bins)
|
||||
else: # XXX: bincount is really slow with complex numbers
|
||||
psum_real = jnp.bincount(dig, weights=weights.real, length=n_bins)
|
||||
psum_imag = jnp.bincount(dig, weights=weights.imag, length=n_bins)
|
||||
psum = (psum_real**2 + psum_imag**2)**.5
|
||||
pk = pk.at[i_ell].set(psum)
|
||||
|
||||
# Normalization and conversion from cell units to [Mpc/h]^3
|
||||
pk = (pk / kcount)[:, 1:-1] * (box_shape / mesh_shape).prod()
|
||||
|
||||
# pk = jnp.concatenate([kavg[None], pk])
|
||||
if np.ndim(multipoles) == 0:
|
||||
return kavg, pk[0]
|
||||
else:
|
||||
return kavg, pk
|
||||
|
||||
|
||||
def transfer(mesh0, mesh1, box_shape, kedges: int | float | list = None):
|
||||
pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
|
||||
ks, pk0 = pk_fn(mesh0)
|
||||
ks, pk1 = pk_fn(mesh1)
|
||||
return ks, (pk1 / pk0)**.5
|
||||
|
||||
|
||||
def coherence(mesh0, mesh1, box_shape, kedges: int | float | list = None):
|
||||
pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
|
||||
ks, pk01 = pk_fn(mesh0, mesh1)
|
||||
ks, pk0 = pk_fn(mesh0)
|
||||
ks, pk1 = pk_fn(mesh1)
|
||||
return ks, pk01 / (pk0 * pk1)**.5
|
||||
|
||||
|
||||
def pktranscoh(mesh0, mesh1, box_shape, kedges: int | float | list = None):
|
||||
pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
|
||||
ks, pk01 = pk_fn(mesh0, mesh1)
|
||||
ks, pk0 = pk_fn(mesh0)
|
||||
ks, pk1 = pk_fn(mesh1)
|
||||
return ks, pk0, pk1, (pk1 / pk0)**.5, pk01 / (pk0 * pk1)**.5
|
||||
|
||||
|
||||
def cross_correlation_coefficients(field_a,
|
||||
field_b,
|
||||
kmin=5,
|
||||
dk=0.5,
|
||||
boxsize=False):
|
||||
"""
|
||||
Calculate the cross correlation coefficients given two real space field
|
||||
Calculate the powerspectra given real space field
|
||||
|
||||
Args:
|
||||
|
||||
field_a: real valued field
|
||||
field_b: real valued field
|
||||
field: real valued field
|
||||
kmin: minimum k-value for binned powerspectra
|
||||
dk: differential in each kbin
|
||||
boxsize: length of each boxlength (can be strangly shaped?)
|
||||
|
@ -170,21 +49,20 @@ def cross_correlation_coefficients(field_a,
|
|||
Returns:
|
||||
|
||||
kbins: the central value of the bins for plotting
|
||||
P / norm: normalized cross correlation coefficient between two field a and b
|
||||
power: real valued array of power in each bin
|
||||
|
||||
"""
|
||||
shape = field_a.shape
|
||||
shape = field.shape
|
||||
nx, ny, nz = shape
|
||||
|
||||
#initialze values related to powerspectra (mode bins and weights)
|
||||
dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk)
|
||||
|
||||
#fast fourier transform
|
||||
fft_image_a = jnp.fft.fftn(field_a)
|
||||
fft_image_b = jnp.fft.fftn(field_b)
|
||||
fft_image = jnp.fft.fftn(field)
|
||||
|
||||
#absolute value of fast fourier transform
|
||||
pk = fft_image_a * jnp.conj(fft_image_b)
|
||||
pk = jnp.real(fft_image * jnp.conj(fft_image))
|
||||
|
||||
#calculating powerspectra
|
||||
real = jnp.real(pk).reshape([-1])
|
||||
|
@ -205,6 +83,55 @@ def cross_correlation_coefficients(field_a,
|
|||
return kbins, P / norm
|
||||
|
||||
|
||||
def cross_correlation_coefficients(field_a,field_b, kmin=5, dk=0.5, boxsize=False):
|
||||
"""
|
||||
Calculate the cross correlation coefficients given two real space field
|
||||
|
||||
Args:
|
||||
|
||||
field_a: real valued field
|
||||
field_b: real valued field
|
||||
kmin: minimum k-value for binned powerspectra
|
||||
dk: differential in each kbin
|
||||
boxsize: length of each boxlength (can be strangly shaped?)
|
||||
|
||||
Returns:
|
||||
|
||||
kbins: the central value of the bins for plotting
|
||||
P / norm: normalized cross correlation coefficient between two field a and b
|
||||
|
||||
"""
|
||||
shape = field_a.shape
|
||||
nx, ny, nz = shape
|
||||
|
||||
#initialze values related to powerspectra (mode bins and weights)
|
||||
dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk)
|
||||
|
||||
#fast fourier transform
|
||||
fft_image_a = jnp.fft.fftn(field_a)
|
||||
fft_image_b = jnp.fft.fftn(field_b)
|
||||
|
||||
#absolute value of fast fourier transform
|
||||
pk = fft_image_a * jnp.conj(fft_image_b)
|
||||
|
||||
#calculating powerspectra
|
||||
real = jnp.real(pk).reshape([-1])
|
||||
imag = jnp.imag(pk).reshape([-1])
|
||||
|
||||
Psum = jnp.bincount(dig, weights=(W.flatten() * imag), length=xsum.size) * 1j
|
||||
Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
|
||||
|
||||
P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')
|
||||
|
||||
#normalization for powerspectra
|
||||
norm = np.prod(np.array(shape[:])).astype('float32')**2
|
||||
|
||||
#find central values of each bin
|
||||
kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2
|
||||
|
||||
return kbins, P / norm
|
||||
|
||||
|
||||
def gaussian_smoothing(im, sigma):
|
||||
"""
|
||||
im: 2d image
|
||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -1,179 +0,0 @@
|
|||
import os
|
||||
|
||||
os.environ["EQX_ON_ERROR"] = "nan" # avoid an allgather caused by diffrax
|
||||
import jax
|
||||
|
||||
jax.distributed.initialize()
|
||||
rank = jax.process_index()
|
||||
size = jax.process_count()
|
||||
if rank == 0:
|
||||
print(f"SIZE is {jax.device_count()}")
|
||||
|
||||
import argparse
|
||||
from functools import partial
|
||||
|
||||
import jax.numpy as jnp
|
||||
import jax_cosmo as jc
|
||||
import numpy as np
|
||||
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
|
||||
PIDController, SaveAt, diffeqsolve)
|
||||
from jax.experimental.mesh_utils import create_device_mesh
|
||||
from jax.experimental.multihost_utils import process_allgather
|
||||
from jax.sharding import Mesh, NamedSharding
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
from jaxpm.kernels import interpolate_power_spectrum
|
||||
from jaxpm.painting import cic_paint_dx
|
||||
from jaxpm.pm import linear_field, lpt, make_diffrax_ode
|
||||
|
||||
all_gather = partial(process_allgather, tiled=True)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run a cosmological simulation with JAX.")
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--pdims",
|
||||
type=int,
|
||||
nargs=2,
|
||||
default=[1, jax.devices()],
|
||||
help="Processor grid dimensions as two integers (e.g., 2 4).")
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--mesh_shape",
|
||||
type=int,
|
||||
nargs=3,
|
||||
default=[512, 512, 512],
|
||||
help="Shape of the simulation mesh as three values (e.g., 512 512 512)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--box_size",
|
||||
type=float,
|
||||
nargs=3,
|
||||
default=[500.0, 500.0, 500.0],
|
||||
help=
|
||||
"Box size of the simulation as three values (e.g., 500.0 500.0 1000.0)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-st",
|
||||
"--snapshots",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Number of snapshots to save during the simulation.")
|
||||
parser.add_argument("-H",
|
||||
"--halo_size",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Halo size for the simulation.")
|
||||
parser.add_argument("-s",
|
||||
"--solver",
|
||||
type=str,
|
||||
choices=['leapfrog', 'dopri8'],
|
||||
default='leapfrog',
|
||||
help="ODE solver choice: 'leapfrog' or 'dopri8'.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def create_mesh_and_sharding(pdims):
|
||||
devices = create_device_mesh(pdims)
|
||||
mesh = Mesh(devices, axis_names=('x', 'y'))
|
||||
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||
return mesh, sharding
|
||||
|
||||
|
||||
@partial(jax.jit, static_argnums=(2, 3, 4, 5, 6))
|
||||
def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
|
||||
solver_choice, nb_snapshots, sharding):
|
||||
k = jnp.logspace(-4, 1, 128)
|
||||
pk = jc.power.linear_matter_power(
|
||||
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
|
||||
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)
|
||||
|
||||
initial_conditions = linear_field(mesh_shape,
|
||||
box_size,
|
||||
pk_fn,
|
||||
seed=jax.random.PRNGKey(0),
|
||||
sharding=sharding)
|
||||
|
||||
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
|
||||
|
||||
dx, p, _ = lpt(cosmo,
|
||||
initial_conditions,
|
||||
a=0.1,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
|
||||
|
||||
# Choose solver
|
||||
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()
|
||||
stepsize_controller = ConstantStepSize(
|
||||
) if solver_choice == "leapfrog" else PIDController(rtol=1e-5, atol=1e-5)
|
||||
res = diffeqsolve(ode_fn,
|
||||
solver,
|
||||
t0=0.1,
|
||||
t1=1.,
|
||||
dt0=0.01,
|
||||
y0=jnp.stack([dx, p], axis=0),
|
||||
args=cosmo,
|
||||
saveat=SaveAt(ts=jnp.linspace(0.2, 1., nb_snapshots)),
|
||||
stepsize_controller=stepsize_controller)
|
||||
|
||||
ode_fields = [
|
||||
cic_paint_dx(sol[0], halo_size=halo_size, sharding=sharding)
|
||||
for sol in res.ys
|
||||
]
|
||||
lpt_field = cic_paint_dx(dx, halo_size=halo_size, sharding=sharding)
|
||||
return initial_conditions, lpt_field, ode_fields, res.stats
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_arguments()
|
||||
mesh_shape = args.mesh_shape
|
||||
box_size = args.box_size
|
||||
halo_size = args.halo_size
|
||||
solver_choice = args.solver
|
||||
nb_snapshots = args.snapshots
|
||||
|
||||
sharding = create_mesh_and_sharding(args.pdims)
|
||||
|
||||
initial_conditions, lpt_displacements, ode_solutions, solver_stats = run_simulation(
|
||||
0.25, 0.8, tuple(mesh_shape), tuple(box_size), halo_size,
|
||||
solver_choice, nb_snapshots, sharding)
|
||||
|
||||
if rank == 0:
|
||||
os.makedirs("fields", exist_ok=True)
|
||||
print(f"[{rank}] Simulation done")
|
||||
print(f"Solver stats: {solver_stats}")
|
||||
|
||||
# Save initial conditions
|
||||
initial_conditions_g = all_gather(initial_conditions)
|
||||
if rank == 0:
|
||||
print(f"[{rank}] Saving initial_conditions")
|
||||
np.save("fields/initial_conditions.npy", initial_conditions_g)
|
||||
print(f"[{rank}] initial_conditions saved")
|
||||
del initial_conditions_g, initial_conditions
|
||||
|
||||
# Save LPT displacements
|
||||
lpt_displacements_g = all_gather(lpt_displacements)
|
||||
if rank == 0:
|
||||
print(f"[{rank}] Saving lpt_displacements")
|
||||
np.save("fields/lpt_displacements.npy", lpt_displacements_g)
|
||||
print(f"[{rank}] lpt_displacements saved")
|
||||
del lpt_displacements_g, lpt_displacements
|
||||
|
||||
# Save each ODE solution separately
|
||||
for i, sol in enumerate(ode_solutions):
|
||||
sol_g = all_gather(sol)
|
||||
if rank == 0:
|
||||
print(f"[{rank}] Saving ode_solution_{i}")
|
||||
np.save(f"fields/ode_solution_{i}.npy", sol_g)
|
||||
print(f"[{rank}] ode_solution_{i} saved")
|
||||
del sol_g
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,320 +0,0 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# **Animating Particle Mesh density fields**\n",
|
||||
"\n",
|
||||
"In this tutorial, we will animate the density field of a particle mesh simulation. We will use the `manim` library to create the animation. \n",
|
||||
"\n",
|
||||
"The density fields are created exactly like in the notebook [**05-MultiHost_PM.ipynb**](05-MultiHost_PM.ipynb) using the same script [**05-MultiHost_PM.py**](05-MultiHost_PM.py)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To run a multi-host simulation, you first need to **allocate a job** with `salloc`. This command requests resources on an HPC cluster.\n",
|
||||
"\n",
|
||||
"just like in notebook [**05-MultiHost_PM.ipynb**]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!salloc --account=XXX@a100 -C a100 --gres=gpu:8 --ntasks-per-node=8 --time=00:40:00 --cpus-per-task=8 --hint=nomultithread --qos=qos_gpu-dev --nodes=4 & "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**A few hours later**\n",
|
||||
"\n",
|
||||
"Use `!squeue -u $USER -o \"%i %D %b\"` to **check the JOB ID** and verify your resource allocation.\n",
|
||||
"\n",
|
||||
"In this example, we’ve been allocated **32 GPUs split across 4 nodes**.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!squeue -u $USER -o \"%i %D %b\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Unset the following environment variables, as they can cause issues when using JAX in a distributed setting:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"del os.environ['VSCODE_PROXY_URI']\n",
|
||||
"del os.environ['NO_PROXY']\n",
|
||||
"del os.environ['no_proxy']"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Checking Available Compute Resources\n",
|
||||
"\n",
|
||||
"Run the following command to initialize JAX distributed computing and display the devices available for this job:\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!srun --jobid=467745 -n 32 python -c \"import jax; jax.distributed.initialize(); print(jax.devices()) if jax.process_index() == 0 else None\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Multi-Host Simulation Script with Arguments (reminder)\n",
|
||||
"\n",
|
||||
"This script is nearly identical to the single-host version, with the main addition being the call to `jax.distributed.initialize()` at the start, enabling multi-host parallelism. Here’s a breakdown of the key arguments:\n",
|
||||
"\n",
|
||||
"- **`--pdims`** (`-p`): Specifies processor grid dimensions as two integers, like `16 2` for 16 x 2 device mesh (default is `[1, jax.devices()]`).\n",
|
||||
"- **`--mesh_shape`** (`-m`): Defines the simulation mesh shape as three integers (default is `[512, 512, 512]`).\n",
|
||||
"- **`--box_size`** (`-b`): Sets the physical box size of the simulation as three floating-point values, e.g., `1000. 1000. 1000.` (default is `[500.0, 500.0, 500.0]`).\n",
|
||||
"- **`--halo_size`** (`-H`): Specifies the halo size for boundary overlap across nodes (default is `64`).\n",
|
||||
"- **`--solver`** (`-s`): Chooses the ODE solver (`leapfrog` or `dopri8`). The `leapfrog` solver uses a fixed step size, while `dopri8` is an adaptive Runge-Kutta solver with a PID controller (default is `leapfrog`).\n",
|
||||
"- **`--snapthots`** (`-st`) : Number of snapshots to save (warning, increases memory usage)\n",
|
||||
"\n",
|
||||
"### Running the Multi-Host Simulation Script\n",
|
||||
"\n",
|
||||
"To create a smooth animation, we need a series of closely spaced snapshots to capture the evolution of the density field over time. In this example, we set the number of snapshots to **10** to ensure smooth transitions in the animation.\n",
|
||||
"\n",
|
||||
"Using a larger number of GPUs helps process these snapshots efficiently, especially with a large simulation mesh or high-resolution data. This allows us to achieve both the desired snapshot frequency and the necessary simulation detail without excessive runtime.\n",
|
||||
"\n",
|
||||
"The command to run the multi-host simulation with these settings will look something like this:\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import subprocess\n",
|
||||
"\n",
|
||||
"# Define parameters as variables\n",
|
||||
"jobid = \"467745\"\n",
|
||||
"num_processes = 32\n",
|
||||
"script_name = \"05-MultiHost_PM.py\"\n",
|
||||
"mesh_shape = (1024, 1024, 1024)\n",
|
||||
"box_size = (1000., 1000., 1000.)\n",
|
||||
"halo_size = 128\n",
|
||||
"solver = \"leapfrog\"\n",
|
||||
"pdims = (16, 2)\n",
|
||||
"snapshots = 8\n",
|
||||
"\n",
|
||||
"# Build the command as a list, incorporating variables\n",
|
||||
"command = [\n",
|
||||
" \"srun\",\n",
|
||||
" f\"--jobid={jobid}\",\n",
|
||||
" \"-n\", str(num_processes),\n",
|
||||
" \"python\", script_name,\n",
|
||||
" \"--mesh_shape\", str(mesh_shape[0]), str(mesh_shape[1]), str(mesh_shape[2]),\n",
|
||||
" \"--box_size\", str(box_size[0]), str(box_size[1]), str(box_size[2]),\n",
|
||||
" \"--halo_size\", str(halo_size),\n",
|
||||
" \"-s\", solver,\n",
|
||||
" \"--pdims\", str(pdims[0]), str(pdims[1]),\n",
|
||||
" \"--snapshots\", str(snapshots)\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"# Execute the command as a subprocess\n",
|
||||
"subprocess.run(command)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Projecting the 3D Density Fields to 2D\n",
|
||||
"\n",
|
||||
"To visualize the 3D density fields in 2D, we need to create a projection:\n",
|
||||
"\n",
|
||||
"- **`project_to_2d` Function**: This function reduces the 3D array to 2D by summing over a portion of one axis.\n",
|
||||
" - We sum the top one-eighth of the data along the first axis to capture a slice of the density field.\n",
|
||||
"\n",
|
||||
"- **Creating 2D Projections**: Apply `project_to_2d` to each 3D field (`initial_conditions`, `lpt_displacements`, `ode_solution_0`, and `ode_solution_1`) to get 2D arrays that represent the density fields.\n",
|
||||
"\n",
|
||||
"### Applying the Magma Colormap\n",
|
||||
"\n",
|
||||
"To improve visualization, apply the \"magma\" colormap to each 2D projection:\n",
|
||||
"\n",
|
||||
"- **`apply_colormap` Function**: This function maps values in the 2D array to colors using the \"magma\" colormap.\n",
|
||||
" - First, normalize the array to the `[0, 1]` range.\n",
|
||||
" - Apply the colormap to create RGB images, which will be used for the animation.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from matplotlib import colormaps\n",
|
||||
"\n",
|
||||
"# Define a function to project the 3D field to 2D\n",
|
||||
"def project_to_2d(field):\n",
|
||||
" sum_over = field.shape[0] // 8\n",
|
||||
" slicing = [slice(None)] * field.ndim\n",
|
||||
" slicing[0] = slice(None, sum_over)\n",
|
||||
" slicing = tuple(slicing)\n",
|
||||
"\n",
|
||||
" return field[slicing].sum(axis=0)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def apply_colormap(array, cmap_name=\"magma\"):\n",
|
||||
" cmap = colormaps[cmap_name]\n",
|
||||
" normalized_array = (array - array.min()) / (array.max() - array.min())\n",
|
||||
" colored_image = cmap(normalized_array)[:, :, :3] # Drop alpha channel for RGB\n",
|
||||
" return (colored_image * 255).astype(np.uint8)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Loading and Visualizing Results\n",
|
||||
"\n",
|
||||
"After running the multi-host simulation, we load the saved results from disk:\n",
|
||||
"\n",
|
||||
"- **`initial_conditions.npy`**: Initial conditions for the simulation.\n",
|
||||
"- **`lpt_displacements.npy`**: Linear perturbation displacements.\n",
|
||||
"- **`ode_solution_*.npy`** : Solutions from the ODE solver at each snapshot.\n",
|
||||
"\n",
|
||||
"We will now project the fields to 2D maps and apply the color map\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"initial_conditions = apply_colormap(project_to_2d(np.load('fields/initial_conditions.npy')))\n",
|
||||
"lpt_displacements = apply_colormap(project_to_2d(np.load('fields/lpt_displacements.npy')))\n",
|
||||
"ode_solutions = []\n",
|
||||
"for i in range(8):\n",
|
||||
" ode_solutions.append(apply_colormap(project_to_2d(np.load(f'fields/ode_solution_{i}.npy'))))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Animating with Manim\n",
|
||||
"\n",
|
||||
"To create animations with `manim` in a Jupyter notebook, we start by configuring some settings to ensure the output displays correctly and without a background.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from manim import *\n",
|
||||
"config.media_width = \"100%\"\n",
|
||||
"config.verbosity = \"WARNING\"\n",
|
||||
"config.background_color = \"#00000000\" # Transparent background"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Defining the Animation in Manim\n",
|
||||
"\n",
|
||||
"This animation class, `FieldTransition`, smoothly transitions through the stages of the particle mesh density field evolution.\n",
|
||||
"\n",
|
||||
"- **Setup**: Each density field snapshot is loaded as an image and aligned for smooth transitions.\n",
|
||||
"- **Animation Sequence**:\n",
|
||||
" - The animation begins with a fade-in of the initial conditions.\n",
|
||||
" - It then transitions through the stages in sequence, showing each snapshot of the density field evolution with brief pauses in between.\n",
|
||||
"\n",
|
||||
"To run the animation, execute `%manim -v WARNING -qm FieldTransition` to render it in the Jupyter Notebook.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Define the animation in Manim\n",
|
||||
"class FieldTransition(Scene):\n",
|
||||
" def construct(self):\n",
|
||||
" init_conditions_img = ImageMobject(initial_conditions).scale(4)\n",
|
||||
" lpt_img = ImageMobject(lpt_displacements).scale(4)\n",
|
||||
" snapshots_imgs = [ImageMobject(sol).scale(4) for sol in ode_solutions]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" # Place the images on top of each other initially\n",
|
||||
" lpt_img.move_to(init_conditions_img)\n",
|
||||
" for img in snapshots_imgs:\n",
|
||||
" img.move_to(init_conditions_img)\n",
|
||||
"\n",
|
||||
" # Show initial field and then transform between fields\n",
|
||||
" self.play(FadeIn(init_conditions_img))\n",
|
||||
" self.wait(0.2)\n",
|
||||
" self.play(Transform(init_conditions_img, lpt_img))\n",
|
||||
" self.wait(0.2)\n",
|
||||
" self.play(Transform(lpt_img, snapshots_imgs[0]))\n",
|
||||
" self.wait(0.2)\n",
|
||||
" for img1, img2 in zip(snapshots_imgs, snapshots_imgs[1:]):\n",
|
||||
" self.play(Transform(img1, img2))\n",
|
||||
" self.wait(0.2)\n",
|
||||
"\n",
|
||||
"%manim -v WARNING -qm -o anim.gif --format=gif FieldTransition "
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -1,39 +0,0 @@
|
|||
# Particle Mesh Simulation with JAXPM on Multi-GPU and Multi-Host Systems
|
||||
|
||||
This collection of notebooks demonstrates how to perform Particle Mesh (PM) simulations using **JAXPM**, leveraging JAX for efficient computation on multi-GPU and multi-host systems. Each notebook progressively covers different setups, from single-GPU simulations to advanced, distributed, multi-host simulations across multiple nodes.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. **[Single-GPU Particle Mesh Simulation](01-Introduction.ipynb)**
|
||||
- Introduction to basic PM simulations on a single GPU.
|
||||
- Uses JAXPM to run simulations with absolute particle positions and Cloud-in-Cell (CIC) painting.
|
||||
|
||||
2. **[Advanced Particle Mesh Simulation on a Single GPU](02-Advanced_usage.ipynb)**
|
||||
- Explore using diffrax solvers in the ODE step.
|
||||
- Explores second order Lagrangian Perturbation Theory (LPT) simulations.
|
||||
- Introduces weighted density field projections
|
||||
|
||||
3. **[Multi-GPU Particle Mesh Simulation with Halo Exchange](03-MultiGPU_PM_Halo.ipynb)**
|
||||
- Extends PM simulation to multi-GPU setups with halo exchange.
|
||||
- Uses sharding and device mesh configurations to manage distributed data across GPUs.
|
||||
|
||||
4. **[Multi-GPU Particle Mesh Simulation with Advanced Solvers](04-MultiGPU_PM_Solvers.ipynb)**
|
||||
- Compares different ODE solvers (Leapfrog and Dopri5) in multi-GPU simulations.
|
||||
- Highlights performance, memory considerations, and solver impact on simulation quality.
|
||||
|
||||
5. **[Multi-Host Particle Mesh Simulation](05-MultiHost_PM.ipynb)**
|
||||
- Extends PM simulations to multi-host, multi-GPU setups for large-scale simulations.
|
||||
- Guides through job submission, device initialization, and retrieving results across nodes.
|
||||
|
||||
## Getting Started
|
||||
|
||||
Each notebook includes installation instructions and guidelines for configuring JAXPM and required dependencies. Follow the setup instructions in each notebook to ensure an optimal environment.
|
||||
|
||||
## Requirements
|
||||
|
||||
- **JAXPM** (included in the installation commands within notebooks)
|
||||
- **Diffrax** for ODE solvers
|
||||
- **JAX** with CUDA support for multi-GPU or TPU setups
|
||||
- **SLURM** for job scheduling on clusters (if running multi-host setups)
|
||||
|
||||
> **Note**: These notebooks are tested on the **Jean Zay** supercomputer and may require configuration changes for different HPC clusters.
|
|
@ -1,20 +0,0 @@
|
|||
[build-system]
|
||||
requires = ["setuptools", "wheel", "setuptools-scm"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "jaxpm"
|
||||
dynamic = ["version"]
|
||||
description = "A simple Particle-Mesh implementation in JAX"
|
||||
authors = [{ name = "JaxPM developers" }]
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
license = { file = "LICENSE" }
|
||||
urls = { "Homepage" = "https://github.com/DifferentiableUniverseInitiative/JaxPM" }
|
||||
dependencies = ["jax_cosmo", "jax>=0.4.35", "jaxdecomp>=0.2.3"]
|
||||
|
||||
[tool.setuptools]
|
||||
packages = ["jaxpm"]
|
||||
|
||||
[tool.setuptools_scm]
|
||||
version_file = "jaxpm/_version.py"
|
|
@ -1,4 +0,0 @@
|
|||
[pytest]
|
||||
markers =
|
||||
distributed: mark a test as distributed
|
||||
single_device: mark a test as single_device
|
|
@ -1,5 +0,0 @@
|
|||
pytest>=8.0.0
|
||||
diffrax
|
||||
pfft-python @ git+https://github.com/MP-Gadget/pfft-python
|
||||
pmesh @ git+https://github.com/MP-Gadget/pmesh
|
||||
fastpm @ git+https://github.com/ASKabalan/fastpm-python
|
11
setup.py
Normal file
11
setup.py
Normal file
|
@ -0,0 +1,11 @@
|
|||
from setuptools import find_packages, setup
|
||||
|
||||
setup(
|
||||
name='JaxPM',
|
||||
version='0.0.1',
|
||||
url='https://github.com/DifferentiableUniverseInitiative/JaxPM',
|
||||
author='JaxPM developers',
|
||||
description='A dead simple FastPM implementation in JAX',
|
||||
packages=find_packages(),
|
||||
install_requires=['jax', 'jax_cosmo'],
|
||||
)
|
|
@ -1,175 +0,0 @@
|
|||
# Parameterized fixture for mesh_shape
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ["EQX_ON_ERROR"] = "nan"
|
||||
setup_done = False
|
||||
on_cluster = False
|
||||
|
||||
|
||||
def is_on_cluster():
|
||||
global on_cluster
|
||||
return on_cluster
|
||||
|
||||
|
||||
def initialize_distributed():
|
||||
global setup_done
|
||||
global on_cluster
|
||||
if not setup_done:
|
||||
if "SLURM_JOB_ID" in os.environ:
|
||||
on_cluster = True
|
||||
print("Running on cluster")
|
||||
import jax
|
||||
jax.distributed.initialize()
|
||||
setup_done = True
|
||||
on_cluster = True
|
||||
else:
|
||||
print("Running locally")
|
||||
setup_done = True
|
||||
on_cluster = False
|
||||
os.environ["JAX_PLATFORM_NAME"] = "cpu"
|
||||
os.environ[
|
||||
"XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
|
||||
import jax
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
scope="session",
|
||||
params=[
|
||||
((32, 32, 32), (256., 256., 256.)), # BOX
|
||||
((32, 32, 64), (256., 256., 512.)), # RECTANGULAR
|
||||
])
|
||||
def simulation_config(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", params=[0.1, 0.5, 0.8])
|
||||
def lpt_scale_factor(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def cosmo():
|
||||
from functools import partial
|
||||
|
||||
from jax_cosmo import Cosmology
|
||||
Planck18 = partial(
|
||||
Cosmology,
|
||||
# Omega_m = 0.3111
|
||||
Omega_c=0.2607,
|
||||
Omega_b=0.0490,
|
||||
Omega_k=0.0,
|
||||
h=0.6766,
|
||||
n_s=0.9665,
|
||||
sigma8=0.8102,
|
||||
w0=-1.0,
|
||||
wa=0.0,
|
||||
)
|
||||
|
||||
return Planck18()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def particle_mesh(simulation_config):
|
||||
from pmesh.pm import ParticleMesh
|
||||
mesh_shape, box_shape = simulation_config
|
||||
return ParticleMesh(BoxSize=box_shape, Nmesh=mesh_shape, dtype='f4')
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def fpm_initial_conditions(cosmo, particle_mesh):
|
||||
import jax_cosmo as jc
|
||||
import numpy as np
|
||||
from jax import numpy as jnp
|
||||
|
||||
# Generate initial particle positions
|
||||
grid = particle_mesh.generate_uniform_particle_grid(shift=0).astype(
|
||||
np.float32)
|
||||
# Interpolate with linear_matter spectrum to get initial density field
|
||||
k = jnp.logspace(-4, 1, 128)
|
||||
pk = jc.power.linear_matter_power(cosmo, k)
|
||||
|
||||
def pk_fn(x):
|
||||
return jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape)
|
||||
|
||||
whitec = particle_mesh.generate_whitenoise(42,
|
||||
type='complex',
|
||||
unitary=False)
|
||||
lineark = whitec.apply(lambda k, v: pk_fn(sum(ki**2 for ki in k)**0.5)**0.5
|
||||
* v * (1 / v.BoxSize).prod()**0.5)
|
||||
init_mesh = lineark.c2r().value # XXX
|
||||
|
||||
return lineark, grid, init_mesh
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def initial_conditions(fpm_initial_conditions):
|
||||
_, _, init_mesh = fpm_initial_conditions
|
||||
return init_mesh
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def solver(cosmo, particle_mesh):
|
||||
from fastpm.core import Cosmology as FastPMCosmology
|
||||
from fastpm.core import Solver
|
||||
ref_cosmo = FastPMCosmology(cosmo)
|
||||
return Solver(particle_mesh, ref_cosmo, B=1)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def fpm_lpt1(solver, fpm_initial_conditions, lpt_scale_factor):
|
||||
|
||||
lineark, grid, _ = fpm_initial_conditions
|
||||
statelpt = solver.lpt(lineark, grid, lpt_scale_factor, order=1)
|
||||
return statelpt
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def fpm_lpt1_field(fpm_lpt1, particle_mesh):
|
||||
return particle_mesh.paint(fpm_lpt1.X).value
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def fpm_lpt2(solver, fpm_initial_conditions, lpt_scale_factor):
|
||||
|
||||
lineark, grid, _ = fpm_initial_conditions
|
||||
statelpt = solver.lpt(lineark, grid, lpt_scale_factor, order=2)
|
||||
return statelpt
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def fpm_lpt2_field(fpm_lpt2, particle_mesh):
|
||||
return particle_mesh.paint(fpm_lpt2.X).value
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def nbody_from_lpt1(solver, fpm_lpt1, particle_mesh, lpt_scale_factor):
|
||||
import numpy as np
|
||||
from fastpm.core import leapfrog
|
||||
|
||||
if lpt_scale_factor == 0.8:
|
||||
pytest.skip("Do not run nbody simulation from scale factor 0.8")
|
||||
|
||||
stages = np.linspace(lpt_scale_factor, 1.0, 10, endpoint=True)
|
||||
|
||||
finalstate = solver.nbody(fpm_lpt1, leapfrog(stages))
|
||||
fpm_mesh = particle_mesh.paint(finalstate.X).value
|
||||
|
||||
return fpm_mesh
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def nbody_from_lpt2(solver, fpm_lpt2, particle_mesh, lpt_scale_factor):
|
||||
import numpy as np
|
||||
from fastpm.core import leapfrog
|
||||
|
||||
if lpt_scale_factor == 0.8:
|
||||
pytest.skip("Do not run nbody simulation from scale factor 0.8")
|
||||
|
||||
stages = np.linspace(lpt_scale_factor, 1.0, 10, endpoint=True)
|
||||
|
||||
finalstate = solver.nbody(fpm_lpt2, leapfrog(stages))
|
||||
fpm_mesh = particle_mesh.paint(finalstate.X).value
|
||||
|
||||
return fpm_mesh
|
|
@ -1,13 +0,0 @@
|
|||
import jax.numpy as jnp
|
||||
|
||||
|
||||
def MSE(x, y):
|
||||
return jnp.mean((x - y)**2)
|
||||
|
||||
|
||||
def MSE_3D(x, y):
|
||||
return ((x - y)**2).mean(axis=0)
|
||||
|
||||
|
||||
def MSRE(x, y):
|
||||
return jnp.mean(((x - y) / y)**2)
|
|
@ -1,155 +0,0 @@
|
|||
import pytest
|
||||
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve
|
||||
from helpers import MSE, MSRE
|
||||
from jax import numpy as jnp
|
||||
|
||||
from jaxpm.distributed import uniform_particles
|
||||
from jaxpm.painting import cic_paint, cic_paint_dx
|
||||
from jaxpm.pm import lpt, make_diffrax_ode
|
||||
from jaxpm.utils import power_spectrum
|
||||
|
||||
_TOLERANCE = 1e-4
|
||||
_PM_TOLERANCE = 1e-3
|
||||
|
||||
|
||||
@pytest.mark.single_device
|
||||
@pytest.mark.parametrize("order", [1, 2])
|
||||
def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor,
|
||||
fpm_lpt1_field, fpm_lpt2_field, cosmo, order):
|
||||
|
||||
mesh_shape, box_shape = simulation_config
|
||||
cosmo._workspace = {}
|
||||
particles = uniform_particles(mesh_shape)
|
||||
|
||||
# Initial displacement
|
||||
dx, _, _ = lpt(cosmo,
|
||||
initial_conditions,
|
||||
particles,
|
||||
a=lpt_scale_factor,
|
||||
order=order)
|
||||
|
||||
fpm_ref_field = fpm_lpt1_field if order == 1 else fpm_lpt2_field
|
||||
|
||||
lpt_field = cic_paint(jnp.zeros(mesh_shape), particles + dx)
|
||||
_, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape)
|
||||
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
|
||||
|
||||
assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE
|
||||
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
|
||||
|
||||
|
||||
@pytest.mark.single_device
|
||||
@pytest.mark.parametrize("order", [1, 2])
|
||||
def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor,
|
||||
fpm_lpt1_field, fpm_lpt2_field, cosmo, order):
|
||||
|
||||
mesh_shape, box_shape = simulation_config
|
||||
cosmo._workspace = {}
|
||||
# Initial displacement
|
||||
dx, _, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order)
|
||||
|
||||
lpt_field = cic_paint_dx(dx)
|
||||
|
||||
fpm_ref_field = fpm_lpt1_field if order == 1 else fpm_lpt2_field
|
||||
|
||||
_, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape)
|
||||
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
|
||||
|
||||
assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE
|
||||
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
|
||||
|
||||
|
||||
@pytest.mark.single_device
|
||||
@pytest.mark.parametrize("order", [1, 2])
|
||||
def test_nbody_absolute(simulation_config, initial_conditions,
|
||||
lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2,
|
||||
cosmo, order):
|
||||
|
||||
mesh_shape, box_shape = simulation_config
|
||||
cosmo._workspace = {}
|
||||
particles = uniform_particles(mesh_shape)
|
||||
|
||||
# Initial displacement
|
||||
dx, p, _ = lpt(cosmo,
|
||||
initial_conditions,
|
||||
particles,
|
||||
a=lpt_scale_factor,
|
||||
order=order)
|
||||
|
||||
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
|
||||
|
||||
solver = Dopri5()
|
||||
controller = PIDController(rtol=1e-8,
|
||||
atol=1e-8,
|
||||
pcoeff=0.4,
|
||||
icoeff=1,
|
||||
dcoeff=0)
|
||||
|
||||
saveat = SaveAt(t1=True)
|
||||
|
||||
y0 = jnp.stack([particles + dx, p])
|
||||
|
||||
solutions = diffeqsolve(ode_fn,
|
||||
solver,
|
||||
t0=lpt_scale_factor,
|
||||
t1=1.0,
|
||||
dt0=None,
|
||||
y0=y0,
|
||||
stepsize_controller=controller,
|
||||
saveat=saveat)
|
||||
|
||||
final_field = cic_paint(jnp.zeros(mesh_shape), solutions.ys[-1, 0])
|
||||
|
||||
fpm_ref_field = nbody_from_lpt1 if order == 1 else nbody_from_lpt2
|
||||
|
||||
_, jpm_ps = power_spectrum(final_field, box_shape=box_shape)
|
||||
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
|
||||
|
||||
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE
|
||||
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE
|
||||
|
||||
|
||||
@pytest.mark.single_device
|
||||
@pytest.mark.parametrize("order", [1, 2])
|
||||
def test_nbody_relative(simulation_config, initial_conditions,
|
||||
lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2,
|
||||
cosmo, order):
|
||||
|
||||
mesh_shape, box_shape = simulation_config
|
||||
cosmo._workspace = {}
|
||||
|
||||
# Initial displacement
|
||||
dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order)
|
||||
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
|
||||
|
||||
solver = Dopri5()
|
||||
controller = PIDController(rtol=1e-9,
|
||||
atol=1e-9,
|
||||
pcoeff=0.4,
|
||||
icoeff=1,
|
||||
dcoeff=0)
|
||||
|
||||
saveat = SaveAt(t1=True)
|
||||
|
||||
y0 = jnp.stack([dx, p])
|
||||
|
||||
solutions = diffeqsolve(ode_fn,
|
||||
solver,
|
||||
t0=lpt_scale_factor,
|
||||
t1=1.0,
|
||||
dt0=None,
|
||||
y0=y0,
|
||||
stepsize_controller=controller,
|
||||
saveat=saveat)
|
||||
|
||||
final_field = cic_paint_dx(solutions.ys[-1, 0])
|
||||
|
||||
fpm_ref_field = nbody_from_lpt1 if order == 1 else nbody_from_lpt2
|
||||
|
||||
_, jpm_ps = power_spectrum(final_field, box_shape=box_shape)
|
||||
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
|
||||
|
||||
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE
|
||||
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE
|
|
@ -1,152 +0,0 @@
|
|||
from conftest import initialize_distributed
|
||||
|
||||
initialize_distributed() # ignore : E402
|
||||
|
||||
import jax # noqa : E402
|
||||
import jax.numpy as jnp # noqa : E402
|
||||
import pytest # noqa : E402
|
||||
from diffrax import SaveAt # noqa : E402
|
||||
from diffrax import Dopri5, ODETerm, PIDController, diffeqsolve
|
||||
from helpers import MSE # noqa : E402
|
||||
from jax import lax # noqa : E402
|
||||
from jax.experimental.multihost_utils import process_allgather # noqa : E402
|
||||
from jax.sharding import NamedSharding
|
||||
from jax.sharding import PartitionSpec as P # noqa : E402
|
||||
|
||||
from jaxpm.distributed import uniform_particles # noqa : E402
|
||||
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
|
||||
from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402
|
||||
|
||||
_TOLERANCE = 3.0 # 🙃🙃
|
||||
|
||||
|
||||
@pytest.mark.distributed
|
||||
@pytest.mark.parametrize("order", [1, 2])
|
||||
@pytest.mark.parametrize("absolute_painting", [True, False])
|
||||
def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
||||
absolute_painting):
|
||||
|
||||
mesh_shape, box_shape = simulation_config
|
||||
# SINGLE DEVICE RUN
|
||||
cosmo._workspace = {}
|
||||
if absolute_painting:
|
||||
particles = uniform_particles(mesh_shape)
|
||||
# Initial displacement
|
||||
dx, p, _ = lpt(cosmo,
|
||||
initial_conditions,
|
||||
particles,
|
||||
a=0.1,
|
||||
order=order)
|
||||
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
|
||||
y0 = jnp.stack([particles + dx, p])
|
||||
else:
|
||||
dx, p, _ = lpt(cosmo, initial_conditions, a=0.1, order=order)
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
|
||||
y0 = jnp.stack([dx, p])
|
||||
|
||||
solver = Dopri5()
|
||||
controller = PIDController(rtol=1e-8,
|
||||
atol=1e-8,
|
||||
pcoeff=0.4,
|
||||
icoeff=1,
|
||||
dcoeff=0)
|
||||
|
||||
saveat = SaveAt(t1=True)
|
||||
|
||||
solutions = diffeqsolve(ode_fn,
|
||||
solver,
|
||||
t0=0.1,
|
||||
t1=1.0,
|
||||
dt0=None,
|
||||
y0=y0,
|
||||
stepsize_controller=controller,
|
||||
saveat=saveat)
|
||||
|
||||
if absolute_painting:
|
||||
single_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape),
|
||||
solutions.ys[-1, 0])
|
||||
else:
|
||||
single_device_final_field = cic_paint_dx(solutions.ys[-1, 0])
|
||||
|
||||
print("Done with single device run")
|
||||
# MULTI DEVICE RUN
|
||||
|
||||
mesh = jax.make_mesh((1, 8), ('x', 'y'))
|
||||
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||
halo_size = mesh_shape[0] // 2
|
||||
|
||||
initial_conditions = lax.with_sharding_constraint(initial_conditions,
|
||||
sharding)
|
||||
|
||||
print(f"sharded initial conditions {initial_conditions.sharding}")
|
||||
|
||||
cosmo._workspace = {}
|
||||
if absolute_painting:
|
||||
particles = uniform_particles(mesh_shape, sharding=sharding)
|
||||
# Initial displacement
|
||||
dx, p, _ = lpt(cosmo,
|
||||
initial_conditions,
|
||||
particles,
|
||||
a=0.1,
|
||||
order=order,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo,
|
||||
mesh_shape,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding))
|
||||
|
||||
y0 = jnp.stack([particles + dx, p])
|
||||
else:
|
||||
dx, p, _ = lpt(cosmo,
|
||||
initial_conditions,
|
||||
a=0.1,
|
||||
order=order,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo,
|
||||
mesh_shape,
|
||||
paint_absolute_pos=False,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding))
|
||||
y0 = jnp.stack([dx, p])
|
||||
|
||||
solver = Dopri5()
|
||||
controller = PIDController(rtol=1e-8,
|
||||
atol=1e-8,
|
||||
pcoeff=0.4,
|
||||
icoeff=1,
|
||||
dcoeff=0)
|
||||
|
||||
saveat = SaveAt(t1=True)
|
||||
|
||||
solutions = diffeqsolve(ode_fn,
|
||||
solver,
|
||||
t0=0.1,
|
||||
t1=1.0,
|
||||
dt0=None,
|
||||
y0=y0,
|
||||
stepsize_controller=controller,
|
||||
saveat=saveat)
|
||||
|
||||
if absolute_painting:
|
||||
multi_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape),
|
||||
solutions.ys[-1, 0],
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
else:
|
||||
multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0],
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
|
||||
multi_device_final_field = process_allgather(multi_device_final_field,
|
||||
tiled=True)
|
||||
|
||||
mse = MSE(single_device_final_field, multi_device_final_field)
|
||||
print(f"MSE is {mse}")
|
||||
|
||||
assert mse < _TOLERANCE
|
|
@ -1,87 +0,0 @@
|
|||
import jax
|
||||
import pytest
|
||||
from diffrax import (BacksolveAdjoint, Dopri5, ODETerm, PIDController,
|
||||
RecursiveCheckpointAdjoint, SaveAt, diffeqsolve)
|
||||
from helpers import MSE
|
||||
from jax import numpy as jnp
|
||||
|
||||
from jaxpm.distributed import uniform_particles
|
||||
from jaxpm.painting import cic_paint, cic_paint_dx
|
||||
from jaxpm.pm import lpt, make_diffrax_ode
|
||||
|
||||
|
||||
@pytest.mark.single_device
|
||||
@pytest.mark.parametrize("order", [1, 2])
|
||||
@pytest.mark.parametrize("absolute_painting", [True, False])
|
||||
@pytest.mark.parametrize("adjoint", ['DTO', 'OTD'])
|
||||
def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
|
||||
nbody_from_lpt1, nbody_from_lpt2, cosmo, order,
|
||||
absolute_painting, adjoint):
|
||||
|
||||
mesh_shape, _ = simulation_config
|
||||
cosmo._workspace = {}
|
||||
|
||||
if adjoint == 'OTD':
|
||||
pytest.skip("OTD adjoint not implemented yet (needs PFFT3D JVP)")
|
||||
|
||||
adjoint = RecursiveCheckpointAdjoint(
|
||||
) if adjoint == 'DTO' else BacksolveAdjoint(solver=Dopri5())
|
||||
|
||||
@jax.jit
|
||||
@jax.grad
|
||||
def forward_model(initial_conditions, cosmo):
|
||||
|
||||
# Initial displacement
|
||||
if absolute_painting:
|
||||
particles = uniform_particles(mesh_shape)
|
||||
dx, p, _ = lpt(cosmo,
|
||||
initial_conditions,
|
||||
particles,
|
||||
a=lpt_scale_factor,
|
||||
order=order)
|
||||
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
|
||||
y0 = jnp.stack([particles + dx, p])
|
||||
|
||||
else:
|
||||
dx, p, _ = lpt(cosmo,
|
||||
initial_conditions,
|
||||
a=lpt_scale_factor,
|
||||
order=order)
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
|
||||
y0 = jnp.stack([dx, p])
|
||||
|
||||
solver = Dopri5()
|
||||
controller = PIDController(rtol=1e-7,
|
||||
atol=1e-7,
|
||||
pcoeff=0.4,
|
||||
icoeff=1,
|
||||
dcoeff=0)
|
||||
|
||||
saveat = SaveAt(t1=True)
|
||||
|
||||
solutions = diffeqsolve(ode_fn,
|
||||
solver,
|
||||
t0=lpt_scale_factor,
|
||||
t1=1.0,
|
||||
dt0=None,
|
||||
y0=y0,
|
||||
adjoint=adjoint,
|
||||
stepsize_controller=controller,
|
||||
saveat=saveat)
|
||||
|
||||
if absolute_painting:
|
||||
final_field = cic_paint(jnp.zeros(mesh_shape), solutions.ys[-1, 0])
|
||||
else:
|
||||
final_field = cic_paint_dx(solutions.ys[-1, 0])
|
||||
|
||||
return MSE(final_field,
|
||||
nbody_from_lpt1 if order == 1 else nbody_from_lpt2)
|
||||
|
||||
bad_initial_conditions = initial_conditions + jax.random.normal(
|
||||
jax.random.PRNGKey(0), initial_conditions.shape) * 0.5
|
||||
best_ic = forward_model(initial_conditions, cosmo)
|
||||
bad_ic = forward_model(bad_initial_conditions, cosmo)
|
||||
|
||||
assert jnp.max(best_ic) < 1e-5
|
||||
assert jnp.max(bad_ic) > 1e-5
|
Loading…
Add table
Reference in a new issue