This commit is contained in:
Wassim KABALAN 2025-06-28 00:57:23 +00:00 committed by GitHub
commit 81dcf3ef10
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 732 additions and 344 deletions

View file

@ -7,15 +7,37 @@ on:
branches: [ "main" ]
jobs:
build:
formatting:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
- name: Checkout Source
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Cache pip dependencies
uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ${{ runner.os }}-formatting-pip-${{ hashFiles('.pre-commit-config.yaml') }}
restore-keys: |
${{ runner.os }}-formatting-pip-
- name: Cache pre-commit
uses: actions/cache@v4
with:
path: ~/.cache/pre-commit
key: ${{ runner.os }}-pre-commit-${{ hashFiles('.pre-commit-config.yaml') }}
restore-keys: |
${{ runner.os }}-pre-commit-
- name: Install dependencies
run: |
python -m pip install --upgrade pip isort
python -m pip install pre-commit
python -m pip install --upgrade pip
python -m pip install pre-commit isort
- name: Run pre-commit
run: python -m pre_commit run --all-files

View file

@ -10,37 +10,63 @@ on:
jobs:
run_tests:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10" , "3.11" , "3.12"]
python-version: ["3.10", "3.11", "3.12"]
steps:
- name: Checkout Source
uses: actions/checkout@v2.3.1
uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
- name: Cache pip dependencies
uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/requirements-test.txt', '**/pyproject.toml') }}
restore-keys: |
${{ runner.os }}-pip-${{ matrix.python-version }}-
${{ runner.os }}-pip-
- name: Cache system dependencies
uses: actions/cache@v4
with:
path: /var/cache/apt
key: ${{ runner.os }}-apt-${{ hashFiles('.github/workflows/tests.yml') }}
restore-keys: |
${{ runner.os }}-apt-
- name: Install system dependencies
run: |
sudo apt-get update
sudo apt-get install -y libopenmpi-dev
python -m pip install --upgrade pip
pip install jax==0.4.35
pip install numpy setuptools cython wheel
pip install git+https://github.com/MP-Gadget/pfft-python
pip install git+https://github.com/MP-Gadget/pmesh
pip install git+https://github.com/ASKabalan/fastpm-python --no-build-isolation
pip install -r requirements-test.txt
pip install .
- name: Install Python dependencies
run: |
python -m pip install --upgrade pip setuptools wheel
# Install JAX first as it's a key dependency
pip install jax
# Install build dependencies
pip install setuptools cython mpi4py
# Install test requirements with no-build-isolation for faster builds
pip install -r requirements-test.txt --no-build-isolation
# Install additional test dependencies
pip install pytest diffrax
# Install package in development mode
pip install -e .
echo "numpy version installed:"
python -c "import numpy; print(numpy.__version__)"
- name: Run Single Device Tests
run: |
cd tests
pytest -v -m "not distributed"
- name: Run Distributed tests
run: |
pytest -v -m distributed
pytest -v tests/test_distributed_pm.py

View file

@ -166,11 +166,11 @@ def uniform_particles(mesh_shape, sharding=None):
axis=-1)
def normal_field(mesh_shape, seed, sharding=None):
def normal_field(seed, shape, sharding=None, dtype=float):
"""Generate a Gaussian random field with the given power spectrum."""
gpu_mesh = sharding.mesh if sharding is not None else None
if gpu_mesh is not None and not (gpu_mesh.empty):
local_mesh_shape = get_local_shape(mesh_shape, sharding)
local_mesh_shape = get_local_shape(shape, sharding)
size = jax.device_count()
# rank = jax.process_index()
@ -190,9 +190,9 @@ def normal_field(mesh_shape, seed, sharding=None):
return jax.random.normal(key=keys[idx], shape=shape, dtype=dtype)
return shard_map(
partial(normal, shape=local_mesh_shape, dtype='float32'),
partial(normal, shape=local_mesh_shape, dtype=dtype),
mesh=gpu_mesh,
in_specs=P(None),
out_specs=spec)(keys) # yapf: disable
else:
return jax.random.normal(shape=mesh_shape, key=seed)
return jax.random.normal(shape=shape, key=seed, dtype=dtype)

View file

@ -1,3 +1,5 @@
import os
import jax.numpy as np
from jax.numpy import interp
from jax_cosmo.background import *
@ -119,7 +121,7 @@ def growth_factor(cosmo, a):
if cosmo._flags["gamma_growth"]:
return _growth_factor_gamma(cosmo, a)
else:
return _growth_factor_ODE(cosmo, a)
return _growth_factor_ODE(cosmo, a)[0]
def growth_factor_second(cosmo, a):
@ -225,7 +227,7 @@ def growth_rate_second(cosmo, a):
return _growth_rate_second_ODE(cosmo, a)
def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4):
def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=256, eps=1e-4):
"""Compute linear growth factor D(a) at a given scale factor,
normalised such that D(a=1) = 1.
@ -243,7 +245,11 @@ def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4):
Growth factor computed at requested scale factor
"""
# Check if growth has already been computed
if not "background.growth_factor" in cosmo._workspace.keys():
CACHING_ACTIVATED = os.environ.get("JC_CACHE", "1") == "1"
if CACHING_ACTIVATED and "background.growth_factor" in cosmo._workspace.keys(
):
cache = cosmo._workspace["background.growth_factor"]
else:
# Compute tabulated array
atab = np.logspace(log10_amin, 0.0, steps)
@ -290,10 +296,10 @@ def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4):
"f2": f2tab,
"h2": h2tab,
}
cosmo._workspace["background.growth_factor"] = cache
else:
cache = cosmo._workspace["background.growth_factor"]
return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0)
if CACHING_ACTIVATED:
cosmo._workspace["background.growth_factor"] = cache
return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0), cache
def _growth_rate_ODE(cosmo, a):
@ -314,9 +320,8 @@ def _growth_rate_ODE(cosmo, a):
Growth rate computed at requested scale factor
"""
# Check if growth has already been computed, if not, compute it
if not "background.growth_factor" in cosmo._workspace.keys():
_growth_factor_ODE(cosmo, np.atleast_1d(1.0))
cache = cosmo._workspace["background.growth_factor"]
cache = _growth_factor_ODE(cosmo, np.atleast_1d(1.0))[1]
return interp(a, cache["a"], cache["f"])
@ -338,36 +343,12 @@ def _growth_factor_second_ODE(cosmo, a):
Second order growth factor computed at requested scale factor
"""
# Check if growth has already been computed, if not, compute it
if not "background.growth_factor" in cosmo._workspace.keys():
_growth_factor_ODE(cosmo, np.atleast_1d(1.0))
cache = cosmo._workspace["background.growth_factor"]
#if not "background.growth_factor" in cosmo._workspace.keys():
# _growth_factor_ODE(cosmo, np.atleast_1d(1.0))
cache = _growth_factor_ODE(cosmo, a)[1]
return interp(a, cache["a"], cache["g2"])
def _growth_rate_ODE(cosmo, a):
"""Compute growth rate dD/dlna at a given scale factor by solving the linear
growth ODE.
Parameters
----------
cosmo: `Cosmology`
Cosmology object
a: array_like
Scale factor
Returns
-------
f: ndarray, or float if input scalar
Second order growth rate computed at requested scale factor
"""
# Check if growth has already been computed, if not, compute it
if not "background.growth_factor" in cosmo._workspace.keys():
_growth_factor_ODE(cosmo, np.atleast_1d(1.0))
cache = cosmo._workspace["background.growth_factor"]
return interp(a, cache["a"], cache["f"])
def _growth_rate_second_ODE(cosmo, a):
"""Compute second order growth rate dD2/dlna at a given scale factor by solving the linear
growth ODE.
@ -386,9 +367,9 @@ def _growth_rate_second_ODE(cosmo, a):
Second order growth rate computed at requested scale factor
"""
# Check if growth has already been computed, if not, compute it
if not "background.growth_factor" in cosmo._workspace.keys():
_growth_factor_ODE(cosmo, np.atleast_1d(1.0))
cache = cosmo._workspace["background.growth_factor"]
#if not "background.growth_factor" in cosmo._workspace.keys():
# _growth_factor_ODE(cosmo, np.atleast_1d(1.0))
cache = _growth_factor_ODE(cosmo, a)[1]
return interp(a, cache["a"], cache["f2"])
@ -411,7 +392,11 @@ def _growth_factor_gamma(cosmo, a, log10_amin=-3, steps=128):
"""
# Check if growth has already been computed, if not, compute it
if not "background.growth_factor" in cosmo._workspace.keys():
CACHING_ACTIVATED = os.environ.get("JC_CACHE", "1") == "1"
if CACHING_ACTIVATED and "background.growth_factor" in cosmo._workspace.keys(
):
cache = cosmo._workspace["background.growth_factor"]
else:
# Compute tabulated array
atab = np.logspace(log10_amin, 0.0, steps)
@ -422,9 +407,8 @@ def _growth_factor_gamma(cosmo, a, log10_amin=-3, steps=128):
gtab = np.exp(odeint(integrand, np.log(atab[0]), np.log(atab)))
gtab = gtab / gtab[-1] # Normalize to a=1.
cache = {"a": atab, "g": gtab}
cosmo._workspace["background.growth_factor"] = cache
else:
cache = cosmo._workspace["background.growth_factor"]
if CACHING_ACTIVATED:
cosmo._workspace["background.growth_factor"] = cache
return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0)
@ -521,6 +505,35 @@ def Gf2(cosmo, a):
return D2f * np.power(a, 3) * np.power(Esqr(cosmo, a), 0.5)
def gp(cosmo, a):
r""" Derivative of D1 against a
Parameters
----------
cosmo: dict
Cosmology dictionary.
a : array_like
Scale factor.
Returns
-------
Scalar float Tensor : the derivative of D1 against a.
Notes
-----
The expression for :math:`gp(a)` is:
.. math::
gp(a)=\frac{dD1}{da}= D'_{1norm}/a
"""
f1 = growth_rate(cosmo, a)
g1 = growth_factor(cosmo, a)
D1f = f1 * g1 / a
return D1f
def dGfa(cosmo, a):
r""" Derivative of Gf against a
@ -549,7 +562,8 @@ def dGfa(cosmo, a):
f1 = growth_rate(cosmo, a)
g1 = growth_factor(cosmo, a)
D1f = f1 * g1 / a
cache = cosmo._workspace['background.growth_factor']
#cache = cosmo._workspace['background.growth_factor']
cache = _growth_factor_ODE(cosmo, a)[1]
f1p = cache['h'] / cache['a'] * cache['g']
f1p = interp(np.log(a), np.log(cache['a']), f1p)
Ea = E(cosmo, a)
@ -584,7 +598,8 @@ def dGf2a(cosmo, a):
f2 = growth_rate_second(cosmo, a)
g2 = growth_factor_second(cosmo, a)
D2f = f2 * g2 / a
cache = cosmo._workspace['background.growth_factor']
#cache = cosmo._workspace['background.growth_factor']
cache = _growth_factor_ODE(cosmo, a)[1]
f2p = cache['h2'] / cache['a'] * cache['g2']
f2p = interp(np.log(a), np.log(cache['a']), f2p)
E_a = E(cosmo, a)

View file

@ -12,7 +12,7 @@ from jaxpm.kernels import cic_compensation, fftk
from jaxpm.painting_utils import gather, scatter
def _cic_paint_impl(grid_mesh, positions, weight=None):
def _cic_paint_impl(grid_mesh, positions, weight=1.):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
displacement field: [nx, ny, nz, 3]
@ -27,12 +27,10 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
if weight is not None:
if jnp.isscalar(weight):
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
else:
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]),
kernel)
if jnp.isscalar(weight):
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
else:
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]), kernel)
neighboor_coords = jnp.mod(
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
@ -48,7 +46,13 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
@partial(jax.jit, static_argnums=(3, 4))
def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
def cic_paint(grid_mesh, positions, weight=1., halo_size=0, sharding=None):
if sharding is not None:
print("""
WARNING : absolute painting is not recommended in multi-device mode.
Please use relative painting instead.
""")
positions = positions.reshape((*grid_mesh.shape, 3))
@ -57,9 +61,11 @@ def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
weight_spec = P() if jnp.isscalar(weight) else spec
grid_mesh = autoshmap(_cic_paint_impl,
gpu_mesh=gpu_mesh,
in_specs=(spec, spec, P()),
in_specs=(spec, spec, weight_spec),
out_specs=spec)(grid_mesh, positions, weight)
grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,
@ -128,6 +134,7 @@ def cic_paint_2d(mesh, positions, weight):
positions: [npart, 2]
weight: [npart]
"""
positions = positions.reshape([-1, 2])
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
@ -136,7 +143,7 @@ def cic_paint_2d(mesh, positions, weight):
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1]
if weight is not None:
kernel = kernel * weight[..., jnp.newaxis]
kernel = kernel * weight.reshape(*positions.shape[:-1])
neighboor_coords = jnp.mod(
neighboor_coords.reshape([-1, 4, 2]).astype('int32'),
@ -151,13 +158,16 @@ def cic_paint_2d(mesh, positions, weight):
return mesh
def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
def _cic_paint_dx_impl(displacements,
weight=1.,
halo_size=0,
chunk_size=2**24):
halo_x, _ = halo_size[0]
halo_y, _ = halo_size[1]
original_shape = displacements.shape
particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32')
particle_mesh = jnp.zeros(original_shape[:-1], dtype=displacements.dtype)
if not jnp.isscalar(weight):
if weight.shape != original_shape[:-1]:
raise ValueError("Weight shape must match particle shape")
@ -175,7 +185,7 @@ def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
return scatter(pmid.reshape([-1, 3]),
displacements.reshape([-1, 3]),
particle_mesh,
chunk_size=2**24,
chunk_size=chunk_size,
val=weight)
@ -190,13 +200,13 @@ def cic_paint_dx(displacements,
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
weight_spec = P() if jnp.isscalar(weight) else spec
grid_mesh = autoshmap(partial(_cic_paint_dx_impl,
halo_size=halo_size,
weight=weight,
chunk_size=chunk_size),
gpu_mesh=gpu_mesh,
in_specs=spec,
out_specs=spec)(displacements)
in_specs=(spec, weight_spec),
out_specs=spec)(displacements, weight)
grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,
@ -230,6 +240,12 @@ def _cic_read_dx_impl(grid_mesh, disp, halo_size):
def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None):
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
# Halo size is halved for the read operation
# We only need to read the density field
# while in the painting operation we need to exchange and reduce the halo
# We chose to do that since it is much easier to write a custom jvp rule for exchange
# while it is a bit harder if there is a reduction involved
halo_size = jax.tree.map(lambda x: x // 2, halo_size)
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,

View file

@ -30,17 +30,14 @@ def enmesh(base_indices, displacements, cell_size, base_shape, offset,
"""Multilinear enmeshing."""
base_indices = jnp.asarray(base_indices)
displacements = jnp.asarray(displacements)
with jax.experimental.enable_x64():
cell_size = jnp.float64(
cell_size) if new_cell_size is not None else jnp.array(
cell_size, dtype=displacements.dtype)
if base_shape is not None:
base_shape = jnp.array(base_shape, dtype=base_indices.dtype)
offset = jnp.float64(offset)
if new_cell_size is not None:
new_cell_size = jnp.float64(new_cell_size)
if new_shape is not None:
new_shape = jnp.array(new_shape, dtype=base_indices.dtype)
cell_size = jnp.array(cell_size, dtype=displacements.dtype)
if base_shape is not None:
base_shape = jnp.array(base_shape, dtype=base_indices.dtype)
offset = offset.astype(base_indices.dtype)
if new_cell_size is not None:
new_cell_size = jnp.array(new_cell_size, dtype=displacements.dtype)
if new_shape is not None:
new_shape = jnp.array(new_shape, dtype=base_indices.dtype)
spatial_dim = base_indices.shape[1]
neighbor_offsets = (

View file

@ -131,7 +131,7 @@ def linear_field(mesh_shape, box_size, pk, seed, sharding=None):
Generate initial conditions.
"""
# Initialize a random field with one slice on each gpu
field = normal_field(mesh_shape, seed=seed, sharding=sharding)
field = normal_field(seed=seed, shape=mesh_shape, sharding=sharding)
field = fft3d(field)
kvec = fftk(field)
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2
@ -172,8 +172,7 @@ def make_ode_fn(mesh_shape,
return nbody_ode
def make_diffrax_ode(cosmo,
mesh_shape,
def make_diffrax_ode(mesh_shape,
paint_absolute_pos=True,
halo_size=0,
sharding=None):
@ -183,6 +182,7 @@ def make_diffrax_ode(cosmo,
state is a tuple (position, velocities)
"""
pos, vel = state
cosmo = args
forces = pm_forces(pos,
mesh_shape=mesh_shape,

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -62,7 +62,7 @@
"\n",
"This cell configures a **2x4 device mesh** across 8 devices and sets up named sharding to distribute data efficiently.\n",
"\n",
"- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid. `create_device_mesh(pdims)` initializes this layout across available GPUs.\n",
"- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid.\n",
"- **Sharding with Mesh**: `Mesh(devices, axis_names=('x', 'y'))` assigns the mesh grid axes, which allows flexible mapping of array data across devices.\n",
"- **PartitionSpec and NamedSharding**: `PartitionSpec` defines data partitioning across mesh axes `('x', 'y')`, and `NamedSharding(mesh, P('x', 'y'))` specifies this sharding scheme for arrays in the simulation.\n",
"\n",
@ -71,7 +71,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -80,11 +80,10 @@
"from jax.sharding import Mesh, NamedSharding\n",
"from jax.sharding import PartitionSpec as P\n",
"\n",
"all_gather = partial(process_allgather, tiled=False)\n",
"all_gather = partial(process_allgather, tiled=True)\n",
"\n",
"pdims = (2, 4)\n",
"devices = create_device_mesh(pdims)\n",
"mesh = Mesh(devices, axis_names=('x', 'y'))\n",
"mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))\n",
"sharding = NamedSharding(mesh, P('x', 'y'))"
]
},
@ -124,7 +123,7 @@
"\n",
" # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False,sharding=sharding , halo_size=halo_size))\n",
" solver = LeapfrogMidpoint()\n",
"\n",
" stepsize_controller = ConstantStepSize()\n",
@ -288,7 +287,7 @@
"\n",
" # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False,sharding=sharding , halo_size=halo_size))\n",
" solver = Dopri5()\n",
"\n",
" stepsize_controller = PIDController(rtol=1e-5,atol=1e-5)\n",

View file

@ -17,9 +17,8 @@ import jax_cosmo as jc
import numpy as np
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
PIDController, SaveAt, diffeqsolve)
from jax.experimental.mesh_utils import create_device_mesh
from jax.experimental.multihost_utils import process_allgather
from jax.sharding import Mesh, NamedSharding
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P
from jaxpm.kernels import interpolate_power_spectrum
@ -78,7 +77,7 @@ def parse_arguments():
def create_mesh_and_sharding(pdims):
devices = create_device_mesh(pdims)
mesh = Mesh(devices, axis_names=('x', 'y'))
mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
return mesh, sharding
@ -106,7 +105,10 @@ def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
make_diffrax_ode(mesh_shape,
paint_absolute_pos=False,
sharding=sharding,
halo_size=halo_size))
# Choose solver
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()

View file

@ -37,3 +37,50 @@ Each notebook includes installation instructions and guidelines for configuring
- **SLURM** for job scheduling on clusters (if running multi-host setups)
> **Note**: These notebooks are tested on the **Jean Zay** supercomputer and may require configuration changes for different HPC clusters.
## Caveats
### Cloud-in-Cell (CIC) Painting (Single Device)
There is two ways to perform the CIC painting in JAXPM. The first one is to use the `cic_paint` which paints absolute particle positions to the mesh. The second one is to use the `cic_paint_dx` which paints relative particle positions to the mesh (using uniform particles). The absolute version is faster at the cost of more memory usage.
inorder to use relative painting you need to :
- Set the `particles` argument in `lpt` function from `jaxpm.pm` to `None`
- Set `paint_absolute_pos` to `False` in `make_ode_fn` or `make_diffrax_ode` function from `jaxpm.pm` (it is True by default)
Otherwise you set `particles` to the starting particles of your choice and leave `paint_absolute_pos` to `True` (default value).
### Cloud-in-Cell (CIC) Painting (Multi Device)
Both `cic_paint` and `cic_paint_dx` functions are available in multi-device mode.
You need to set the arguments `sharding` and `halo_size` which is explained in the notebook [03-MultiGPU_PM_Halo.ipynb](03-MultiGPU_PM_Halo.ipynb).
One thing to note that `cic_paint` is not as accurate as `cic_paint_dx` in multi-device mode and therefor is not recommended.
Using relative painting in multi-device mode is just like in single device mode.\
You need to set the `particles` argument in `lpt` function from `jaxpm.pm` to `None` and set `paint_absolute_pos` to `False`
### Distributed PM
To run a distributed PM follow the examples in notebooks [03](03-MultiGPU_PM_Halo.ipynb) and [05](05-MultiHost_PM.ipynb) for multi-host.
In short you need to set the arguments `sharding` and `halo_size` in `lpt` , `linear_field` the `make_ode` functions and `pm_forces` if you use it.
Missmatching the shardings will give you errors and unexpected results.
You can also use `normal_field` and `uniform_particles` from `jaxpm.pm.distributed` to create the fields and particles with a sharding.
### Choosing the right pdims
pdims are processor dimensions.\
Explained more in the jaxdecomp paper [here](https://github.com/DifferentiableUniverseInitiative/jaxDecomp).
For 8 devices there are three decompositions that are possible:
- (1 , 8)
- (2 , 4) , (4 , 2)
- (8 , 1)
(1 , X) should be the fastest (2 , X) or (X , 2) is more accurate but slightly slower.\
and (X , 1) is giving the least accurate results for some reason so it is not recommended.

View file

@ -1,5 +1,5 @@
pytest>=8.0.0
diffrax
pfft-python @ git+https://github.com/MP-Gadget/pfft-python
pmesh @ git+https://github.com/MP-Gadget/pmesh
fastpm @ git+https://github.com/ASKabalan/fastpm-python
numpy==2.2.6
diffrax

View file

@ -44,7 +44,7 @@ def simulation_config(request):
return request.param
@pytest.fixture(scope="session", params=[0.1, 0.5, 0.8])
@pytest.fixture(scope="session", params=[0.1, 0.2])
def lpt_scale_factor(request):
return request.param
@ -151,7 +151,7 @@ def nbody_from_lpt1(solver, fpm_lpt1, particle_mesh, lpt_scale_factor):
if lpt_scale_factor == 0.8:
pytest.skip("Do not run nbody simulation from scale factor 0.8")
stages = np.linspace(lpt_scale_factor, 1.0, 10, endpoint=True)
stages = np.linspace(lpt_scale_factor, 1.0, 100, endpoint=True)
finalstate = solver.nbody(fpm_lpt1, leapfrog(stages))
fpm_mesh = particle_mesh.paint(finalstate.X).value
@ -167,7 +167,7 @@ def nbody_from_lpt2(solver, fpm_lpt2, particle_mesh, lpt_scale_factor):
if lpt_scale_factor == 0.8:
pytest.skip("Do not run nbody simulation from scale factor 0.8")
stages = np.linspace(lpt_scale_factor, 1.0, 10, endpoint=True)
stages = np.linspace(lpt_scale_factor, 1.0, 100, endpoint=True)
finalstate = solver.nbody(fpm_lpt2, leapfrog(stages))
fpm_mesh = particle_mesh.paint(finalstate.X).value

View file

@ -2,6 +2,7 @@ import pytest
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve
from helpers import MSE, MSRE
from jax import numpy as jnp
from numpy.testing import assert_allclose
from jaxpm.distributed import uniform_particles
from jaxpm.painting import cic_paint, cic_paint_dx
@ -10,6 +11,8 @@ from jaxpm.utils import power_spectrum
_TOLERANCE = 1e-4
_PM_TOLERANCE = 1e-3
_FIELD_RTOL = 1e-4
_FIELD_ATOL = 1e-3
@pytest.mark.single_device
@ -34,7 +37,10 @@ def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor,
_, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape)
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE
assert_allclose(lpt_field,
fpm_ref_field,
rtol=_FIELD_RTOL,
atol=_FIELD_ATOL)
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
@ -55,7 +61,10 @@ def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor,
_, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape)
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE
assert_allclose(lpt_field,
fpm_ref_field,
rtol=_FIELD_RTOL,
atol=_FIELD_ATOL)
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
@ -76,7 +85,7 @@ def test_nbody_absolute(simulation_config, initial_conditions,
a=lpt_scale_factor,
order=order)
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
solver = Dopri5()
controller = PIDController(rtol=1e-8,
@ -95,6 +104,7 @@ def test_nbody_absolute(simulation_config, initial_conditions,
t1=1.0,
dt0=None,
y0=y0,
args=cosmo,
stepsize_controller=controller,
saveat=saveat)
@ -105,7 +115,10 @@ def test_nbody_absolute(simulation_config, initial_conditions,
_, jpm_ps = power_spectrum(final_field, box_shape=box_shape)
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE
assert_allclose(final_field,
fpm_ref_field,
rtol=_FIELD_RTOL,
atol=_FIELD_ATOL)
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE
@ -121,8 +134,7 @@ def test_nbody_relative(simulation_config, initial_conditions,
# Initial displacement
dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order)
ode_fn = ODETerm(
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
ode_fn = ODETerm(make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
solver = Dopri5()
controller = PIDController(rtol=1e-9,
@ -141,6 +153,7 @@ def test_nbody_relative(simulation_config, initial_conditions,
t1=1.0,
dt0=None,
y0=y0,
args=cosmo,
stepsize_controller=controller,
saveat=saveat)
@ -151,5 +164,8 @@ def test_nbody_relative(simulation_config, initial_conditions,
_, jpm_ps = power_spectrum(final_field, box_shape=box_shape)
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE
assert_allclose(final_field,
fpm_ref_field,
rtol=_FIELD_RTOL,
atol=_FIELD_ATOL)
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE

View file

@ -2,8 +2,11 @@ from conftest import initialize_distributed
initialize_distributed() # ignore : E402
from functools import partial # noqa : E402
import jax # noqa : E402
import jax.numpy as jnp # noqa : E402
import jax_cosmo as jc # noqa : E402
import pytest # noqa : E402
from diffrax import SaveAt # noqa : E402
from diffrax import Dopri5, ODETerm, PIDController, diffeqsolve
@ -12,19 +15,31 @@ from jax import lax # noqa : E402
from jax.experimental.multihost_utils import process_allgather # noqa : E402
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P # noqa : E402
from jaxdecomp import get_fft_output_sharding
from jaxpm.distributed import uniform_particles # noqa : E402
from jaxpm.distributed import fft3d, ifft3d
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402
from jaxpm.pm import lpt, make_diffrax_ode, pm_forces # noqa : E402
_TOLERANCE = 3.0 # 🙃🙃
_TOLERANCE = 1e-6 # 🎉🎉🎉
pdims = [(1, 8), (8, 1), (4, 2), (2, 4)]
@pytest.mark.distributed
@pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("pdims", pdims)
@pytest.mark.parametrize("absolute_painting", [True, False])
def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
absolute_painting):
pdims, absolute_painting):
if absolute_painting:
pytest.skip("Absolute painting is not recommended in distributed mode")
painting_str = "absolute" if absolute_painting else "relative"
print("=" * 50)
print(f"Running with {painting_str} painting and pdims {pdims} ...")
mesh_shape, box_shape = simulation_config
# SINGLE DEVICE RUN
@ -37,12 +52,12 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
particles,
a=0.1,
order=order)
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
y0 = jnp.stack([particles + dx, p])
else:
dx, p, _ = lpt(cosmo, initial_conditions, a=0.1, order=order)
ode_fn = ODETerm(
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
ode_fn = ODETerm(make_diffrax_ode(mesh_shape,
paint_absolute_pos=False))
y0 = jnp.stack([dx, p])
solver = Dopri5()
@ -60,6 +75,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
t1=1.0,
dt0=None,
y0=y0,
args=cosmo,
stepsize_controller=controller,
saveat=saveat)
@ -72,7 +88,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
print("Done with single device run")
# MULTI DEVICE RUN
mesh = jax.make_mesh((1, 8), ('x', 'y'))
mesh = jax.make_mesh(pdims, ('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2
@ -94,8 +110,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(cosmo,
mesh_shape,
make_diffrax_ode(mesh_shape,
halo_size=halo_size,
sharding=sharding))
@ -108,8 +123,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
halo_size=halo_size,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(cosmo,
mesh_shape,
make_diffrax_ode(mesh_shape,
paint_absolute_pos=False,
halo_size=halo_size,
sharding=sharding))
@ -130,16 +144,23 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
t1=1.0,
dt0=None,
y0=y0,
args=cosmo,
stepsize_controller=controller,
saveat=saveat)
final_field = solutions.ys[-1, 0]
print(f"Final field sharding is {final_field.sharding}")
assert final_field.sharding.is_equivalent_to(sharding , ndim=3) \
, f"Final field sharding is not correct .. should be {sharding} it is instead {final_field.sharding}"
if absolute_painting:
multi_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape),
solutions.ys[-1, 0],
final_field,
halo_size=halo_size,
sharding=sharding)
else:
multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0],
multi_device_final_field = cic_paint_dx(final_field,
halo_size=halo_size,
sharding=sharding)
@ -150,3 +171,230 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
print(f"MSE is {mse}")
assert mse < _TOLERANCE
@pytest.mark.distributed
@pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("pdims", pdims)
def test_distrubted_gradients(simulation_config, initial_conditions, cosmo,
order, nbody_from_lpt1, nbody_from_lpt2, pdims):
mesh_shape, box_shape = simulation_config
# SINGLE DEVICE RUN
cosmo._workspace = {}
mesh = jax.make_mesh(pdims, ('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2
initial_conditions = lax.with_sharding_constraint(initial_conditions,
sharding)
print(f"sharded initial conditions {initial_conditions.sharding}")
cosmo._workspace = {}
@jax.jit
def forward_model(initial_conditions, cosmo):
dx, p, _ = lpt(cosmo,
initial_conditions,
a=0.1,
order=order,
halo_size=halo_size,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(mesh_shape,
paint_absolute_pos=False,
halo_size=halo_size,
sharding=sharding))
y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p)
solver = Dopri5()
controller = PIDController(rtol=1e-8,
atol=1e-8,
pcoeff=0.4,
icoeff=1,
dcoeff=0)
saveat = SaveAt(t1=True)
solutions = diffeqsolve(ode_fn,
solver,
t0=0.1,
t1=1.0,
dt0=None,
y0=y0,
args=cosmo,
stepsize_controller=controller,
saveat=saveat)
multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0],
halo_size=halo_size,
sharding=sharding)
return multi_device_final_field
@jax.jit
def model(initial_conditions, cosmo):
final_field = forward_model(initial_conditions, cosmo)
return MSE(final_field,
nbody_from_lpt1 if order == 1 else nbody_from_lpt2)
obs_val = model(initial_conditions, cosmo)
shifted_initial_conditions = initial_conditions + jax.random.normal(
jax.random.key(42), initial_conditions.shape) * 5
good_grads = jax.grad(model)(initial_conditions, cosmo)
off_grads = jax.grad(model)(shifted_initial_conditions, cosmo)
assert good_grads.sharding.is_equivalent_to(initial_conditions.sharding,
ndim=3)
assert off_grads.sharding.is_equivalent_to(initial_conditions.sharding,
ndim=3)
@pytest.mark.distributed
@pytest.mark.parametrize("pdims", pdims)
def test_fwd_rev_gradients(cosmo, pdims):
mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0)
cosmo._workspace = {}
mesh = jax.make_mesh(pdims, ('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2
initial_conditions = jax.random.normal(jax.random.PRNGKey(42), mesh_shape)
initial_conditions = lax.with_sharding_constraint(initial_conditions,
sharding)
print(f"sharded initial conditions {initial_conditions.sharding}")
cosmo._workspace = {}
@partial(jax.jit, static_argnums=(2, 3, 4))
def compute_forces(initial_conditions,
cosmo,
a=0.5,
halo_size=0,
sharding=None):
paint_absolute_pos = False
particles = jnp.zeros_like(initial_conditions,
shape=(*initial_conditions.shape, 3))
a = jnp.atleast_1d(a)
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
initial_conditions = jax.lax.with_sharding_constraint(
initial_conditions, sharding)
delta_k = fft3d(initial_conditions)
out_sharding = get_fft_output_sharding(sharding)
delta_k = jax.lax.with_sharding_constraint(delta_k, out_sharding)
initial_force = pm_forces(particles,
delta=delta_k,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding)
return initial_force[..., 0]
forces = compute_forces(initial_conditions,
cosmo,
halo_size=halo_size,
sharding=sharding)
back_gradient = jax.jacrev(compute_forces)(initial_conditions,
cosmo,
halo_size=halo_size,
sharding=sharding)
fwd_gradient = jax.jacfwd(compute_forces)(initial_conditions,
cosmo,
halo_size=halo_size,
sharding=sharding)
print(f"Forces sharding is {forces.sharding}")
print(f"Backward gradient sharding is {back_gradient.sharding}")
print(f"Forward gradient sharding is {fwd_gradient.sharding}")
assert forces.sharding.is_equivalent_to(initial_conditions.sharding,
ndim=3)
assert back_gradient[0, 0, 0, ...].sharding.is_equivalent_to(
initial_conditions.sharding, ndim=3)
assert fwd_gradient.sharding.is_equivalent_to(initial_conditions.sharding,
ndim=3)
@pytest.mark.distributed
@pytest.mark.parametrize("pdims", pdims)
def test_vmap(cosmo, pdims):
mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0)
cosmo._workspace = {}
mesh = jax.make_mesh(pdims, ('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2
single_dev_initial_conditions = jax.random.normal(jax.random.PRNGKey(42),
mesh_shape)
initial_conditions = lax.with_sharding_constraint(
single_dev_initial_conditions, sharding)
single_ics = jnp.stack([
single_dev_initial_conditions, single_dev_initial_conditions,
single_dev_initial_conditions
])
sharded_ics = jnp.stack(
[initial_conditions, initial_conditions, initial_conditions])
print(f"unsharded initial conditions batch {single_ics.sharding}")
print(f"sharded initial conditions batch {sharded_ics.sharding}")
cosmo._workspace = {}
@partial(jax.jit, static_argnums=(2, 3, 4))
def compute_forces(initial_conditions,
cosmo,
a=0.5,
halo_size=0,
sharding=None):
paint_absolute_pos = False
particles = jnp.zeros_like(initial_conditions,
shape=(*initial_conditions.shape, 3))
a = jnp.atleast_1d(a)
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
initial_conditions = jax.lax.with_sharding_constraint(
initial_conditions, sharding)
delta_k = fft3d(initial_conditions)
out_sharding = get_fft_output_sharding(sharding)
delta_k = jax.lax.with_sharding_constraint(delta_k, out_sharding)
initial_force = pm_forces(particles,
delta=delta_k,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding)
return initial_force[..., 0]
def fn(ic):
return compute_forces(ic,
cosmo,
halo_size=halo_size,
sharding=sharding)
v_compute_forces = jax.vmap(fn)
print(f"single_ics shape {single_ics.shape}")
print(f"sharded_ics shape {sharded_ics.shape}")
single_dev_forces = v_compute_forces(single_ics)
sharded_forces = v_compute_forces(sharded_ics)
assert single_dev_forces.ndim == 4
assert sharded_forces.ndim == 4
print(f"Sharded forces {sharded_forces.sharding}")
assert sharded_forces[0].sharding.is_equivalent_to(
initial_conditions.sharding, ndim=3)
assert sharded_forces.sharding.spec[0] == None

View file

@ -39,7 +39,7 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
particles,
a=lpt_scale_factor,
order=order)
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
y0 = jnp.stack([particles + dx, p])
else:
@ -48,7 +48,7 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
a=lpt_scale_factor,
order=order)
ode_fn = ODETerm(
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
y0 = jnp.stack([dx, p])
solver = Dopri5()
@ -66,6 +66,7 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
t1=1.0,
dt0=None,
y0=y0,
args=cosmo,
adjoint=adjoint,
stepsize_controller=controller,
saveat=saveat)