This commit is contained in:
Wassim KABALAN 2025-06-28 00:57:23 +00:00 committed by GitHub
commit 81dcf3ef10
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 732 additions and 344 deletions

View file

@ -7,15 +7,37 @@ on:
branches: [ "main" ] branches: [ "main" ]
jobs: jobs:
build: formatting:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - name: Checkout Source
- name: Set up Python ${{ matrix.python-version }} uses: actions/checkout@v4
uses: actions/setup-python@v3
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Cache pip dependencies
uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ${{ runner.os }}-formatting-pip-${{ hashFiles('.pre-commit-config.yaml') }}
restore-keys: |
${{ runner.os }}-formatting-pip-
- name: Cache pre-commit
uses: actions/cache@v4
with:
path: ~/.cache/pre-commit
key: ${{ runner.os }}-pre-commit-${{ hashFiles('.pre-commit-config.yaml') }}
restore-keys: |
${{ runner.os }}-pre-commit-
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip isort python -m pip install --upgrade pip
python -m pip install pre-commit python -m pip install pre-commit isort
- name: Run pre-commit - name: Run pre-commit
run: python -m pre_commit run --all-files run: python -m pre_commit run --all-files

View file

@ -10,37 +10,63 @@ on:
jobs: jobs:
run_tests: run_tests:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
python-version: ["3.10" , "3.11" , "3.12"] python-version: ["3.10", "3.11", "3.12"]
steps: steps:
- name: Checkout Source - name: Checkout Source
uses: actions/checkout@v2.3.1 uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2 uses: actions/setup-python@v5
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- name: Install dependencies - name: Cache pip dependencies
uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/requirements-test.txt', '**/pyproject.toml') }}
restore-keys: |
${{ runner.os }}-pip-${{ matrix.python-version }}-
${{ runner.os }}-pip-
- name: Cache system dependencies
uses: actions/cache@v4
with:
path: /var/cache/apt
key: ${{ runner.os }}-apt-${{ hashFiles('.github/workflows/tests.yml') }}
restore-keys: |
${{ runner.os }}-apt-
- name: Install system dependencies
run: | run: |
sudo apt-get update
sudo apt-get install -y libopenmpi-dev sudo apt-get install -y libopenmpi-dev
python -m pip install --upgrade pip
pip install jax==0.4.35 - name: Install Python dependencies
pip install numpy setuptools cython wheel run: |
pip install git+https://github.com/MP-Gadget/pfft-python python -m pip install --upgrade pip setuptools wheel
pip install git+https://github.com/MP-Gadget/pmesh # Install JAX first as it's a key dependency
pip install git+https://github.com/ASKabalan/fastpm-python --no-build-isolation pip install jax
pip install -r requirements-test.txt # Install build dependencies
pip install . pip install setuptools cython mpi4py
# Install test requirements with no-build-isolation for faster builds
pip install -r requirements-test.txt --no-build-isolation
# Install additional test dependencies
pip install pytest diffrax
# Install package in development mode
pip install -e .
echo "numpy version installed:"
python -c "import numpy; print(numpy.__version__)"
- name: Run Single Device Tests - name: Run Single Device Tests
run: | run: |
cd tests cd tests
pytest -v -m "not distributed" pytest -v -m "not distributed"
- name: Run Distributed tests - name: Run Distributed tests
run: | run: |
pytest -v -m distributed pytest -v tests/test_distributed_pm.py

View file

@ -166,11 +166,11 @@ def uniform_particles(mesh_shape, sharding=None):
axis=-1) axis=-1)
def normal_field(mesh_shape, seed, sharding=None): def normal_field(seed, shape, sharding=None, dtype=float):
"""Generate a Gaussian random field with the given power spectrum.""" """Generate a Gaussian random field with the given power spectrum."""
gpu_mesh = sharding.mesh if sharding is not None else None gpu_mesh = sharding.mesh if sharding is not None else None
if gpu_mesh is not None and not (gpu_mesh.empty): if gpu_mesh is not None and not (gpu_mesh.empty):
local_mesh_shape = get_local_shape(mesh_shape, sharding) local_mesh_shape = get_local_shape(shape, sharding)
size = jax.device_count() size = jax.device_count()
# rank = jax.process_index() # rank = jax.process_index()
@ -190,9 +190,9 @@ def normal_field(mesh_shape, seed, sharding=None):
return jax.random.normal(key=keys[idx], shape=shape, dtype=dtype) return jax.random.normal(key=keys[idx], shape=shape, dtype=dtype)
return shard_map( return shard_map(
partial(normal, shape=local_mesh_shape, dtype='float32'), partial(normal, shape=local_mesh_shape, dtype=dtype),
mesh=gpu_mesh, mesh=gpu_mesh,
in_specs=P(None), in_specs=P(None),
out_specs=spec)(keys) # yapf: disable out_specs=spec)(keys) # yapf: disable
else: else:
return jax.random.normal(shape=mesh_shape, key=seed) return jax.random.normal(shape=shape, key=seed, dtype=dtype)

View file

@ -1,3 +1,5 @@
import os
import jax.numpy as np import jax.numpy as np
from jax.numpy import interp from jax.numpy import interp
from jax_cosmo.background import * from jax_cosmo.background import *
@ -119,7 +121,7 @@ def growth_factor(cosmo, a):
if cosmo._flags["gamma_growth"]: if cosmo._flags["gamma_growth"]:
return _growth_factor_gamma(cosmo, a) return _growth_factor_gamma(cosmo, a)
else: else:
return _growth_factor_ODE(cosmo, a) return _growth_factor_ODE(cosmo, a)[0]
def growth_factor_second(cosmo, a): def growth_factor_second(cosmo, a):
@ -225,7 +227,7 @@ def growth_rate_second(cosmo, a):
return _growth_rate_second_ODE(cosmo, a) return _growth_rate_second_ODE(cosmo, a)
def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4): def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=256, eps=1e-4):
"""Compute linear growth factor D(a) at a given scale factor, """Compute linear growth factor D(a) at a given scale factor,
normalised such that D(a=1) = 1. normalised such that D(a=1) = 1.
@ -243,7 +245,11 @@ def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4):
Growth factor computed at requested scale factor Growth factor computed at requested scale factor
""" """
# Check if growth has already been computed # Check if growth has already been computed
if not "background.growth_factor" in cosmo._workspace.keys(): CACHING_ACTIVATED = os.environ.get("JC_CACHE", "1") == "1"
if CACHING_ACTIVATED and "background.growth_factor" in cosmo._workspace.keys(
):
cache = cosmo._workspace["background.growth_factor"]
else:
# Compute tabulated array # Compute tabulated array
atab = np.logspace(log10_amin, 0.0, steps) atab = np.logspace(log10_amin, 0.0, steps)
@ -290,10 +296,10 @@ def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4):
"f2": f2tab, "f2": f2tab,
"h2": h2tab, "h2": h2tab,
} }
cosmo._workspace["background.growth_factor"] = cache if CACHING_ACTIVATED:
else: cosmo._workspace["background.growth_factor"] = cache
cache = cosmo._workspace["background.growth_factor"]
return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0) return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0), cache
def _growth_rate_ODE(cosmo, a): def _growth_rate_ODE(cosmo, a):
@ -314,9 +320,8 @@ def _growth_rate_ODE(cosmo, a):
Growth rate computed at requested scale factor Growth rate computed at requested scale factor
""" """
# Check if growth has already been computed, if not, compute it # Check if growth has already been computed, if not, compute it
if not "background.growth_factor" in cosmo._workspace.keys():
_growth_factor_ODE(cosmo, np.atleast_1d(1.0)) cache = _growth_factor_ODE(cosmo, np.atleast_1d(1.0))[1]
cache = cosmo._workspace["background.growth_factor"]
return interp(a, cache["a"], cache["f"]) return interp(a, cache["a"], cache["f"])
@ -338,36 +343,12 @@ def _growth_factor_second_ODE(cosmo, a):
Second order growth factor computed at requested scale factor Second order growth factor computed at requested scale factor
""" """
# Check if growth has already been computed, if not, compute it # Check if growth has already been computed, if not, compute it
if not "background.growth_factor" in cosmo._workspace.keys(): #if not "background.growth_factor" in cosmo._workspace.keys():
_growth_factor_ODE(cosmo, np.atleast_1d(1.0)) # _growth_factor_ODE(cosmo, np.atleast_1d(1.0))
cache = cosmo._workspace["background.growth_factor"] cache = _growth_factor_ODE(cosmo, a)[1]
return interp(a, cache["a"], cache["g2"]) return interp(a, cache["a"], cache["g2"])
def _growth_rate_ODE(cosmo, a):
"""Compute growth rate dD/dlna at a given scale factor by solving the linear
growth ODE.
Parameters
----------
cosmo: `Cosmology`
Cosmology object
a: array_like
Scale factor
Returns
-------
f: ndarray, or float if input scalar
Second order growth rate computed at requested scale factor
"""
# Check if growth has already been computed, if not, compute it
if not "background.growth_factor" in cosmo._workspace.keys():
_growth_factor_ODE(cosmo, np.atleast_1d(1.0))
cache = cosmo._workspace["background.growth_factor"]
return interp(a, cache["a"], cache["f"])
def _growth_rate_second_ODE(cosmo, a): def _growth_rate_second_ODE(cosmo, a):
"""Compute second order growth rate dD2/dlna at a given scale factor by solving the linear """Compute second order growth rate dD2/dlna at a given scale factor by solving the linear
growth ODE. growth ODE.
@ -386,9 +367,9 @@ def _growth_rate_second_ODE(cosmo, a):
Second order growth rate computed at requested scale factor Second order growth rate computed at requested scale factor
""" """
# Check if growth has already been computed, if not, compute it # Check if growth has already been computed, if not, compute it
if not "background.growth_factor" in cosmo._workspace.keys(): #if not "background.growth_factor" in cosmo._workspace.keys():
_growth_factor_ODE(cosmo, np.atleast_1d(1.0)) # _growth_factor_ODE(cosmo, np.atleast_1d(1.0))
cache = cosmo._workspace["background.growth_factor"] cache = _growth_factor_ODE(cosmo, a)[1]
return interp(a, cache["a"], cache["f2"]) return interp(a, cache["a"], cache["f2"])
@ -411,7 +392,11 @@ def _growth_factor_gamma(cosmo, a, log10_amin=-3, steps=128):
""" """
# Check if growth has already been computed, if not, compute it # Check if growth has already been computed, if not, compute it
if not "background.growth_factor" in cosmo._workspace.keys(): CACHING_ACTIVATED = os.environ.get("JC_CACHE", "1") == "1"
if CACHING_ACTIVATED and "background.growth_factor" in cosmo._workspace.keys(
):
cache = cosmo._workspace["background.growth_factor"]
else:
# Compute tabulated array # Compute tabulated array
atab = np.logspace(log10_amin, 0.0, steps) atab = np.logspace(log10_amin, 0.0, steps)
@ -422,9 +407,8 @@ def _growth_factor_gamma(cosmo, a, log10_amin=-3, steps=128):
gtab = np.exp(odeint(integrand, np.log(atab[0]), np.log(atab))) gtab = np.exp(odeint(integrand, np.log(atab[0]), np.log(atab)))
gtab = gtab / gtab[-1] # Normalize to a=1. gtab = gtab / gtab[-1] # Normalize to a=1.
cache = {"a": atab, "g": gtab} cache = {"a": atab, "g": gtab}
cosmo._workspace["background.growth_factor"] = cache if CACHING_ACTIVATED:
else: cosmo._workspace["background.growth_factor"] = cache
cache = cosmo._workspace["background.growth_factor"]
return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0) return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0)
@ -521,6 +505,35 @@ def Gf2(cosmo, a):
return D2f * np.power(a, 3) * np.power(Esqr(cosmo, a), 0.5) return D2f * np.power(a, 3) * np.power(Esqr(cosmo, a), 0.5)
def gp(cosmo, a):
r""" Derivative of D1 against a
Parameters
----------
cosmo: dict
Cosmology dictionary.
a : array_like
Scale factor.
Returns
-------
Scalar float Tensor : the derivative of D1 against a.
Notes
-----
The expression for :math:`gp(a)` is:
.. math::
gp(a)=\frac{dD1}{da}= D'_{1norm}/a
"""
f1 = growth_rate(cosmo, a)
g1 = growth_factor(cosmo, a)
D1f = f1 * g1 / a
return D1f
def dGfa(cosmo, a): def dGfa(cosmo, a):
r""" Derivative of Gf against a r""" Derivative of Gf against a
@ -549,7 +562,8 @@ def dGfa(cosmo, a):
f1 = growth_rate(cosmo, a) f1 = growth_rate(cosmo, a)
g1 = growth_factor(cosmo, a) g1 = growth_factor(cosmo, a)
D1f = f1 * g1 / a D1f = f1 * g1 / a
cache = cosmo._workspace['background.growth_factor'] #cache = cosmo._workspace['background.growth_factor']
cache = _growth_factor_ODE(cosmo, a)[1]
f1p = cache['h'] / cache['a'] * cache['g'] f1p = cache['h'] / cache['a'] * cache['g']
f1p = interp(np.log(a), np.log(cache['a']), f1p) f1p = interp(np.log(a), np.log(cache['a']), f1p)
Ea = E(cosmo, a) Ea = E(cosmo, a)
@ -584,7 +598,8 @@ def dGf2a(cosmo, a):
f2 = growth_rate_second(cosmo, a) f2 = growth_rate_second(cosmo, a)
g2 = growth_factor_second(cosmo, a) g2 = growth_factor_second(cosmo, a)
D2f = f2 * g2 / a D2f = f2 * g2 / a
cache = cosmo._workspace['background.growth_factor'] #cache = cosmo._workspace['background.growth_factor']
cache = _growth_factor_ODE(cosmo, a)[1]
f2p = cache['h2'] / cache['a'] * cache['g2'] f2p = cache['h2'] / cache['a'] * cache['g2']
f2p = interp(np.log(a), np.log(cache['a']), f2p) f2p = interp(np.log(a), np.log(cache['a']), f2p)
E_a = E(cosmo, a) E_a = E(cosmo, a)

View file

@ -12,7 +12,7 @@ from jaxpm.kernels import cic_compensation, fftk
from jaxpm.painting_utils import gather, scatter from jaxpm.painting_utils import gather, scatter
def _cic_paint_impl(grid_mesh, positions, weight=None): def _cic_paint_impl(grid_mesh, positions, weight=1.):
""" Paints positions onto mesh """ Paints positions onto mesh
mesh: [nx, ny, nz] mesh: [nx, ny, nz]
displacement field: [nx, ny, nz, 3] displacement field: [nx, ny, nz, 3]
@ -27,12 +27,10 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
neighboor_coords = floor + connection neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords) kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
if weight is not None: if jnp.isscalar(weight):
if jnp.isscalar(weight): kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel) else:
else: kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]), kernel)
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]),
kernel)
neighboor_coords = jnp.mod( neighboor_coords = jnp.mod(
neighboor_coords.reshape([-1, 8, 3]).astype('int32'), neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
@ -48,7 +46,13 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
@partial(jax.jit, static_argnums=(3, 4)) @partial(jax.jit, static_argnums=(3, 4))
def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None): def cic_paint(grid_mesh, positions, weight=1., halo_size=0, sharding=None):
if sharding is not None:
print("""
WARNING : absolute painting is not recommended in multi-device mode.
Please use relative painting instead.
""")
positions = positions.reshape((*grid_mesh.shape, 3)) positions = positions.reshape((*grid_mesh.shape, 3))
@ -57,9 +61,11 @@ def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if isinstance(sharding, NamedSharding) else P() spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
weight_spec = P() if jnp.isscalar(weight) else spec
grid_mesh = autoshmap(_cic_paint_impl, grid_mesh = autoshmap(_cic_paint_impl,
gpu_mesh=gpu_mesh, gpu_mesh=gpu_mesh,
in_specs=(spec, spec, P()), in_specs=(spec, spec, weight_spec),
out_specs=spec)(grid_mesh, positions, weight) out_specs=spec)(grid_mesh, positions, weight)
grid_mesh = halo_exchange(grid_mesh, grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents, halo_extents=halo_extents,
@ -128,6 +134,7 @@ def cic_paint_2d(mesh, positions, weight):
positions: [npart, 2] positions: [npart, 2]
weight: [npart] weight: [npart]
""" """
positions = positions.reshape([-1, 2])
positions = jnp.expand_dims(positions, 1) positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions) floor = jnp.floor(positions)
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]]) connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
@ -136,7 +143,7 @@ def cic_paint_2d(mesh, positions, weight):
kernel = 1. - jnp.abs(positions - neighboor_coords) kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] kernel = kernel[..., 0] * kernel[..., 1]
if weight is not None: if weight is not None:
kernel = kernel * weight[..., jnp.newaxis] kernel = kernel * weight.reshape(*positions.shape[:-1])
neighboor_coords = jnp.mod( neighboor_coords = jnp.mod(
neighboor_coords.reshape([-1, 4, 2]).astype('int32'), neighboor_coords.reshape([-1, 4, 2]).astype('int32'),
@ -151,13 +158,16 @@ def cic_paint_2d(mesh, positions, weight):
return mesh return mesh
def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24): def _cic_paint_dx_impl(displacements,
weight=1.,
halo_size=0,
chunk_size=2**24):
halo_x, _ = halo_size[0] halo_x, _ = halo_size[0]
halo_y, _ = halo_size[1] halo_y, _ = halo_size[1]
original_shape = displacements.shape original_shape = displacements.shape
particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32') particle_mesh = jnp.zeros(original_shape[:-1], dtype=displacements.dtype)
if not jnp.isscalar(weight): if not jnp.isscalar(weight):
if weight.shape != original_shape[:-1]: if weight.shape != original_shape[:-1]:
raise ValueError("Weight shape must match particle shape") raise ValueError("Weight shape must match particle shape")
@ -175,7 +185,7 @@ def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
return scatter(pmid.reshape([-1, 3]), return scatter(pmid.reshape([-1, 3]),
displacements.reshape([-1, 3]), displacements.reshape([-1, 3]),
particle_mesh, particle_mesh,
chunk_size=2**24, chunk_size=chunk_size,
val=weight) val=weight)
@ -190,13 +200,13 @@ def cic_paint_dx(displacements,
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if isinstance(sharding, NamedSharding) else P() spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
weight_spec = P() if jnp.isscalar(weight) else spec
grid_mesh = autoshmap(partial(_cic_paint_dx_impl, grid_mesh = autoshmap(partial(_cic_paint_dx_impl,
halo_size=halo_size, halo_size=halo_size,
weight=weight,
chunk_size=chunk_size), chunk_size=chunk_size),
gpu_mesh=gpu_mesh, gpu_mesh=gpu_mesh,
in_specs=spec, in_specs=(spec, weight_spec),
out_specs=spec)(displacements) out_specs=spec)(displacements, weight)
grid_mesh = halo_exchange(grid_mesh, grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents, halo_extents=halo_extents,
@ -230,6 +240,12 @@ def _cic_read_dx_impl(grid_mesh, disp, halo_size):
def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None): def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None):
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding) halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
# Halo size is halved for the read operation
# We only need to read the density field
# while in the painting operation we need to exchange and reduce the halo
# We chose to do that since it is much easier to write a custom jvp rule for exchange
# while it is a bit harder if there is a reduction involved
halo_size = jax.tree.map(lambda x: x // 2, halo_size)
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding) grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
grid_mesh = halo_exchange(grid_mesh, grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents, halo_extents=halo_extents,

View file

@ -30,17 +30,14 @@ def enmesh(base_indices, displacements, cell_size, base_shape, offset,
"""Multilinear enmeshing.""" """Multilinear enmeshing."""
base_indices = jnp.asarray(base_indices) base_indices = jnp.asarray(base_indices)
displacements = jnp.asarray(displacements) displacements = jnp.asarray(displacements)
with jax.experimental.enable_x64(): cell_size = jnp.array(cell_size, dtype=displacements.dtype)
cell_size = jnp.float64( if base_shape is not None:
cell_size) if new_cell_size is not None else jnp.array( base_shape = jnp.array(base_shape, dtype=base_indices.dtype)
cell_size, dtype=displacements.dtype) offset = offset.astype(base_indices.dtype)
if base_shape is not None: if new_cell_size is not None:
base_shape = jnp.array(base_shape, dtype=base_indices.dtype) new_cell_size = jnp.array(new_cell_size, dtype=displacements.dtype)
offset = jnp.float64(offset) if new_shape is not None:
if new_cell_size is not None: new_shape = jnp.array(new_shape, dtype=base_indices.dtype)
new_cell_size = jnp.float64(new_cell_size)
if new_shape is not None:
new_shape = jnp.array(new_shape, dtype=base_indices.dtype)
spatial_dim = base_indices.shape[1] spatial_dim = base_indices.shape[1]
neighbor_offsets = ( neighbor_offsets = (

View file

@ -131,7 +131,7 @@ def linear_field(mesh_shape, box_size, pk, seed, sharding=None):
Generate initial conditions. Generate initial conditions.
""" """
# Initialize a random field with one slice on each gpu # Initialize a random field with one slice on each gpu
field = normal_field(mesh_shape, seed=seed, sharding=sharding) field = normal_field(seed=seed, shape=mesh_shape, sharding=sharding)
field = fft3d(field) field = fft3d(field)
kvec = fftk(field) kvec = fftk(field)
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2 kmesh = sum((kk / box_size[i] * mesh_shape[i])**2
@ -172,8 +172,7 @@ def make_ode_fn(mesh_shape,
return nbody_ode return nbody_ode
def make_diffrax_ode(cosmo, def make_diffrax_ode(mesh_shape,
mesh_shape,
paint_absolute_pos=True, paint_absolute_pos=True,
halo_size=0, halo_size=0,
sharding=None): sharding=None):
@ -183,6 +182,7 @@ def make_diffrax_ode(cosmo,
state is a tuple (position, velocities) state is a tuple (position, velocities)
""" """
pos, vel = state pos, vel = state
cosmo = args
forces = pm_forces(pos, forces = pm_forces(pos,
mesh_shape=mesh_shape, mesh_shape=mesh_shape,

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -62,7 +62,7 @@
"\n", "\n",
"This cell configures a **2x4 device mesh** across 8 devices and sets up named sharding to distribute data efficiently.\n", "This cell configures a **2x4 device mesh** across 8 devices and sets up named sharding to distribute data efficiently.\n",
"\n", "\n",
"- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid. `create_device_mesh(pdims)` initializes this layout across available GPUs.\n", "- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid.\n",
"- **Sharding with Mesh**: `Mesh(devices, axis_names=('x', 'y'))` assigns the mesh grid axes, which allows flexible mapping of array data across devices.\n", "- **Sharding with Mesh**: `Mesh(devices, axis_names=('x', 'y'))` assigns the mesh grid axes, which allows flexible mapping of array data across devices.\n",
"- **PartitionSpec and NamedSharding**: `PartitionSpec` defines data partitioning across mesh axes `('x', 'y')`, and `NamedSharding(mesh, P('x', 'y'))` specifies this sharding scheme for arrays in the simulation.\n", "- **PartitionSpec and NamedSharding**: `PartitionSpec` defines data partitioning across mesh axes `('x', 'y')`, and `NamedSharding(mesh, P('x', 'y'))` specifies this sharding scheme for arrays in the simulation.\n",
"\n", "\n",
@ -71,7 +71,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -80,11 +80,10 @@
"from jax.sharding import Mesh, NamedSharding\n", "from jax.sharding import Mesh, NamedSharding\n",
"from jax.sharding import PartitionSpec as P\n", "from jax.sharding import PartitionSpec as P\n",
"\n", "\n",
"all_gather = partial(process_allgather, tiled=False)\n", "all_gather = partial(process_allgather, tiled=True)\n",
"\n", "\n",
"pdims = (2, 4)\n", "pdims = (2, 4)\n",
"devices = create_device_mesh(pdims)\n", "mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))\n",
"mesh = Mesh(devices, axis_names=('x', 'y'))\n",
"sharding = NamedSharding(mesh, P('x', 'y'))" "sharding = NamedSharding(mesh, P('x', 'y'))"
] ]
}, },
@ -124,7 +123,7 @@
"\n", "\n",
" # Evolve the simulation forward\n", " # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n", " ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n", " make_diffrax_ode(mesh_shape, paint_absolute_pos=False,sharding=sharding , halo_size=halo_size))\n",
" solver = LeapfrogMidpoint()\n", " solver = LeapfrogMidpoint()\n",
"\n", "\n",
" stepsize_controller = ConstantStepSize()\n", " stepsize_controller = ConstantStepSize()\n",
@ -288,7 +287,7 @@
"\n", "\n",
" # Evolve the simulation forward\n", " # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n", " ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n", " make_diffrax_ode(mesh_shape, paint_absolute_pos=False,sharding=sharding , halo_size=halo_size))\n",
" solver = Dopri5()\n", " solver = Dopri5()\n",
"\n", "\n",
" stepsize_controller = PIDController(rtol=1e-5,atol=1e-5)\n", " stepsize_controller = PIDController(rtol=1e-5,atol=1e-5)\n",

View file

@ -17,9 +17,8 @@ import jax_cosmo as jc
import numpy as np import numpy as np
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm, from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
PIDController, SaveAt, diffeqsolve) PIDController, SaveAt, diffeqsolve)
from jax.experimental.mesh_utils import create_device_mesh
from jax.experimental.multihost_utils import process_allgather from jax.experimental.multihost_utils import process_allgather
from jax.sharding import Mesh, NamedSharding from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P from jax.sharding import PartitionSpec as P
from jaxpm.kernels import interpolate_power_spectrum from jaxpm.kernels import interpolate_power_spectrum
@ -78,7 +77,7 @@ def parse_arguments():
def create_mesh_and_sharding(pdims): def create_mesh_and_sharding(pdims):
devices = create_device_mesh(pdims) devices = create_device_mesh(pdims)
mesh = Mesh(devices, axis_names=('x', 'y')) mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y')) sharding = NamedSharding(mesh, P('x', 'y'))
return mesh, sharding return mesh, sharding
@ -106,7 +105,10 @@ def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
sharding=sharding) sharding=sharding)
ode_fn = ODETerm( ode_fn = ODETerm(
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False)) make_diffrax_ode(mesh_shape,
paint_absolute_pos=False,
sharding=sharding,
halo_size=halo_size))
# Choose solver # Choose solver
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5() solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()

View file

@ -37,3 +37,50 @@ Each notebook includes installation instructions and guidelines for configuring
- **SLURM** for job scheduling on clusters (if running multi-host setups) - **SLURM** for job scheduling on clusters (if running multi-host setups)
> **Note**: These notebooks are tested on the **Jean Zay** supercomputer and may require configuration changes for different HPC clusters. > **Note**: These notebooks are tested on the **Jean Zay** supercomputer and may require configuration changes for different HPC clusters.
## Caveats
### Cloud-in-Cell (CIC) Painting (Single Device)
There is two ways to perform the CIC painting in JAXPM. The first one is to use the `cic_paint` which paints absolute particle positions to the mesh. The second one is to use the `cic_paint_dx` which paints relative particle positions to the mesh (using uniform particles). The absolute version is faster at the cost of more memory usage.
inorder to use relative painting you need to :
- Set the `particles` argument in `lpt` function from `jaxpm.pm` to `None`
- Set `paint_absolute_pos` to `False` in `make_ode_fn` or `make_diffrax_ode` function from `jaxpm.pm` (it is True by default)
Otherwise you set `particles` to the starting particles of your choice and leave `paint_absolute_pos` to `True` (default value).
### Cloud-in-Cell (CIC) Painting (Multi Device)
Both `cic_paint` and `cic_paint_dx` functions are available in multi-device mode.
You need to set the arguments `sharding` and `halo_size` which is explained in the notebook [03-MultiGPU_PM_Halo.ipynb](03-MultiGPU_PM_Halo.ipynb).
One thing to note that `cic_paint` is not as accurate as `cic_paint_dx` in multi-device mode and therefor is not recommended.
Using relative painting in multi-device mode is just like in single device mode.\
You need to set the `particles` argument in `lpt` function from `jaxpm.pm` to `None` and set `paint_absolute_pos` to `False`
### Distributed PM
To run a distributed PM follow the examples in notebooks [03](03-MultiGPU_PM_Halo.ipynb) and [05](05-MultiHost_PM.ipynb) for multi-host.
In short you need to set the arguments `sharding` and `halo_size` in `lpt` , `linear_field` the `make_ode` functions and `pm_forces` if you use it.
Missmatching the shardings will give you errors and unexpected results.
You can also use `normal_field` and `uniform_particles` from `jaxpm.pm.distributed` to create the fields and particles with a sharding.
### Choosing the right pdims
pdims are processor dimensions.\
Explained more in the jaxdecomp paper [here](https://github.com/DifferentiableUniverseInitiative/jaxDecomp).
For 8 devices there are three decompositions that are possible:
- (1 , 8)
- (2 , 4) , (4 , 2)
- (8 , 1)
(1 , X) should be the fastest (2 , X) or (X , 2) is more accurate but slightly slower.\
and (X , 1) is giving the least accurate results for some reason so it is not recommended.

View file

@ -1,5 +1,5 @@
pytest>=8.0.0
diffrax
pfft-python @ git+https://github.com/MP-Gadget/pfft-python pfft-python @ git+https://github.com/MP-Gadget/pfft-python
pmesh @ git+https://github.com/MP-Gadget/pmesh pmesh @ git+https://github.com/MP-Gadget/pmesh
fastpm @ git+https://github.com/ASKabalan/fastpm-python fastpm @ git+https://github.com/ASKabalan/fastpm-python
numpy==2.2.6
diffrax

View file

@ -44,7 +44,7 @@ def simulation_config(request):
return request.param return request.param
@pytest.fixture(scope="session", params=[0.1, 0.5, 0.8]) @pytest.fixture(scope="session", params=[0.1, 0.2])
def lpt_scale_factor(request): def lpt_scale_factor(request):
return request.param return request.param
@ -151,7 +151,7 @@ def nbody_from_lpt1(solver, fpm_lpt1, particle_mesh, lpt_scale_factor):
if lpt_scale_factor == 0.8: if lpt_scale_factor == 0.8:
pytest.skip("Do not run nbody simulation from scale factor 0.8") pytest.skip("Do not run nbody simulation from scale factor 0.8")
stages = np.linspace(lpt_scale_factor, 1.0, 10, endpoint=True) stages = np.linspace(lpt_scale_factor, 1.0, 100, endpoint=True)
finalstate = solver.nbody(fpm_lpt1, leapfrog(stages)) finalstate = solver.nbody(fpm_lpt1, leapfrog(stages))
fpm_mesh = particle_mesh.paint(finalstate.X).value fpm_mesh = particle_mesh.paint(finalstate.X).value
@ -167,7 +167,7 @@ def nbody_from_lpt2(solver, fpm_lpt2, particle_mesh, lpt_scale_factor):
if lpt_scale_factor == 0.8: if lpt_scale_factor == 0.8:
pytest.skip("Do not run nbody simulation from scale factor 0.8") pytest.skip("Do not run nbody simulation from scale factor 0.8")
stages = np.linspace(lpt_scale_factor, 1.0, 10, endpoint=True) stages = np.linspace(lpt_scale_factor, 1.0, 100, endpoint=True)
finalstate = solver.nbody(fpm_lpt2, leapfrog(stages)) finalstate = solver.nbody(fpm_lpt2, leapfrog(stages))
fpm_mesh = particle_mesh.paint(finalstate.X).value fpm_mesh = particle_mesh.paint(finalstate.X).value

View file

@ -2,6 +2,7 @@ import pytest
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve
from helpers import MSE, MSRE from helpers import MSE, MSRE
from jax import numpy as jnp from jax import numpy as jnp
from numpy.testing import assert_allclose
from jaxpm.distributed import uniform_particles from jaxpm.distributed import uniform_particles
from jaxpm.painting import cic_paint, cic_paint_dx from jaxpm.painting import cic_paint, cic_paint_dx
@ -10,6 +11,8 @@ from jaxpm.utils import power_spectrum
_TOLERANCE = 1e-4 _TOLERANCE = 1e-4
_PM_TOLERANCE = 1e-3 _PM_TOLERANCE = 1e-3
_FIELD_RTOL = 1e-4
_FIELD_ATOL = 1e-3
@pytest.mark.single_device @pytest.mark.single_device
@ -34,7 +37,10 @@ def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor,
_, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape) _, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape)
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape) _, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE assert_allclose(lpt_field,
fpm_ref_field,
rtol=_FIELD_RTOL,
atol=_FIELD_ATOL)
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
@ -55,7 +61,10 @@ def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor,
_, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape) _, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape)
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape) _, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE assert_allclose(lpt_field,
fpm_ref_field,
rtol=_FIELD_RTOL,
atol=_FIELD_ATOL)
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
@ -76,7 +85,7 @@ def test_nbody_absolute(simulation_config, initial_conditions,
a=lpt_scale_factor, a=lpt_scale_factor,
order=order) order=order)
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape)) ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
solver = Dopri5() solver = Dopri5()
controller = PIDController(rtol=1e-8, controller = PIDController(rtol=1e-8,
@ -95,6 +104,7 @@ def test_nbody_absolute(simulation_config, initial_conditions,
t1=1.0, t1=1.0,
dt0=None, dt0=None,
y0=y0, y0=y0,
args=cosmo,
stepsize_controller=controller, stepsize_controller=controller,
saveat=saveat) saveat=saveat)
@ -105,7 +115,10 @@ def test_nbody_absolute(simulation_config, initial_conditions,
_, jpm_ps = power_spectrum(final_field, box_shape=box_shape) _, jpm_ps = power_spectrum(final_field, box_shape=box_shape)
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape) _, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE assert_allclose(final_field,
fpm_ref_field,
rtol=_FIELD_RTOL,
atol=_FIELD_ATOL)
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE
@ -121,8 +134,7 @@ def test_nbody_relative(simulation_config, initial_conditions,
# Initial displacement # Initial displacement
dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order) dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order)
ode_fn = ODETerm( ode_fn = ODETerm(make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
solver = Dopri5() solver = Dopri5()
controller = PIDController(rtol=1e-9, controller = PIDController(rtol=1e-9,
@ -141,6 +153,7 @@ def test_nbody_relative(simulation_config, initial_conditions,
t1=1.0, t1=1.0,
dt0=None, dt0=None,
y0=y0, y0=y0,
args=cosmo,
stepsize_controller=controller, stepsize_controller=controller,
saveat=saveat) saveat=saveat)
@ -151,5 +164,8 @@ def test_nbody_relative(simulation_config, initial_conditions,
_, jpm_ps = power_spectrum(final_field, box_shape=box_shape) _, jpm_ps = power_spectrum(final_field, box_shape=box_shape)
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape) _, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE assert_allclose(final_field,
fpm_ref_field,
rtol=_FIELD_RTOL,
atol=_FIELD_ATOL)
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE

View file

@ -2,8 +2,11 @@ from conftest import initialize_distributed
initialize_distributed() # ignore : E402 initialize_distributed() # ignore : E402
from functools import partial # noqa : E402
import jax # noqa : E402 import jax # noqa : E402
import jax.numpy as jnp # noqa : E402 import jax.numpy as jnp # noqa : E402
import jax_cosmo as jc # noqa : E402
import pytest # noqa : E402 import pytest # noqa : E402
from diffrax import SaveAt # noqa : E402 from diffrax import SaveAt # noqa : E402
from diffrax import Dopri5, ODETerm, PIDController, diffeqsolve from diffrax import Dopri5, ODETerm, PIDController, diffeqsolve
@ -12,19 +15,31 @@ from jax import lax # noqa : E402
from jax.experimental.multihost_utils import process_allgather # noqa : E402 from jax.experimental.multihost_utils import process_allgather # noqa : E402
from jax.sharding import NamedSharding from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P # noqa : E402 from jax.sharding import PartitionSpec as P # noqa : E402
from jaxdecomp import get_fft_output_sharding
from jaxpm.distributed import uniform_particles # noqa : E402 from jaxpm.distributed import uniform_particles # noqa : E402
from jaxpm.distributed import fft3d, ifft3d
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402 from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402 from jaxpm.pm import lpt, make_diffrax_ode, pm_forces # noqa : E402
_TOLERANCE = 3.0 # 🙃🙃 _TOLERANCE = 1e-6 # 🎉🎉🎉
pdims = [(1, 8), (8, 1), (4, 2), (2, 4)]
@pytest.mark.distributed @pytest.mark.distributed
@pytest.mark.parametrize("order", [1, 2]) @pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("pdims", pdims)
@pytest.mark.parametrize("absolute_painting", [True, False]) @pytest.mark.parametrize("absolute_painting", [True, False])
def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
absolute_painting): pdims, absolute_painting):
if absolute_painting:
pytest.skip("Absolute painting is not recommended in distributed mode")
painting_str = "absolute" if absolute_painting else "relative"
print("=" * 50)
print(f"Running with {painting_str} painting and pdims {pdims} ...")
mesh_shape, box_shape = simulation_config mesh_shape, box_shape = simulation_config
# SINGLE DEVICE RUN # SINGLE DEVICE RUN
@ -37,12 +52,12 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
particles, particles,
a=0.1, a=0.1,
order=order) order=order)
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape)) ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
y0 = jnp.stack([particles + dx, p]) y0 = jnp.stack([particles + dx, p])
else: else:
dx, p, _ = lpt(cosmo, initial_conditions, a=0.1, order=order) dx, p, _ = lpt(cosmo, initial_conditions, a=0.1, order=order)
ode_fn = ODETerm( ode_fn = ODETerm(make_diffrax_ode(mesh_shape,
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False)) paint_absolute_pos=False))
y0 = jnp.stack([dx, p]) y0 = jnp.stack([dx, p])
solver = Dopri5() solver = Dopri5()
@ -60,6 +75,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
t1=1.0, t1=1.0,
dt0=None, dt0=None,
y0=y0, y0=y0,
args=cosmo,
stepsize_controller=controller, stepsize_controller=controller,
saveat=saveat) saveat=saveat)
@ -72,7 +88,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
print("Done with single device run") print("Done with single device run")
# MULTI DEVICE RUN # MULTI DEVICE RUN
mesh = jax.make_mesh((1, 8), ('x', 'y')) mesh = jax.make_mesh(pdims, ('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y')) sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2 halo_size = mesh_shape[0] // 2
@ -94,8 +110,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
sharding=sharding) sharding=sharding)
ode_fn = ODETerm( ode_fn = ODETerm(
make_diffrax_ode(cosmo, make_diffrax_ode(mesh_shape,
mesh_shape,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding)) sharding=sharding))
@ -108,8 +123,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding) sharding=sharding)
ode_fn = ODETerm( ode_fn = ODETerm(
make_diffrax_ode(cosmo, make_diffrax_ode(mesh_shape,
mesh_shape,
paint_absolute_pos=False, paint_absolute_pos=False,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding)) sharding=sharding))
@ -130,16 +144,23 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
t1=1.0, t1=1.0,
dt0=None, dt0=None,
y0=y0, y0=y0,
args=cosmo,
stepsize_controller=controller, stepsize_controller=controller,
saveat=saveat) saveat=saveat)
final_field = solutions.ys[-1, 0]
print(f"Final field sharding is {final_field.sharding}")
assert final_field.sharding.is_equivalent_to(sharding , ndim=3) \
, f"Final field sharding is not correct .. should be {sharding} it is instead {final_field.sharding}"
if absolute_painting: if absolute_painting:
multi_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape), multi_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape),
solutions.ys[-1, 0], final_field,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding) sharding=sharding)
else: else:
multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0], multi_device_final_field = cic_paint_dx(final_field,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding) sharding=sharding)
@ -150,3 +171,230 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
print(f"MSE is {mse}") print(f"MSE is {mse}")
assert mse < _TOLERANCE assert mse < _TOLERANCE
@pytest.mark.distributed
@pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("pdims", pdims)
def test_distrubted_gradients(simulation_config, initial_conditions, cosmo,
order, nbody_from_lpt1, nbody_from_lpt2, pdims):
mesh_shape, box_shape = simulation_config
# SINGLE DEVICE RUN
cosmo._workspace = {}
mesh = jax.make_mesh(pdims, ('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2
initial_conditions = lax.with_sharding_constraint(initial_conditions,
sharding)
print(f"sharded initial conditions {initial_conditions.sharding}")
cosmo._workspace = {}
@jax.jit
def forward_model(initial_conditions, cosmo):
dx, p, _ = lpt(cosmo,
initial_conditions,
a=0.1,
order=order,
halo_size=halo_size,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(mesh_shape,
paint_absolute_pos=False,
halo_size=halo_size,
sharding=sharding))
y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p)
solver = Dopri5()
controller = PIDController(rtol=1e-8,
atol=1e-8,
pcoeff=0.4,
icoeff=1,
dcoeff=0)
saveat = SaveAt(t1=True)
solutions = diffeqsolve(ode_fn,
solver,
t0=0.1,
t1=1.0,
dt0=None,
y0=y0,
args=cosmo,
stepsize_controller=controller,
saveat=saveat)
multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0],
halo_size=halo_size,
sharding=sharding)
return multi_device_final_field
@jax.jit
def model(initial_conditions, cosmo):
final_field = forward_model(initial_conditions, cosmo)
return MSE(final_field,
nbody_from_lpt1 if order == 1 else nbody_from_lpt2)
obs_val = model(initial_conditions, cosmo)
shifted_initial_conditions = initial_conditions + jax.random.normal(
jax.random.key(42), initial_conditions.shape) * 5
good_grads = jax.grad(model)(initial_conditions, cosmo)
off_grads = jax.grad(model)(shifted_initial_conditions, cosmo)
assert good_grads.sharding.is_equivalent_to(initial_conditions.sharding,
ndim=3)
assert off_grads.sharding.is_equivalent_to(initial_conditions.sharding,
ndim=3)
@pytest.mark.distributed
@pytest.mark.parametrize("pdims", pdims)
def test_fwd_rev_gradients(cosmo, pdims):
mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0)
cosmo._workspace = {}
mesh = jax.make_mesh(pdims, ('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2
initial_conditions = jax.random.normal(jax.random.PRNGKey(42), mesh_shape)
initial_conditions = lax.with_sharding_constraint(initial_conditions,
sharding)
print(f"sharded initial conditions {initial_conditions.sharding}")
cosmo._workspace = {}
@partial(jax.jit, static_argnums=(2, 3, 4))
def compute_forces(initial_conditions,
cosmo,
a=0.5,
halo_size=0,
sharding=None):
paint_absolute_pos = False
particles = jnp.zeros_like(initial_conditions,
shape=(*initial_conditions.shape, 3))
a = jnp.atleast_1d(a)
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
initial_conditions = jax.lax.with_sharding_constraint(
initial_conditions, sharding)
delta_k = fft3d(initial_conditions)
out_sharding = get_fft_output_sharding(sharding)
delta_k = jax.lax.with_sharding_constraint(delta_k, out_sharding)
initial_force = pm_forces(particles,
delta=delta_k,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding)
return initial_force[..., 0]
forces = compute_forces(initial_conditions,
cosmo,
halo_size=halo_size,
sharding=sharding)
back_gradient = jax.jacrev(compute_forces)(initial_conditions,
cosmo,
halo_size=halo_size,
sharding=sharding)
fwd_gradient = jax.jacfwd(compute_forces)(initial_conditions,
cosmo,
halo_size=halo_size,
sharding=sharding)
print(f"Forces sharding is {forces.sharding}")
print(f"Backward gradient sharding is {back_gradient.sharding}")
print(f"Forward gradient sharding is {fwd_gradient.sharding}")
assert forces.sharding.is_equivalent_to(initial_conditions.sharding,
ndim=3)
assert back_gradient[0, 0, 0, ...].sharding.is_equivalent_to(
initial_conditions.sharding, ndim=3)
assert fwd_gradient.sharding.is_equivalent_to(initial_conditions.sharding,
ndim=3)
@pytest.mark.distributed
@pytest.mark.parametrize("pdims", pdims)
def test_vmap(cosmo, pdims):
mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0)
cosmo._workspace = {}
mesh = jax.make_mesh(pdims, ('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2
single_dev_initial_conditions = jax.random.normal(jax.random.PRNGKey(42),
mesh_shape)
initial_conditions = lax.with_sharding_constraint(
single_dev_initial_conditions, sharding)
single_ics = jnp.stack([
single_dev_initial_conditions, single_dev_initial_conditions,
single_dev_initial_conditions
])
sharded_ics = jnp.stack(
[initial_conditions, initial_conditions, initial_conditions])
print(f"unsharded initial conditions batch {single_ics.sharding}")
print(f"sharded initial conditions batch {sharded_ics.sharding}")
cosmo._workspace = {}
@partial(jax.jit, static_argnums=(2, 3, 4))
def compute_forces(initial_conditions,
cosmo,
a=0.5,
halo_size=0,
sharding=None):
paint_absolute_pos = False
particles = jnp.zeros_like(initial_conditions,
shape=(*initial_conditions.shape, 3))
a = jnp.atleast_1d(a)
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
initial_conditions = jax.lax.with_sharding_constraint(
initial_conditions, sharding)
delta_k = fft3d(initial_conditions)
out_sharding = get_fft_output_sharding(sharding)
delta_k = jax.lax.with_sharding_constraint(delta_k, out_sharding)
initial_force = pm_forces(particles,
delta=delta_k,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding)
return initial_force[..., 0]
def fn(ic):
return compute_forces(ic,
cosmo,
halo_size=halo_size,
sharding=sharding)
v_compute_forces = jax.vmap(fn)
print(f"single_ics shape {single_ics.shape}")
print(f"sharded_ics shape {sharded_ics.shape}")
single_dev_forces = v_compute_forces(single_ics)
sharded_forces = v_compute_forces(sharded_ics)
assert single_dev_forces.ndim == 4
assert sharded_forces.ndim == 4
print(f"Sharded forces {sharded_forces.sharding}")
assert sharded_forces[0].sharding.is_equivalent_to(
initial_conditions.sharding, ndim=3)
assert sharded_forces.sharding.spec[0] == None

View file

@ -39,7 +39,7 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
particles, particles,
a=lpt_scale_factor, a=lpt_scale_factor,
order=order) order=order)
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape)) ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
y0 = jnp.stack([particles + dx, p]) y0 = jnp.stack([particles + dx, p])
else: else:
@ -48,7 +48,7 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
a=lpt_scale_factor, a=lpt_scale_factor,
order=order) order=order)
ode_fn = ODETerm( ode_fn = ODETerm(
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False)) make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
y0 = jnp.stack([dx, p]) y0 = jnp.stack([dx, p])
solver = Dopri5() solver = Dopri5()
@ -66,6 +66,7 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
t1=1.0, t1=1.0,
dt0=None, dt0=None,
y0=y0, y0=y0,
args=cosmo,
adjoint=adjoint, adjoint=adjoint,
stepsize_controller=controller, stepsize_controller=controller,
saveat=saveat) saveat=saveat)