mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-28 16:11:11 +00:00
Merge e666aada42
into cb2a7ab17f
This commit is contained in:
commit
81dcf3ef10
18 changed files with 732 additions and 344 deletions
34
.github/workflows/formatting.yml
vendored
34
.github/workflows/formatting.yml
vendored
|
@ -7,15 +7,37 @@ on:
|
|||
branches: [ "main" ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
formatting:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v3
|
||||
- name: Checkout Source
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Cache pip dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-formatting-pip-${{ hashFiles('.pre-commit-config.yaml') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-formatting-pip-
|
||||
|
||||
- name: Cache pre-commit
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pre-commit
|
||||
key: ${{ runner.os }}-pre-commit-${{ hashFiles('.pre-commit-config.yaml') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pre-commit-
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip isort
|
||||
python -m pip install pre-commit
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install pre-commit isort
|
||||
|
||||
- name: Run pre-commit
|
||||
run: python -m pre_commit run --all-files
|
||||
|
|
54
.github/workflows/tests.yml
vendored
54
.github/workflows/tests.yml
vendored
|
@ -10,37 +10,63 @@ on:
|
|||
|
||||
jobs:
|
||||
run_tests:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.10" , "3.11" , "3.12"]
|
||||
python-version: ["3.10", "3.11", "3.12"]
|
||||
|
||||
steps:
|
||||
- name: Checkout Source
|
||||
uses: actions/checkout@v2.3.1
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install dependencies
|
||||
- name: Cache pip dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/requirements-test.txt', '**/pyproject.toml') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pip-${{ matrix.python-version }}-
|
||||
${{ runner.os }}-pip-
|
||||
|
||||
- name: Cache system dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: /var/cache/apt
|
||||
key: ${{ runner.os }}-apt-${{ hashFiles('.github/workflows/tests.yml') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-apt-
|
||||
|
||||
- name: Install system dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libopenmpi-dev
|
||||
python -m pip install --upgrade pip
|
||||
pip install jax==0.4.35
|
||||
pip install numpy setuptools cython wheel
|
||||
pip install git+https://github.com/MP-Gadget/pfft-python
|
||||
pip install git+https://github.com/MP-Gadget/pmesh
|
||||
pip install git+https://github.com/ASKabalan/fastpm-python --no-build-isolation
|
||||
pip install -r requirements-test.txt
|
||||
pip install .
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip setuptools wheel
|
||||
# Install JAX first as it's a key dependency
|
||||
pip install jax
|
||||
# Install build dependencies
|
||||
pip install setuptools cython mpi4py
|
||||
# Install test requirements with no-build-isolation for faster builds
|
||||
pip install -r requirements-test.txt --no-build-isolation
|
||||
# Install additional test dependencies
|
||||
pip install pytest diffrax
|
||||
# Install package in development mode
|
||||
pip install -e .
|
||||
echo "numpy version installed:"
|
||||
python -c "import numpy; print(numpy.__version__)"
|
||||
|
||||
- name: Run Single Device Tests
|
||||
run: |
|
||||
cd tests
|
||||
pytest -v -m "not distributed"
|
||||
|
||||
- name: Run Distributed tests
|
||||
run: |
|
||||
pytest -v -m distributed
|
||||
pytest -v tests/test_distributed_pm.py
|
||||
|
|
|
@ -166,11 +166,11 @@ def uniform_particles(mesh_shape, sharding=None):
|
|||
axis=-1)
|
||||
|
||||
|
||||
def normal_field(mesh_shape, seed, sharding=None):
|
||||
def normal_field(seed, shape, sharding=None, dtype=float):
|
||||
"""Generate a Gaussian random field with the given power spectrum."""
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
if gpu_mesh is not None and not (gpu_mesh.empty):
|
||||
local_mesh_shape = get_local_shape(mesh_shape, sharding)
|
||||
local_mesh_shape = get_local_shape(shape, sharding)
|
||||
|
||||
size = jax.device_count()
|
||||
# rank = jax.process_index()
|
||||
|
@ -190,9 +190,9 @@ def normal_field(mesh_shape, seed, sharding=None):
|
|||
return jax.random.normal(key=keys[idx], shape=shape, dtype=dtype)
|
||||
|
||||
return shard_map(
|
||||
partial(normal, shape=local_mesh_shape, dtype='float32'),
|
||||
partial(normal, shape=local_mesh_shape, dtype=dtype),
|
||||
mesh=gpu_mesh,
|
||||
in_specs=P(None),
|
||||
out_specs=spec)(keys) # yapf: disable
|
||||
else:
|
||||
return jax.random.normal(shape=mesh_shape, key=seed)
|
||||
return jax.random.normal(shape=shape, key=seed, dtype=dtype)
|
||||
|
|
103
jaxpm/growth.py
103
jaxpm/growth.py
|
@ -1,3 +1,5 @@
|
|||
import os
|
||||
|
||||
import jax.numpy as np
|
||||
from jax.numpy import interp
|
||||
from jax_cosmo.background import *
|
||||
|
@ -119,7 +121,7 @@ def growth_factor(cosmo, a):
|
|||
if cosmo._flags["gamma_growth"]:
|
||||
return _growth_factor_gamma(cosmo, a)
|
||||
else:
|
||||
return _growth_factor_ODE(cosmo, a)
|
||||
return _growth_factor_ODE(cosmo, a)[0]
|
||||
|
||||
|
||||
def growth_factor_second(cosmo, a):
|
||||
|
@ -225,7 +227,7 @@ def growth_rate_second(cosmo, a):
|
|||
return _growth_rate_second_ODE(cosmo, a)
|
||||
|
||||
|
||||
def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4):
|
||||
def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=256, eps=1e-4):
|
||||
"""Compute linear growth factor D(a) at a given scale factor,
|
||||
normalised such that D(a=1) = 1.
|
||||
|
||||
|
@ -243,7 +245,11 @@ def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4):
|
|||
Growth factor computed at requested scale factor
|
||||
"""
|
||||
# Check if growth has already been computed
|
||||
if not "background.growth_factor" in cosmo._workspace.keys():
|
||||
CACHING_ACTIVATED = os.environ.get("JC_CACHE", "1") == "1"
|
||||
if CACHING_ACTIVATED and "background.growth_factor" in cosmo._workspace.keys(
|
||||
):
|
||||
cache = cosmo._workspace["background.growth_factor"]
|
||||
else:
|
||||
# Compute tabulated array
|
||||
atab = np.logspace(log10_amin, 0.0, steps)
|
||||
|
||||
|
@ -290,10 +296,10 @@ def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4):
|
|||
"f2": f2tab,
|
||||
"h2": h2tab,
|
||||
}
|
||||
if CACHING_ACTIVATED:
|
||||
cosmo._workspace["background.growth_factor"] = cache
|
||||
else:
|
||||
cache = cosmo._workspace["background.growth_factor"]
|
||||
return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0)
|
||||
|
||||
return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0), cache
|
||||
|
||||
|
||||
def _growth_rate_ODE(cosmo, a):
|
||||
|
@ -314,9 +320,8 @@ def _growth_rate_ODE(cosmo, a):
|
|||
Growth rate computed at requested scale factor
|
||||
"""
|
||||
# Check if growth has already been computed, if not, compute it
|
||||
if not "background.growth_factor" in cosmo._workspace.keys():
|
||||
_growth_factor_ODE(cosmo, np.atleast_1d(1.0))
|
||||
cache = cosmo._workspace["background.growth_factor"]
|
||||
|
||||
cache = _growth_factor_ODE(cosmo, np.atleast_1d(1.0))[1]
|
||||
return interp(a, cache["a"], cache["f"])
|
||||
|
||||
|
||||
|
@ -338,36 +343,12 @@ def _growth_factor_second_ODE(cosmo, a):
|
|||
Second order growth factor computed at requested scale factor
|
||||
"""
|
||||
# Check if growth has already been computed, if not, compute it
|
||||
if not "background.growth_factor" in cosmo._workspace.keys():
|
||||
_growth_factor_ODE(cosmo, np.atleast_1d(1.0))
|
||||
cache = cosmo._workspace["background.growth_factor"]
|
||||
#if not "background.growth_factor" in cosmo._workspace.keys():
|
||||
# _growth_factor_ODE(cosmo, np.atleast_1d(1.0))
|
||||
cache = _growth_factor_ODE(cosmo, a)[1]
|
||||
return interp(a, cache["a"], cache["g2"])
|
||||
|
||||
|
||||
def _growth_rate_ODE(cosmo, a):
|
||||
"""Compute growth rate dD/dlna at a given scale factor by solving the linear
|
||||
growth ODE.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cosmo: `Cosmology`
|
||||
Cosmology object
|
||||
|
||||
a: array_like
|
||||
Scale factor
|
||||
|
||||
Returns
|
||||
-------
|
||||
f: ndarray, or float if input scalar
|
||||
Second order growth rate computed at requested scale factor
|
||||
"""
|
||||
# Check if growth has already been computed, if not, compute it
|
||||
if not "background.growth_factor" in cosmo._workspace.keys():
|
||||
_growth_factor_ODE(cosmo, np.atleast_1d(1.0))
|
||||
cache = cosmo._workspace["background.growth_factor"]
|
||||
return interp(a, cache["a"], cache["f"])
|
||||
|
||||
|
||||
def _growth_rate_second_ODE(cosmo, a):
|
||||
"""Compute second order growth rate dD2/dlna at a given scale factor by solving the linear
|
||||
growth ODE.
|
||||
|
@ -386,9 +367,9 @@ def _growth_rate_second_ODE(cosmo, a):
|
|||
Second order growth rate computed at requested scale factor
|
||||
"""
|
||||
# Check if growth has already been computed, if not, compute it
|
||||
if not "background.growth_factor" in cosmo._workspace.keys():
|
||||
_growth_factor_ODE(cosmo, np.atleast_1d(1.0))
|
||||
cache = cosmo._workspace["background.growth_factor"]
|
||||
#if not "background.growth_factor" in cosmo._workspace.keys():
|
||||
# _growth_factor_ODE(cosmo, np.atleast_1d(1.0))
|
||||
cache = _growth_factor_ODE(cosmo, a)[1]
|
||||
return interp(a, cache["a"], cache["f2"])
|
||||
|
||||
|
||||
|
@ -411,7 +392,11 @@ def _growth_factor_gamma(cosmo, a, log10_amin=-3, steps=128):
|
|||
|
||||
"""
|
||||
# Check if growth has already been computed, if not, compute it
|
||||
if not "background.growth_factor" in cosmo._workspace.keys():
|
||||
CACHING_ACTIVATED = os.environ.get("JC_CACHE", "1") == "1"
|
||||
if CACHING_ACTIVATED and "background.growth_factor" in cosmo._workspace.keys(
|
||||
):
|
||||
cache = cosmo._workspace["background.growth_factor"]
|
||||
else:
|
||||
# Compute tabulated array
|
||||
atab = np.logspace(log10_amin, 0.0, steps)
|
||||
|
||||
|
@ -422,9 +407,8 @@ def _growth_factor_gamma(cosmo, a, log10_amin=-3, steps=128):
|
|||
gtab = np.exp(odeint(integrand, np.log(atab[0]), np.log(atab)))
|
||||
gtab = gtab / gtab[-1] # Normalize to a=1.
|
||||
cache = {"a": atab, "g": gtab}
|
||||
if CACHING_ACTIVATED:
|
||||
cosmo._workspace["background.growth_factor"] = cache
|
||||
else:
|
||||
cache = cosmo._workspace["background.growth_factor"]
|
||||
return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0)
|
||||
|
||||
|
||||
|
@ -521,6 +505,35 @@ def Gf2(cosmo, a):
|
|||
return D2f * np.power(a, 3) * np.power(Esqr(cosmo, a), 0.5)
|
||||
|
||||
|
||||
def gp(cosmo, a):
|
||||
r""" Derivative of D1 against a
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cosmo: dict
|
||||
Cosmology dictionary.
|
||||
|
||||
a : array_like
|
||||
Scale factor.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Scalar float Tensor : the derivative of D1 against a.
|
||||
|
||||
Notes
|
||||
-----
|
||||
|
||||
The expression for :math:`gp(a)` is:
|
||||
|
||||
.. math::
|
||||
gp(a)=\frac{dD1}{da}= D'_{1norm}/a
|
||||
"""
|
||||
f1 = growth_rate(cosmo, a)
|
||||
g1 = growth_factor(cosmo, a)
|
||||
D1f = f1 * g1 / a
|
||||
return D1f
|
||||
|
||||
|
||||
def dGfa(cosmo, a):
|
||||
r""" Derivative of Gf against a
|
||||
|
||||
|
@ -549,7 +562,8 @@ def dGfa(cosmo, a):
|
|||
f1 = growth_rate(cosmo, a)
|
||||
g1 = growth_factor(cosmo, a)
|
||||
D1f = f1 * g1 / a
|
||||
cache = cosmo._workspace['background.growth_factor']
|
||||
#cache = cosmo._workspace['background.growth_factor']
|
||||
cache = _growth_factor_ODE(cosmo, a)[1]
|
||||
f1p = cache['h'] / cache['a'] * cache['g']
|
||||
f1p = interp(np.log(a), np.log(cache['a']), f1p)
|
||||
Ea = E(cosmo, a)
|
||||
|
@ -584,7 +598,8 @@ def dGf2a(cosmo, a):
|
|||
f2 = growth_rate_second(cosmo, a)
|
||||
g2 = growth_factor_second(cosmo, a)
|
||||
D2f = f2 * g2 / a
|
||||
cache = cosmo._workspace['background.growth_factor']
|
||||
#cache = cosmo._workspace['background.growth_factor']
|
||||
cache = _growth_factor_ODE(cosmo, a)[1]
|
||||
f2p = cache['h2'] / cache['a'] * cache['g2']
|
||||
f2p = interp(np.log(a), np.log(cache['a']), f2p)
|
||||
E_a = E(cosmo, a)
|
||||
|
|
|
@ -12,7 +12,7 @@ from jaxpm.kernels import cic_compensation, fftk
|
|||
from jaxpm.painting_utils import gather, scatter
|
||||
|
||||
|
||||
def _cic_paint_impl(grid_mesh, positions, weight=None):
|
||||
def _cic_paint_impl(grid_mesh, positions, weight=1.):
|
||||
""" Paints positions onto mesh
|
||||
mesh: [nx, ny, nz]
|
||||
displacement field: [nx, ny, nz, 3]
|
||||
|
@ -27,12 +27,10 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
|
|||
neighboor_coords = floor + connection
|
||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||
if weight is not None:
|
||||
if jnp.isscalar(weight):
|
||||
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
|
||||
else:
|
||||
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]),
|
||||
kernel)
|
||||
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]), kernel)
|
||||
|
||||
neighboor_coords = jnp.mod(
|
||||
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
|
||||
|
@ -48,7 +46,13 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
|
|||
|
||||
|
||||
@partial(jax.jit, static_argnums=(3, 4))
|
||||
def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
|
||||
def cic_paint(grid_mesh, positions, weight=1., halo_size=0, sharding=None):
|
||||
|
||||
if sharding is not None:
|
||||
print("""
|
||||
WARNING : absolute painting is not recommended in multi-device mode.
|
||||
Please use relative painting instead.
|
||||
""")
|
||||
|
||||
positions = positions.reshape((*grid_mesh.shape, 3))
|
||||
|
||||
|
@ -57,9 +61,11 @@ def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
|
|||
|
||||
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
||||
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
||||
weight_spec = P() if jnp.isscalar(weight) else spec
|
||||
|
||||
grid_mesh = autoshmap(_cic_paint_impl,
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=(spec, spec, P()),
|
||||
in_specs=(spec, spec, weight_spec),
|
||||
out_specs=spec)(grid_mesh, positions, weight)
|
||||
grid_mesh = halo_exchange(grid_mesh,
|
||||
halo_extents=halo_extents,
|
||||
|
@ -128,6 +134,7 @@ def cic_paint_2d(mesh, positions, weight):
|
|||
positions: [npart, 2]
|
||||
weight: [npart]
|
||||
"""
|
||||
positions = positions.reshape([-1, 2])
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
floor = jnp.floor(positions)
|
||||
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
|
||||
|
@ -136,7 +143,7 @@ def cic_paint_2d(mesh, positions, weight):
|
|||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1]
|
||||
if weight is not None:
|
||||
kernel = kernel * weight[..., jnp.newaxis]
|
||||
kernel = kernel * weight.reshape(*positions.shape[:-1])
|
||||
|
||||
neighboor_coords = jnp.mod(
|
||||
neighboor_coords.reshape([-1, 4, 2]).astype('int32'),
|
||||
|
@ -151,13 +158,16 @@ def cic_paint_2d(mesh, positions, weight):
|
|||
return mesh
|
||||
|
||||
|
||||
def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
|
||||
def _cic_paint_dx_impl(displacements,
|
||||
weight=1.,
|
||||
halo_size=0,
|
||||
chunk_size=2**24):
|
||||
|
||||
halo_x, _ = halo_size[0]
|
||||
halo_y, _ = halo_size[1]
|
||||
|
||||
original_shape = displacements.shape
|
||||
particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32')
|
||||
particle_mesh = jnp.zeros(original_shape[:-1], dtype=displacements.dtype)
|
||||
if not jnp.isscalar(weight):
|
||||
if weight.shape != original_shape[:-1]:
|
||||
raise ValueError("Weight shape must match particle shape")
|
||||
|
@ -175,7 +185,7 @@ def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
|
|||
return scatter(pmid.reshape([-1, 3]),
|
||||
displacements.reshape([-1, 3]),
|
||||
particle_mesh,
|
||||
chunk_size=2**24,
|
||||
chunk_size=chunk_size,
|
||||
val=weight)
|
||||
|
||||
|
||||
|
@ -190,13 +200,13 @@ def cic_paint_dx(displacements,
|
|||
|
||||
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
||||
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
||||
weight_spec = P() if jnp.isscalar(weight) else spec
|
||||
grid_mesh = autoshmap(partial(_cic_paint_dx_impl,
|
||||
halo_size=halo_size,
|
||||
weight=weight,
|
||||
chunk_size=chunk_size),
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=spec,
|
||||
out_specs=spec)(displacements)
|
||||
in_specs=(spec, weight_spec),
|
||||
out_specs=spec)(displacements, weight)
|
||||
|
||||
grid_mesh = halo_exchange(grid_mesh,
|
||||
halo_extents=halo_extents,
|
||||
|
@ -230,6 +240,12 @@ def _cic_read_dx_impl(grid_mesh, disp, halo_size):
|
|||
def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None):
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
|
||||
# Halo size is halved for the read operation
|
||||
# We only need to read the density field
|
||||
# while in the painting operation we need to exchange and reduce the halo
|
||||
# We chose to do that since it is much easier to write a custom jvp rule for exchange
|
||||
# while it is a bit harder if there is a reduction involved
|
||||
halo_size = jax.tree.map(lambda x: x // 2, halo_size)
|
||||
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
|
||||
grid_mesh = halo_exchange(grid_mesh,
|
||||
halo_extents=halo_extents,
|
||||
|
|
|
@ -30,15 +30,12 @@ def enmesh(base_indices, displacements, cell_size, base_shape, offset,
|
|||
"""Multilinear enmeshing."""
|
||||
base_indices = jnp.asarray(base_indices)
|
||||
displacements = jnp.asarray(displacements)
|
||||
with jax.experimental.enable_x64():
|
||||
cell_size = jnp.float64(
|
||||
cell_size) if new_cell_size is not None else jnp.array(
|
||||
cell_size, dtype=displacements.dtype)
|
||||
cell_size = jnp.array(cell_size, dtype=displacements.dtype)
|
||||
if base_shape is not None:
|
||||
base_shape = jnp.array(base_shape, dtype=base_indices.dtype)
|
||||
offset = jnp.float64(offset)
|
||||
offset = offset.astype(base_indices.dtype)
|
||||
if new_cell_size is not None:
|
||||
new_cell_size = jnp.float64(new_cell_size)
|
||||
new_cell_size = jnp.array(new_cell_size, dtype=displacements.dtype)
|
||||
if new_shape is not None:
|
||||
new_shape = jnp.array(new_shape, dtype=base_indices.dtype)
|
||||
|
||||
|
|
|
@ -131,7 +131,7 @@ def linear_field(mesh_shape, box_size, pk, seed, sharding=None):
|
|||
Generate initial conditions.
|
||||
"""
|
||||
# Initialize a random field with one slice on each gpu
|
||||
field = normal_field(mesh_shape, seed=seed, sharding=sharding)
|
||||
field = normal_field(seed=seed, shape=mesh_shape, sharding=sharding)
|
||||
field = fft3d(field)
|
||||
kvec = fftk(field)
|
||||
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2
|
||||
|
@ -172,8 +172,7 @@ def make_ode_fn(mesh_shape,
|
|||
return nbody_ode
|
||||
|
||||
|
||||
def make_diffrax_ode(cosmo,
|
||||
mesh_shape,
|
||||
def make_diffrax_ode(mesh_shape,
|
||||
paint_absolute_pos=True,
|
||||
halo_size=0,
|
||||
sharding=None):
|
||||
|
@ -183,6 +182,7 @@ def make_diffrax_ode(cosmo,
|
|||
state is a tuple (position, velocities)
|
||||
"""
|
||||
pos, vel = state
|
||||
cosmo = args
|
||||
|
||||
forces = pm_forces(pos,
|
||||
mesh_shape=mesh_shape,
|
||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -62,7 +62,7 @@
|
|||
"\n",
|
||||
"This cell configures a **2x4 device mesh** across 8 devices and sets up named sharding to distribute data efficiently.\n",
|
||||
"\n",
|
||||
"- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid. `create_device_mesh(pdims)` initializes this layout across available GPUs.\n",
|
||||
"- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid.\n",
|
||||
"- **Sharding with Mesh**: `Mesh(devices, axis_names=('x', 'y'))` assigns the mesh grid axes, which allows flexible mapping of array data across devices.\n",
|
||||
"- **PartitionSpec and NamedSharding**: `PartitionSpec` defines data partitioning across mesh axes `('x', 'y')`, and `NamedSharding(mesh, P('x', 'y'))` specifies this sharding scheme for arrays in the simulation.\n",
|
||||
"\n",
|
||||
|
@ -71,7 +71,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -80,11 +80,10 @@
|
|||
"from jax.sharding import Mesh, NamedSharding\n",
|
||||
"from jax.sharding import PartitionSpec as P\n",
|
||||
"\n",
|
||||
"all_gather = partial(process_allgather, tiled=False)\n",
|
||||
"all_gather = partial(process_allgather, tiled=True)\n",
|
||||
"\n",
|
||||
"pdims = (2, 4)\n",
|
||||
"devices = create_device_mesh(pdims)\n",
|
||||
"mesh = Mesh(devices, axis_names=('x', 'y'))\n",
|
||||
"mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))\n",
|
||||
"sharding = NamedSharding(mesh, P('x', 'y'))"
|
||||
]
|
||||
},
|
||||
|
@ -124,7 +123,7 @@
|
|||
"\n",
|
||||
" # Evolve the simulation forward\n",
|
||||
" ode_fn = ODETerm(\n",
|
||||
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
|
||||
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False,sharding=sharding , halo_size=halo_size))\n",
|
||||
" solver = LeapfrogMidpoint()\n",
|
||||
"\n",
|
||||
" stepsize_controller = ConstantStepSize()\n",
|
||||
|
@ -288,7 +287,7 @@
|
|||
"\n",
|
||||
" # Evolve the simulation forward\n",
|
||||
" ode_fn = ODETerm(\n",
|
||||
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
|
||||
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False,sharding=sharding , halo_size=halo_size))\n",
|
||||
" solver = Dopri5()\n",
|
||||
"\n",
|
||||
" stepsize_controller = PIDController(rtol=1e-5,atol=1e-5)\n",
|
||||
|
|
|
@ -17,9 +17,8 @@ import jax_cosmo as jc
|
|||
import numpy as np
|
||||
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
|
||||
PIDController, SaveAt, diffeqsolve)
|
||||
from jax.experimental.mesh_utils import create_device_mesh
|
||||
from jax.experimental.multihost_utils import process_allgather
|
||||
from jax.sharding import Mesh, NamedSharding
|
||||
from jax.sharding import NamedSharding
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
from jaxpm.kernels import interpolate_power_spectrum
|
||||
|
@ -78,7 +77,7 @@ def parse_arguments():
|
|||
|
||||
def create_mesh_and_sharding(pdims):
|
||||
devices = create_device_mesh(pdims)
|
||||
mesh = Mesh(devices, axis_names=('x', 'y'))
|
||||
mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))
|
||||
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||
return mesh, sharding
|
||||
|
||||
|
@ -106,7 +105,10 @@ def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
|
|||
sharding=sharding)
|
||||
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
|
||||
make_diffrax_ode(mesh_shape,
|
||||
paint_absolute_pos=False,
|
||||
sharding=sharding,
|
||||
halo_size=halo_size))
|
||||
|
||||
# Choose solver
|
||||
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()
|
||||
|
|
|
@ -37,3 +37,50 @@ Each notebook includes installation instructions and guidelines for configuring
|
|||
- **SLURM** for job scheduling on clusters (if running multi-host setups)
|
||||
|
||||
> **Note**: These notebooks are tested on the **Jean Zay** supercomputer and may require configuration changes for different HPC clusters.
|
||||
|
||||
## Caveats
|
||||
|
||||
### Cloud-in-Cell (CIC) Painting (Single Device)
|
||||
|
||||
There is two ways to perform the CIC painting in JAXPM. The first one is to use the `cic_paint` which paints absolute particle positions to the mesh. The second one is to use the `cic_paint_dx` which paints relative particle positions to the mesh (using uniform particles). The absolute version is faster at the cost of more memory usage.
|
||||
|
||||
inorder to use relative painting you need to :
|
||||
|
||||
- Set the `particles` argument in `lpt` function from `jaxpm.pm` to `None`
|
||||
- Set `paint_absolute_pos` to `False` in `make_ode_fn` or `make_diffrax_ode` function from `jaxpm.pm` (it is True by default)
|
||||
|
||||
Otherwise you set `particles` to the starting particles of your choice and leave `paint_absolute_pos` to `True` (default value).
|
||||
|
||||
### Cloud-in-Cell (CIC) Painting (Multi Device)
|
||||
|
||||
Both `cic_paint` and `cic_paint_dx` functions are available in multi-device mode.
|
||||
|
||||
You need to set the arguments `sharding` and `halo_size` which is explained in the notebook [03-MultiGPU_PM_Halo.ipynb](03-MultiGPU_PM_Halo.ipynb).
|
||||
|
||||
One thing to note that `cic_paint` is not as accurate as `cic_paint_dx` in multi-device mode and therefor is not recommended.
|
||||
|
||||
Using relative painting in multi-device mode is just like in single device mode.\
|
||||
You need to set the `particles` argument in `lpt` function from `jaxpm.pm` to `None` and set `paint_absolute_pos` to `False`
|
||||
|
||||
### Distributed PM
|
||||
|
||||
To run a distributed PM follow the examples in notebooks [03](03-MultiGPU_PM_Halo.ipynb) and [05](05-MultiHost_PM.ipynb) for multi-host.
|
||||
|
||||
In short you need to set the arguments `sharding` and `halo_size` in `lpt` , `linear_field` the `make_ode` functions and `pm_forces` if you use it.
|
||||
|
||||
Missmatching the shardings will give you errors and unexpected results.
|
||||
|
||||
You can also use `normal_field` and `uniform_particles` from `jaxpm.pm.distributed` to create the fields and particles with a sharding.
|
||||
|
||||
### Choosing the right pdims
|
||||
|
||||
pdims are processor dimensions.\
|
||||
Explained more in the jaxdecomp paper [here](https://github.com/DifferentiableUniverseInitiative/jaxDecomp).
|
||||
|
||||
For 8 devices there are three decompositions that are possible:
|
||||
- (1 , 8)
|
||||
- (2 , 4) , (4 , 2)
|
||||
- (8 , 1)
|
||||
|
||||
(1 , X) should be the fastest (2 , X) or (X , 2) is more accurate but slightly slower.\
|
||||
and (X , 1) is giving the least accurate results for some reason so it is not recommended.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
pytest>=8.0.0
|
||||
diffrax
|
||||
pfft-python @ git+https://github.com/MP-Gadget/pfft-python
|
||||
pmesh @ git+https://github.com/MP-Gadget/pmesh
|
||||
fastpm @ git+https://github.com/ASKabalan/fastpm-python
|
||||
numpy==2.2.6
|
||||
diffrax
|
||||
|
|
|
@ -44,7 +44,7 @@ def simulation_config(request):
|
|||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", params=[0.1, 0.5, 0.8])
|
||||
@pytest.fixture(scope="session", params=[0.1, 0.2])
|
||||
def lpt_scale_factor(request):
|
||||
return request.param
|
||||
|
||||
|
@ -151,7 +151,7 @@ def nbody_from_lpt1(solver, fpm_lpt1, particle_mesh, lpt_scale_factor):
|
|||
if lpt_scale_factor == 0.8:
|
||||
pytest.skip("Do not run nbody simulation from scale factor 0.8")
|
||||
|
||||
stages = np.linspace(lpt_scale_factor, 1.0, 10, endpoint=True)
|
||||
stages = np.linspace(lpt_scale_factor, 1.0, 100, endpoint=True)
|
||||
|
||||
finalstate = solver.nbody(fpm_lpt1, leapfrog(stages))
|
||||
fpm_mesh = particle_mesh.paint(finalstate.X).value
|
||||
|
@ -167,7 +167,7 @@ def nbody_from_lpt2(solver, fpm_lpt2, particle_mesh, lpt_scale_factor):
|
|||
if lpt_scale_factor == 0.8:
|
||||
pytest.skip("Do not run nbody simulation from scale factor 0.8")
|
||||
|
||||
stages = np.linspace(lpt_scale_factor, 1.0, 10, endpoint=True)
|
||||
stages = np.linspace(lpt_scale_factor, 1.0, 100, endpoint=True)
|
||||
|
||||
finalstate = solver.nbody(fpm_lpt2, leapfrog(stages))
|
||||
fpm_mesh = particle_mesh.paint(finalstate.X).value
|
||||
|
|
|
@ -2,6 +2,7 @@ import pytest
|
|||
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve
|
||||
from helpers import MSE, MSRE
|
||||
from jax import numpy as jnp
|
||||
from numpy.testing import assert_allclose
|
||||
|
||||
from jaxpm.distributed import uniform_particles
|
||||
from jaxpm.painting import cic_paint, cic_paint_dx
|
||||
|
@ -10,6 +11,8 @@ from jaxpm.utils import power_spectrum
|
|||
|
||||
_TOLERANCE = 1e-4
|
||||
_PM_TOLERANCE = 1e-3
|
||||
_FIELD_RTOL = 1e-4
|
||||
_FIELD_ATOL = 1e-3
|
||||
|
||||
|
||||
@pytest.mark.single_device
|
||||
|
@ -34,7 +37,10 @@ def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor,
|
|||
_, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape)
|
||||
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
|
||||
|
||||
assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE
|
||||
assert_allclose(lpt_field,
|
||||
fpm_ref_field,
|
||||
rtol=_FIELD_RTOL,
|
||||
atol=_FIELD_ATOL)
|
||||
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
|
||||
|
||||
|
||||
|
@ -55,7 +61,10 @@ def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor,
|
|||
_, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape)
|
||||
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
|
||||
|
||||
assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE
|
||||
assert_allclose(lpt_field,
|
||||
fpm_ref_field,
|
||||
rtol=_FIELD_RTOL,
|
||||
atol=_FIELD_ATOL)
|
||||
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
|
||||
|
||||
|
||||
|
@ -76,7 +85,7 @@ def test_nbody_absolute(simulation_config, initial_conditions,
|
|||
a=lpt_scale_factor,
|
||||
order=order)
|
||||
|
||||
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
|
||||
ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
|
||||
|
||||
solver = Dopri5()
|
||||
controller = PIDController(rtol=1e-8,
|
||||
|
@ -95,6 +104,7 @@ def test_nbody_absolute(simulation_config, initial_conditions,
|
|||
t1=1.0,
|
||||
dt0=None,
|
||||
y0=y0,
|
||||
args=cosmo,
|
||||
stepsize_controller=controller,
|
||||
saveat=saveat)
|
||||
|
||||
|
@ -105,7 +115,10 @@ def test_nbody_absolute(simulation_config, initial_conditions,
|
|||
_, jpm_ps = power_spectrum(final_field, box_shape=box_shape)
|
||||
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
|
||||
|
||||
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE
|
||||
assert_allclose(final_field,
|
||||
fpm_ref_field,
|
||||
rtol=_FIELD_RTOL,
|
||||
atol=_FIELD_ATOL)
|
||||
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE
|
||||
|
||||
|
||||
|
@ -121,8 +134,7 @@ def test_nbody_relative(simulation_config, initial_conditions,
|
|||
# Initial displacement
|
||||
dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order)
|
||||
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
|
||||
ode_fn = ODETerm(make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
|
||||
|
||||
solver = Dopri5()
|
||||
controller = PIDController(rtol=1e-9,
|
||||
|
@ -141,6 +153,7 @@ def test_nbody_relative(simulation_config, initial_conditions,
|
|||
t1=1.0,
|
||||
dt0=None,
|
||||
y0=y0,
|
||||
args=cosmo,
|
||||
stepsize_controller=controller,
|
||||
saveat=saveat)
|
||||
|
||||
|
@ -151,5 +164,8 @@ def test_nbody_relative(simulation_config, initial_conditions,
|
|||
_, jpm_ps = power_spectrum(final_field, box_shape=box_shape)
|
||||
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
|
||||
|
||||
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE
|
||||
assert_allclose(final_field,
|
||||
fpm_ref_field,
|
||||
rtol=_FIELD_RTOL,
|
||||
atol=_FIELD_ATOL)
|
||||
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE
|
||||
|
|
|
@ -2,8 +2,11 @@ from conftest import initialize_distributed
|
|||
|
||||
initialize_distributed() # ignore : E402
|
||||
|
||||
from functools import partial # noqa : E402
|
||||
|
||||
import jax # noqa : E402
|
||||
import jax.numpy as jnp # noqa : E402
|
||||
import jax_cosmo as jc # noqa : E402
|
||||
import pytest # noqa : E402
|
||||
from diffrax import SaveAt # noqa : E402
|
||||
from diffrax import Dopri5, ODETerm, PIDController, diffeqsolve
|
||||
|
@ -12,19 +15,31 @@ from jax import lax # noqa : E402
|
|||
from jax.experimental.multihost_utils import process_allgather # noqa : E402
|
||||
from jax.sharding import NamedSharding
|
||||
from jax.sharding import PartitionSpec as P # noqa : E402
|
||||
from jaxdecomp import get_fft_output_sharding
|
||||
|
||||
from jaxpm.distributed import uniform_particles # noqa : E402
|
||||
from jaxpm.distributed import fft3d, ifft3d
|
||||
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
|
||||
from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402
|
||||
from jaxpm.pm import lpt, make_diffrax_ode, pm_forces # noqa : E402
|
||||
|
||||
_TOLERANCE = 3.0 # 🙃🙃
|
||||
_TOLERANCE = 1e-6 # 🎉🎉🎉
|
||||
|
||||
pdims = [(1, 8), (8, 1), (4, 2), (2, 4)]
|
||||
|
||||
|
||||
@pytest.mark.distributed
|
||||
@pytest.mark.parametrize("order", [1, 2])
|
||||
@pytest.mark.parametrize("pdims", pdims)
|
||||
@pytest.mark.parametrize("absolute_painting", [True, False])
|
||||
def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
||||
absolute_painting):
|
||||
pdims, absolute_painting):
|
||||
|
||||
if absolute_painting:
|
||||
pytest.skip("Absolute painting is not recommended in distributed mode")
|
||||
|
||||
painting_str = "absolute" if absolute_painting else "relative"
|
||||
print("=" * 50)
|
||||
print(f"Running with {painting_str} painting and pdims {pdims} ...")
|
||||
|
||||
mesh_shape, box_shape = simulation_config
|
||||
# SINGLE DEVICE RUN
|
||||
|
@ -37,12 +52,12 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
|||
particles,
|
||||
a=0.1,
|
||||
order=order)
|
||||
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
|
||||
ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
|
||||
y0 = jnp.stack([particles + dx, p])
|
||||
else:
|
||||
dx, p, _ = lpt(cosmo, initial_conditions, a=0.1, order=order)
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
|
||||
ode_fn = ODETerm(make_diffrax_ode(mesh_shape,
|
||||
paint_absolute_pos=False))
|
||||
y0 = jnp.stack([dx, p])
|
||||
|
||||
solver = Dopri5()
|
||||
|
@ -60,6 +75,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
|||
t1=1.0,
|
||||
dt0=None,
|
||||
y0=y0,
|
||||
args=cosmo,
|
||||
stepsize_controller=controller,
|
||||
saveat=saveat)
|
||||
|
||||
|
@ -72,7 +88,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
|||
print("Done with single device run")
|
||||
# MULTI DEVICE RUN
|
||||
|
||||
mesh = jax.make_mesh((1, 8), ('x', 'y'))
|
||||
mesh = jax.make_mesh(pdims, ('x', 'y'))
|
||||
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||
halo_size = mesh_shape[0] // 2
|
||||
|
||||
|
@ -94,8 +110,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
|||
sharding=sharding)
|
||||
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo,
|
||||
mesh_shape,
|
||||
make_diffrax_ode(mesh_shape,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding))
|
||||
|
||||
|
@ -108,8 +123,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
|||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo,
|
||||
mesh_shape,
|
||||
make_diffrax_ode(mesh_shape,
|
||||
paint_absolute_pos=False,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding))
|
||||
|
@ -130,16 +144,23 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
|||
t1=1.0,
|
||||
dt0=None,
|
||||
y0=y0,
|
||||
args=cosmo,
|
||||
stepsize_controller=controller,
|
||||
saveat=saveat)
|
||||
|
||||
final_field = solutions.ys[-1, 0]
|
||||
print(f"Final field sharding is {final_field.sharding}")
|
||||
|
||||
assert final_field.sharding.is_equivalent_to(sharding , ndim=3) \
|
||||
, f"Final field sharding is not correct .. should be {sharding} it is instead {final_field.sharding}"
|
||||
|
||||
if absolute_painting:
|
||||
multi_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape),
|
||||
solutions.ys[-1, 0],
|
||||
final_field,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
else:
|
||||
multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0],
|
||||
multi_device_final_field = cic_paint_dx(final_field,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
|
||||
|
@ -150,3 +171,230 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
|||
print(f"MSE is {mse}")
|
||||
|
||||
assert mse < _TOLERANCE
|
||||
|
||||
|
||||
@pytest.mark.distributed
|
||||
@pytest.mark.parametrize("order", [1, 2])
|
||||
@pytest.mark.parametrize("pdims", pdims)
|
||||
def test_distrubted_gradients(simulation_config, initial_conditions, cosmo,
|
||||
order, nbody_from_lpt1, nbody_from_lpt2, pdims):
|
||||
|
||||
mesh_shape, box_shape = simulation_config
|
||||
# SINGLE DEVICE RUN
|
||||
cosmo._workspace = {}
|
||||
|
||||
mesh = jax.make_mesh(pdims, ('x', 'y'))
|
||||
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||
halo_size = mesh_shape[0] // 2
|
||||
|
||||
initial_conditions = lax.with_sharding_constraint(initial_conditions,
|
||||
sharding)
|
||||
|
||||
print(f"sharded initial conditions {initial_conditions.sharding}")
|
||||
cosmo._workspace = {}
|
||||
|
||||
@jax.jit
|
||||
def forward_model(initial_conditions, cosmo):
|
||||
|
||||
dx, p, _ = lpt(cosmo,
|
||||
initial_conditions,
|
||||
a=0.1,
|
||||
order=order,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(mesh_shape,
|
||||
paint_absolute_pos=False,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding))
|
||||
y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p)
|
||||
|
||||
solver = Dopri5()
|
||||
controller = PIDController(rtol=1e-8,
|
||||
atol=1e-8,
|
||||
pcoeff=0.4,
|
||||
icoeff=1,
|
||||
dcoeff=0)
|
||||
|
||||
saveat = SaveAt(t1=True)
|
||||
solutions = diffeqsolve(ode_fn,
|
||||
solver,
|
||||
t0=0.1,
|
||||
t1=1.0,
|
||||
dt0=None,
|
||||
y0=y0,
|
||||
args=cosmo,
|
||||
stepsize_controller=controller,
|
||||
saveat=saveat)
|
||||
|
||||
multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0],
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
|
||||
return multi_device_final_field
|
||||
|
||||
@jax.jit
|
||||
def model(initial_conditions, cosmo):
|
||||
final_field = forward_model(initial_conditions, cosmo)
|
||||
return MSE(final_field,
|
||||
nbody_from_lpt1 if order == 1 else nbody_from_lpt2)
|
||||
|
||||
obs_val = model(initial_conditions, cosmo)
|
||||
|
||||
shifted_initial_conditions = initial_conditions + jax.random.normal(
|
||||
jax.random.key(42), initial_conditions.shape) * 5
|
||||
|
||||
good_grads = jax.grad(model)(initial_conditions, cosmo)
|
||||
off_grads = jax.grad(model)(shifted_initial_conditions, cosmo)
|
||||
|
||||
assert good_grads.sharding.is_equivalent_to(initial_conditions.sharding,
|
||||
ndim=3)
|
||||
assert off_grads.sharding.is_equivalent_to(initial_conditions.sharding,
|
||||
ndim=3)
|
||||
|
||||
|
||||
@pytest.mark.distributed
|
||||
@pytest.mark.parametrize("pdims", pdims)
|
||||
def test_fwd_rev_gradients(cosmo, pdims):
|
||||
|
||||
mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0)
|
||||
cosmo._workspace = {}
|
||||
|
||||
mesh = jax.make_mesh(pdims, ('x', 'y'))
|
||||
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||
halo_size = mesh_shape[0] // 2
|
||||
|
||||
initial_conditions = jax.random.normal(jax.random.PRNGKey(42), mesh_shape)
|
||||
initial_conditions = lax.with_sharding_constraint(initial_conditions,
|
||||
sharding)
|
||||
print(f"sharded initial conditions {initial_conditions.sharding}")
|
||||
cosmo._workspace = {}
|
||||
|
||||
@partial(jax.jit, static_argnums=(2, 3, 4))
|
||||
def compute_forces(initial_conditions,
|
||||
cosmo,
|
||||
a=0.5,
|
||||
halo_size=0,
|
||||
sharding=None):
|
||||
|
||||
paint_absolute_pos = False
|
||||
particles = jnp.zeros_like(initial_conditions,
|
||||
shape=(*initial_conditions.shape, 3))
|
||||
|
||||
a = jnp.atleast_1d(a)
|
||||
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
|
||||
|
||||
initial_conditions = jax.lax.with_sharding_constraint(
|
||||
initial_conditions, sharding)
|
||||
delta_k = fft3d(initial_conditions)
|
||||
out_sharding = get_fft_output_sharding(sharding)
|
||||
delta_k = jax.lax.with_sharding_constraint(delta_k, out_sharding)
|
||||
|
||||
initial_force = pm_forces(particles,
|
||||
delta=delta_k,
|
||||
paint_absolute_pos=paint_absolute_pos,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
|
||||
return initial_force[..., 0]
|
||||
|
||||
forces = compute_forces(initial_conditions,
|
||||
cosmo,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
back_gradient = jax.jacrev(compute_forces)(initial_conditions,
|
||||
cosmo,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
fwd_gradient = jax.jacfwd(compute_forces)(initial_conditions,
|
||||
cosmo,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
|
||||
print(f"Forces sharding is {forces.sharding}")
|
||||
print(f"Backward gradient sharding is {back_gradient.sharding}")
|
||||
print(f"Forward gradient sharding is {fwd_gradient.sharding}")
|
||||
assert forces.sharding.is_equivalent_to(initial_conditions.sharding,
|
||||
ndim=3)
|
||||
assert back_gradient[0, 0, 0, ...].sharding.is_equivalent_to(
|
||||
initial_conditions.sharding, ndim=3)
|
||||
assert fwd_gradient.sharding.is_equivalent_to(initial_conditions.sharding,
|
||||
ndim=3)
|
||||
|
||||
|
||||
@pytest.mark.distributed
|
||||
@pytest.mark.parametrize("pdims", pdims)
|
||||
def test_vmap(cosmo, pdims):
|
||||
|
||||
mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0)
|
||||
cosmo._workspace = {}
|
||||
|
||||
mesh = jax.make_mesh(pdims, ('x', 'y'))
|
||||
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||
halo_size = mesh_shape[0] // 2
|
||||
|
||||
single_dev_initial_conditions = jax.random.normal(jax.random.PRNGKey(42),
|
||||
mesh_shape)
|
||||
initial_conditions = lax.with_sharding_constraint(
|
||||
single_dev_initial_conditions, sharding)
|
||||
|
||||
single_ics = jnp.stack([
|
||||
single_dev_initial_conditions, single_dev_initial_conditions,
|
||||
single_dev_initial_conditions
|
||||
])
|
||||
sharded_ics = jnp.stack(
|
||||
[initial_conditions, initial_conditions, initial_conditions])
|
||||
print(f"unsharded initial conditions batch {single_ics.sharding}")
|
||||
print(f"sharded initial conditions batch {sharded_ics.sharding}")
|
||||
cosmo._workspace = {}
|
||||
|
||||
@partial(jax.jit, static_argnums=(2, 3, 4))
|
||||
def compute_forces(initial_conditions,
|
||||
cosmo,
|
||||
a=0.5,
|
||||
halo_size=0,
|
||||
sharding=None):
|
||||
|
||||
paint_absolute_pos = False
|
||||
particles = jnp.zeros_like(initial_conditions,
|
||||
shape=(*initial_conditions.shape, 3))
|
||||
|
||||
a = jnp.atleast_1d(a)
|
||||
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
|
||||
|
||||
initial_conditions = jax.lax.with_sharding_constraint(
|
||||
initial_conditions, sharding)
|
||||
delta_k = fft3d(initial_conditions)
|
||||
out_sharding = get_fft_output_sharding(sharding)
|
||||
delta_k = jax.lax.with_sharding_constraint(delta_k, out_sharding)
|
||||
|
||||
initial_force = pm_forces(particles,
|
||||
delta=delta_k,
|
||||
paint_absolute_pos=paint_absolute_pos,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
|
||||
return initial_force[..., 0]
|
||||
|
||||
def fn(ic):
|
||||
return compute_forces(ic,
|
||||
cosmo,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
|
||||
v_compute_forces = jax.vmap(fn)
|
||||
|
||||
print(f"single_ics shape {single_ics.shape}")
|
||||
print(f"sharded_ics shape {sharded_ics.shape}")
|
||||
|
||||
single_dev_forces = v_compute_forces(single_ics)
|
||||
sharded_forces = v_compute_forces(sharded_ics)
|
||||
|
||||
assert single_dev_forces.ndim == 4
|
||||
assert sharded_forces.ndim == 4
|
||||
|
||||
print(f"Sharded forces {sharded_forces.sharding}")
|
||||
|
||||
assert sharded_forces[0].sharding.is_equivalent_to(
|
||||
initial_conditions.sharding, ndim=3)
|
||||
assert sharded_forces.sharding.spec[0] == None
|
||||
|
|
|
@ -39,7 +39,7 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
|
|||
particles,
|
||||
a=lpt_scale_factor,
|
||||
order=order)
|
||||
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
|
||||
ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
|
||||
y0 = jnp.stack([particles + dx, p])
|
||||
|
||||
else:
|
||||
|
@ -48,7 +48,7 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
|
|||
a=lpt_scale_factor,
|
||||
order=order)
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
|
||||
make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
|
||||
y0 = jnp.stack([dx, p])
|
||||
|
||||
solver = Dopri5()
|
||||
|
@ -66,6 +66,7 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
|
|||
t1=1.0,
|
||||
dt0=None,
|
||||
y0=y0,
|
||||
args=cosmo,
|
||||
adjoint=adjoint,
|
||||
stepsize_controller=controller,
|
||||
saveat=saveat)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue