This commit is contained in:
Wassim KABALAN 2025-05-09 20:08:35 +00:00 committed by GitHub
commit 037b465824
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 461 additions and 154 deletions

View file

@ -119,7 +119,7 @@ def growth_factor(cosmo, a):
if cosmo._flags["gamma_growth"]:
return _growth_factor_gamma(cosmo, a)
else:
return _growth_factor_ODE(cosmo, a)
return _growth_factor_ODE(cosmo, a)[0]
def growth_factor_second(cosmo, a):
@ -225,7 +225,7 @@ def growth_rate_second(cosmo, a):
return _growth_rate_second_ODE(cosmo, a)
def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4):
def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=256, eps=1e-4):
"""Compute linear growth factor D(a) at a given scale factor,
normalised such that D(a=1) = 1.
@ -243,57 +243,56 @@ def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4):
Growth factor computed at requested scale factor
"""
# Check if growth has already been computed
if not "background.growth_factor" in cosmo._workspace.keys():
# Compute tabulated array
atab = np.logspace(log10_amin, 0.0, steps)
#if not "background.growth_factor" in cosmo._workspace.keys():
# Compute tabulated array
atab = np.logspace(log10_amin, 0.0, steps)
def D_derivs(y, x):
q = (2.0 - 0.5 *
(Omega_m_a(cosmo, x) +
(1.0 + 3.0 * w(cosmo, x)) * Omega_de_a(cosmo, x))) / x
r = 1.5 * Omega_m_a(cosmo, x) / x / x
def D_derivs(y, x):
q = (2.0 - 0.5 *
(Omega_m_a(cosmo, x) +
(1.0 + 3.0 * w(cosmo, x)) * Omega_de_a(cosmo, x))) / x
r = 1.5 * Omega_m_a(cosmo, x) / x / x
g1, g2 = y[0]
f1, f2 = y[1]
dy1da = [f1, -q * f1 + r * g1]
dy2da = [f2, -q * f2 + r * g2 - r * g1**2]
return np.array([[dy1da[0], dy2da[0]], [dy1da[1], dy2da[1]]])
g1, g2 = y[0]
f1, f2 = y[1]
dy1da = [f1, -q * f1 + r * g1]
dy2da = [f2, -q * f2 + r * g2 - r * g1**2]
return np.array([[dy1da[0], dy2da[0]], [dy1da[1], dy2da[1]]])
y0 = np.array([[atab[0], -3.0 / 7 * atab[0]**2],
[1.0, -6.0 / 7 * atab[0]]])
y = odeint(D_derivs, y0, atab)
y0 = np.array([[atab[0], -3.0 / 7 * atab[0]**2],
[1.0, -6.0 / 7 * atab[0]]])
y = odeint(D_derivs, y0, atab)
# compute second order derivatives growth
dyda2 = D_derivs(np.transpose(y, (1, 2, 0)), atab)
dyda2 = np.transpose(dyda2, (2, 0, 1))
# compute second order derivatives growth
dyda2 = D_derivs(np.transpose(y, (1, 2, 0)), atab)
dyda2 = np.transpose(dyda2, (2, 0, 1))
# Normalize results
y1 = y[:, 0, 0]
gtab = y1 / y1[-1]
y2 = y[:, 0, 1]
g2tab = y2 / y2[-1]
# To transform from dD/da to dlnD/dlna: dlnD/dlna = a / D dD/da
ftab = y[:, 1, 0] / y1[-1] * atab / gtab
f2tab = y[:, 1, 1] / y2[-1] * atab / g2tab
# Similarly for second order derivatives
# Note: these factors are not accessible as parent functions yet
# since it is unclear what to refer to them with.
htab = dyda2[:, 1, 0] / y1[-1] * atab / gtab
h2tab = dyda2[:, 1, 1] / y2[-1] * atab / g2tab
# Normalize results
y1 = y[:, 0, 0]
gtab = y1 / y1[-1]
y2 = y[:, 0, 1]
g2tab = y2 / y2[-1]
# To transform from dD/da to dlnD/dlna: dlnD/dlna = a / D dD/da
ftab = y[:, 1, 0] / y1[-1] * atab / gtab
f2tab = y[:, 1, 1] / y2[-1] * atab / g2tab
# Similarly for second order derivatives
# Note: these factors are not accessible as parent functions yet
# since it is unclear what to refer to them with.
htab = dyda2[:, 1, 0] / y1[-1] * atab / gtab
h2tab = dyda2[:, 1, 1] / y2[-1] * atab / g2tab
cache = {
"a": atab,
"g": gtab,
"f": ftab,
"h": htab,
"g2": g2tab,
"f2": f2tab,
"h2": h2tab,
}
cosmo._workspace["background.growth_factor"] = cache
else:
cache = cosmo._workspace["background.growth_factor"]
return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0)
cache = {
"a": atab,
"g": gtab,
"f": ftab,
"h": htab,
"g2": g2tab,
"f2": f2tab,
"h2": h2tab,
}
return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0) , cache
def _growth_rate_ODE(cosmo, a):
@ -314,12 +313,10 @@ def _growth_rate_ODE(cosmo, a):
Growth rate computed at requested scale factor
"""
# Check if growth has already been computed, if not, compute it
if not "background.growth_factor" in cosmo._workspace.keys():
_growth_factor_ODE(cosmo, np.atleast_1d(1.0))
cache = cosmo._workspace["background.growth_factor"]
cache = _growth_factor_ODE(cosmo, np.atleast_1d(1.0))[1]
return interp(a, cache["a"], cache["f"])
def _growth_factor_second_ODE(cosmo, a):
"""Compute second order growth factor D2(a) at a given scale factor,
normalised such that D(a=1) = 1.
@ -338,36 +335,12 @@ def _growth_factor_second_ODE(cosmo, a):
Second order growth factor computed at requested scale factor
"""
# Check if growth has already been computed, if not, compute it
if not "background.growth_factor" in cosmo._workspace.keys():
_growth_factor_ODE(cosmo, np.atleast_1d(1.0))
cache = cosmo._workspace["background.growth_factor"]
#if not "background.growth_factor" in cosmo._workspace.keys():
# _growth_factor_ODE(cosmo, np.atleast_1d(1.0))
cache = _growth_factor_ODE(cosmo, a)[1]
return interp(a, cache["a"], cache["g2"])
def _growth_rate_ODE(cosmo, a):
"""Compute growth rate dD/dlna at a given scale factor by solving the linear
growth ODE.
Parameters
----------
cosmo: `Cosmology`
Cosmology object
a: array_like
Scale factor
Returns
-------
f: ndarray, or float if input scalar
Second order growth rate computed at requested scale factor
"""
# Check if growth has already been computed, if not, compute it
if not "background.growth_factor" in cosmo._workspace.keys():
_growth_factor_ODE(cosmo, np.atleast_1d(1.0))
cache = cosmo._workspace["background.growth_factor"]
return interp(a, cache["a"], cache["f"])
def _growth_rate_second_ODE(cosmo, a):
"""Compute second order growth rate dD2/dlna at a given scale factor by solving the linear
growth ODE.
@ -386,9 +359,9 @@ def _growth_rate_second_ODE(cosmo, a):
Second order growth rate computed at requested scale factor
"""
# Check if growth has already been computed, if not, compute it
if not "background.growth_factor" in cosmo._workspace.keys():
_growth_factor_ODE(cosmo, np.atleast_1d(1.0))
cache = cosmo._workspace["background.growth_factor"]
#if not "background.growth_factor" in cosmo._workspace.keys():
# _growth_factor_ODE(cosmo, np.atleast_1d(1.0))
cache = _growth_factor_ODE(cosmo, a)[1]
return interp(a, cache["a"], cache["f2"])
@ -521,6 +494,34 @@ def Gf2(cosmo, a):
return D2f * np.power(a, 3) * np.power(Esqr(cosmo, a), 0.5)
def gp(cosmo, a):
r""" Derivative of D1 against a
Parameters
----------
cosmo: dict
Cosmology dictionary.
a : array_like
Scale factor.
Returns
-------
Scalar float Tensor : the derivative of D1 against a.
Notes
-----
The expression for :math:`gp(a)` is:
.. math::
gp(a)=\frac{dD1}{da}= D'_{1norm}/a
"""
f1 = growth_rate(cosmo, a)
g1 = growth_factor(cosmo, a)
D1f = f1 * g1 / a
return D1f
def dGfa(cosmo, a):
r""" Derivative of Gf against a
@ -549,7 +550,8 @@ def dGfa(cosmo, a):
f1 = growth_rate(cosmo, a)
g1 = growth_factor(cosmo, a)
D1f = f1 * g1 / a
cache = cosmo._workspace['background.growth_factor']
#cache = cosmo._workspace['background.growth_factor']
cache = _growth_factor_ODE(cosmo, a)[1]
f1p = cache['h'] / cache['a'] * cache['g']
f1p = interp(np.log(a), np.log(cache['a']), f1p)
Ea = E(cosmo, a)
@ -584,9 +586,10 @@ def dGf2a(cosmo, a):
f2 = growth_rate_second(cosmo, a)
g2 = growth_factor_second(cosmo, a)
D2f = f2 * g2 / a
cache = cosmo._workspace['background.growth_factor']
#cache = cosmo._workspace['background.growth_factor']
cache = _growth_factor_ODE(cosmo, a)[1]
f2p = cache['h2'] / cache['a'] * cache['g2']
f2p = interp(np.log(a), np.log(cache['a']), f2p)
E_a = E(cosmo, a)
return (f2p * a**3 * E_a + D2f * a**3 * dEa(cosmo, a) +
3 * a**2 * E_a * D2f)
3 * a**2 * E_a * D2f)

View file

@ -12,7 +12,7 @@ from jaxpm.kernels import cic_compensation, fftk
from jaxpm.painting_utils import gather, scatter
def _cic_paint_impl(grid_mesh, positions, weight=None):
def _cic_paint_impl(grid_mesh, positions, weight=1.):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
displacement field: [nx, ny, nz, 3]
@ -27,12 +27,10 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
if weight is not None:
if jnp.isscalar(weight):
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
else:
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]),
kernel)
if jnp.isscalar(weight):
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
else:
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]), kernel)
neighboor_coords = jnp.mod(
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
@ -48,7 +46,13 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
@partial(jax.jit, static_argnums=(3, 4))
def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
def cic_paint(grid_mesh, positions, weight=1., halo_size=0, sharding=None):
if sharding is not None:
print("""
WARNING : absolute painting is not recommended in multi-device mode.
Please use relative painting instead.
""")
positions = positions.reshape((*grid_mesh.shape, 3))
@ -57,9 +61,11 @@ def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
weight_spec = P() if jnp.isscalar(weight) else spec
grid_mesh = autoshmap(_cic_paint_impl,
gpu_mesh=gpu_mesh,
in_specs=(spec, spec, P()),
in_specs=(spec, spec, weight_spec),
out_specs=spec)(grid_mesh, positions, weight)
grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,
@ -128,6 +134,7 @@ def cic_paint_2d(mesh, positions, weight):
positions: [npart, 2]
weight: [npart]
"""
positions = positions.reshape([-1, 2])
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
@ -136,7 +143,7 @@ def cic_paint_2d(mesh, positions, weight):
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1]
if weight is not None:
kernel = kernel * weight[..., jnp.newaxis]
kernel = kernel * weight.reshape(*positions.shape[:-1])
neighboor_coords = jnp.mod(
neighboor_coords.reshape([-1, 4, 2]).astype('int32'),
@ -151,7 +158,10 @@ def cic_paint_2d(mesh, positions, weight):
return mesh
def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
def _cic_paint_dx_impl(displacements,
weight=1.,
halo_size=0,
chunk_size=2**24):
halo_x, _ = halo_size[0]
halo_y, _ = halo_size[1]
@ -190,13 +200,13 @@ def cic_paint_dx(displacements,
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
weight_spec = P() if jnp.isscalar(weight) else spec
grid_mesh = autoshmap(partial(_cic_paint_dx_impl,
halo_size=halo_size,
weight=weight,
chunk_size=chunk_size),
gpu_mesh=gpu_mesh,
in_specs=spec,
out_specs=spec)(displacements)
in_specs=(spec, weight_spec),
out_specs=spec)(displacements, weight)
grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,

View file

@ -172,8 +172,7 @@ def make_ode_fn(mesh_shape,
return nbody_ode
def make_diffrax_ode(cosmo,
mesh_shape,
def make_diffrax_ode(mesh_shape,
paint_absolute_pos=True,
halo_size=0,
sharding=None):
@ -183,6 +182,7 @@ def make_diffrax_ode(cosmo,
state is a tuple (position, velocities)
"""
pos, vel = state
cosmo = args
forces = pm_forces(pos,
mesh_shape=mesh_shape,

View file

@ -32,7 +32,7 @@
"from jaxpm.painting import cic_paint , cic_paint_dx\n",
"from jaxpm.pm import linear_field, lpt, make_diffrax_ode\n",
"from jaxpm.distributed import uniform_particles\n",
"from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve"
"from diffrax import PIDController, Tsit5, ODETerm, SaveAt, diffeqsolve"
]
},
{
@ -41,10 +41,9 @@
"source": [
"### Particle Mesh Simulation with Diffrax Leapfrog Solver\n",
"\n",
"In this setup, we use the `LeapfrogMidpoint` solver from the `diffrax` library to evolve particle displacements over time in our Particle Mesh simulation. The novelty here is the use of a **Leapfrog solver** from `diffrax` for efficient, memory-saving time integration.\n",
"In this setup, we use the `Tsit5` solver from the `diffrax` library to evolve particle displacements over time in our Particle Mesh simulation.\n",
"\n",
"- **Leapfrog Integration**: This symplectic integrator is well-suited for simulations of gravitational dynamics, preserving energy over long timescales and allowing larger time steps without sacrificing accuracy.\n",
"- **Efficient Displacement Tracking**: We initialize only displacements (`dx`) rather than absolute positions, which, combined with Leapfrogs stability, enhances memory efficiency and speeds up computation.\n"
"- **Efficient Displacement Tracking**: We initialize only displacements (`dx`) rather than absolute positions, which uses a the `pmwd` cic_painting algorithm which is more memory efficient at the cost of being slightly slower\n"
]
},
{
@ -84,10 +83,10 @@
" \n",
" # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
" solver = LeapfrogMidpoint()\n",
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False))\n",
" solver = Tsit5()\n",
"\n",
" stepsize_controller = ConstantStepSize()\n",
" stepsize_controller = PIDController(rtol=1e-6 , atol=1e-6)\n",
" res = diffeqsolve(ode_fn,\n",
" solver,\n",
" t0=0.1,\n",
@ -257,10 +256,10 @@
" \n",
" # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape))\n",
" solver = LeapfrogMidpoint()\n",
" make_diffrax_ode(mesh_shape))\n",
" solver = Tsit5()\n",
"\n",
" stepsize_controller = ConstantStepSize()\n",
" stepsize_controller = PIDController(rtol=1e-6 , atol=1e-6)\n",
" res = diffeqsolve(ode_fn,\n",
" solver,\n",
" t0=0.1,\n",

View file

@ -90,7 +90,7 @@
"\n",
"This cell configures a **2x4 device mesh** across 8 devices and sets up named sharding to distribute data efficiently.\n",
"\n",
"- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid. `create_device_mesh(pdims)` initializes this layout across available GPUs.\n",
"- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid.\n",
"- **Sharding with Mesh**: `Mesh(devices, axis_names=('x', 'y'))` assigns the mesh grid axes, which allows flexible mapping of array data across devices.\n",
"- **PartitionSpec and NamedSharding**: `PartitionSpec` defines data partitioning across mesh axes `('x', 'y')`, and `NamedSharding(mesh, P('x', 'y'))` specifies this sharding scheme for arrays in the simulation.\n",
"\n",
@ -99,21 +99,18 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"id": "9edd2246",
"metadata": {},
"outputs": [],
"source": [
"from jax.experimental.mesh_utils import create_device_mesh\n",
"from jax.experimental.multihost_utils import process_allgather\n",
"from jax.sharding import Mesh, NamedSharding\n",
"from jax.sharding import PartitionSpec as P\n",
"from jax.sharding import PartitionSpec as P, NamedSharding\n",
"\n",
"all_gather = partial(process_allgather, tiled=False)\n",
"all_gather = partial(process_allgather, tiled=True)\n",
"\n",
"pdims = (2, 4)\n",
"devices = create_device_mesh(pdims)\n",
"mesh = Mesh(devices, axis_names=('x', 'y'))\n",
"mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))\n",
"sharding = NamedSharding(mesh, P('x', 'y'))"
]
},
@ -180,10 +177,10 @@
"\n",
" # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
" solver = LeapfrogMidpoint()\n",
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False , sharding=sharding))\n",
" solver = Tsit5()\n",
"\n",
" stepsize_controller = ConstantStepSize()\n",
" stepsize_controller = PIDController(rtol=1e-6 , atol=1e-6)\n",
" res = diffeqsolve(ode_fn,\n",
" solver,\n",
" t0=0.1,\n",
@ -410,10 +407,10 @@
"\n",
" # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
" solver = LeapfrogMidpoint()\n",
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False , sharding=sharding , halo_size=halo_size))\n",
" solver = Tsit5()\n",
"\n",
" stepsize_controller = ConstantStepSize()\n",
" stepsize_controller = PIDController(rtol=1e-6 , atol=1e-6)\n",
" res = diffeqsolve(ode_fn,\n",
" solver,\n",
" t0=0.1,\n",
@ -689,7 +686,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
"version": "3.11.11"
}
},
"nbformat": 4,

View file

@ -62,7 +62,7 @@
"\n",
"This cell configures a **2x4 device mesh** across 8 devices and sets up named sharding to distribute data efficiently.\n",
"\n",
"- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid. `create_device_mesh(pdims)` initializes this layout across available GPUs.\n",
"- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid.\n",
"- **Sharding with Mesh**: `Mesh(devices, axis_names=('x', 'y'))` assigns the mesh grid axes, which allows flexible mapping of array data across devices.\n",
"- **PartitionSpec and NamedSharding**: `PartitionSpec` defines data partitioning across mesh axes `('x', 'y')`, and `NamedSharding(mesh, P('x', 'y'))` specifies this sharding scheme for arrays in the simulation.\n",
"\n",
@ -71,7 +71,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -80,11 +80,10 @@
"from jax.sharding import Mesh, NamedSharding\n",
"from jax.sharding import PartitionSpec as P\n",
"\n",
"all_gather = partial(process_allgather, tiled=False)\n",
"all_gather = partial(process_allgather, tiled=True)\n",
"\n",
"pdims = (2, 4)\n",
"devices = create_device_mesh(pdims)\n",
"mesh = Mesh(devices, axis_names=('x', 'y'))\n",
"mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))\n",
"sharding = NamedSharding(mesh, P('x', 'y'))"
]
},
@ -124,7 +123,7 @@
"\n",
" # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False,sharding=sharding , halo_size=halo_size))\n",
" solver = LeapfrogMidpoint()\n",
"\n",
" stepsize_controller = ConstantStepSize()\n",
@ -288,7 +287,7 @@
"\n",
" # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False,sharding=sharding , halo_size=halo_size))\n",
" solver = Dopri5()\n",
"\n",
" stepsize_controller = PIDController(rtol=1e-5,atol=1e-5)\n",

View file

@ -17,9 +17,8 @@ import jax_cosmo as jc
import numpy as np
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
PIDController, SaveAt, diffeqsolve)
from jax.experimental.mesh_utils import create_device_mesh
from jax.experimental.multihost_utils import process_allgather
from jax.sharding import Mesh, NamedSharding
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P
from jaxpm.kernels import interpolate_power_spectrum
@ -78,7 +77,7 @@ def parse_arguments():
def create_mesh_and_sharding(pdims):
devices = create_device_mesh(pdims)
mesh = Mesh(devices, axis_names=('x', 'y'))
mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
return mesh, sharding
@ -106,7 +105,10 @@ def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
make_diffrax_ode(mesh_shape,
paint_absolute_pos=False,
sharding=sharding,
halo_size=halo_size))
# Choose solver
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()

View file

@ -37,3 +37,50 @@ Each notebook includes installation instructions and guidelines for configuring
- **SLURM** for job scheduling on clusters (if running multi-host setups)
> **Note**: These notebooks are tested on the **Jean Zay** supercomputer and may require configuration changes for different HPC clusters.
## Caveats
### Cloud-in-Cell (CIC) Painting (Single Device)
There is two ways to perform the CIC painting in JAXPM. The first one is to use the `cic_paint` which paints absolute particle positions to the mesh. The second one is to use the `cic_paint_dx` which paints relative particle positions to the mesh (using uniform particles). The absolute version is faster at the cost of more memory usage.
inorder to use relative painting you need to :
- Set the `particles` argument in `lpt` function from `jaxpm.pm` to `None`
- Set `paint_absolute_pos` to `False` in `make_ode_fn` or `make_diffrax_ode` function from `jaxpm.pm` (it is True by default)
Otherwise you set `particles` to the starting particles of your choice and leave `paint_absolute_pos` to `True` (default value).
### Cloud-in-Cell (CIC) Painting (Multi Device)
Both `cic_paint` and `cic_paint_dx` functions are available in multi-device mode.
You need to set the arguments `sharding` and `halo_size` which is explained in the notebook [03-MultiGPU_PM_Halo.ipynb](03-MultiGPU_PM_Halo.ipynb).
One thing to note that `cic_paint` is not as accurate as `cic_paint_dx` in multi-device mode and therefor is not recommended.
Using relative painting in multi-device mode is just like in single device mode.\
You need to set the `particles` argument in `lpt` function from `jaxpm.pm` to `None` and set `paint_absolute_pos` to `False`
### Distributed PM
To run a distributed PM follow the examples in notebooks [03](03-MultiGPU_PM_Halo.ipynb) and [05](05-MultiHost_PM.ipynb) for multi-host.
In short you need to set the arguments `sharding` and `halo_size` in `lpt` , `linear_field` the `make_ode` functions and `pm_forces` if you use it.
Missmatching the shardings will give you errors and unexpected results.
You can also use `normal_field` and `uniform_particles` from `jaxpm.pm.distributed` to create the fields and particles with a sharding.
### Choosing the right pdims
pdims are processor dimensions.\
Explained more in the jaxdecomp paper [here](https://github.com/DifferentiableUniverseInitiative/jaxDecomp).
For 8 devices there are three decompositions that are possible:
- (1 , 8)
- (2 , 4) , (4 , 2)
- (8 , 1)
(1 , X) should be the fastest (2 , X) or (X , 2) is more accurate but slightly slower.\
and (X , 1) is giving the least accurate results for some reason so it is not recommended.

View file

@ -76,7 +76,7 @@ def test_nbody_absolute(simulation_config, initial_conditions,
a=lpt_scale_factor,
order=order)
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
solver = Dopri5()
controller = PIDController(rtol=1e-8,
@ -95,6 +95,7 @@ def test_nbody_absolute(simulation_config, initial_conditions,
t1=1.0,
dt0=None,
y0=y0,
args=cosmo,
stepsize_controller=controller,
saveat=saveat)
@ -121,8 +122,7 @@ def test_nbody_relative(simulation_config, initial_conditions,
# Initial displacement
dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order)
ode_fn = ODETerm(
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
ode_fn = ODETerm(make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
solver = Dopri5()
controller = PIDController(rtol=1e-9,
@ -141,6 +141,7 @@ def test_nbody_relative(simulation_config, initial_conditions,
t1=1.0,
dt0=None,
y0=y0,
args=cosmo,
stepsize_controller=controller,
saveat=saveat)

View file

@ -2,8 +2,11 @@ from conftest import initialize_distributed
initialize_distributed() # ignore : E402
from functools import partial # noqa : E402
import jax # noqa : E402
import jax.numpy as jnp # noqa : E402
import jax_cosmo as jc # noqa : E402
import pytest # noqa : E402
from diffrax import SaveAt # noqa : E402
from diffrax import Dopri5, ODETerm, PIDController, diffeqsolve
@ -12,19 +15,31 @@ from jax import lax # noqa : E402
from jax.experimental.multihost_utils import process_allgather # noqa : E402
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P # noqa : E402
from jaxdecomp import get_fft_output_sharding
from jaxpm.distributed import uniform_particles # noqa : E402
from jaxpm.distributed import fft3d, ifft3d
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402
from jaxpm.pm import lpt, make_diffrax_ode, pm_forces # noqa : E402
_TOLERANCE = 3.0 # 🙃🙃
_TOLERANCE = 1e-1 # 🙃🙃
pdims = [(1, 8), (8, 1), (4, 2), (2, 4)]
@pytest.mark.distributed
@pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("pdims", pdims)
@pytest.mark.parametrize("absolute_painting", [True, False])
def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
absolute_painting):
pdims, absolute_painting):
if absolute_painting:
pytest.skip("Absolute painting is not recommended in distributed mode")
painting_str = "absolute" if absolute_painting else "relative"
print("=" * 50)
print(f"Running with {painting_str} painting and pdims {pdims} ...")
mesh_shape, box_shape = simulation_config
# SINGLE DEVICE RUN
@ -37,12 +52,12 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
particles,
a=0.1,
order=order)
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
y0 = jnp.stack([particles + dx, p])
else:
dx, p, _ = lpt(cosmo, initial_conditions, a=0.1, order=order)
ode_fn = ODETerm(
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
ode_fn = ODETerm(make_diffrax_ode(mesh_shape,
paint_absolute_pos=False))
y0 = jnp.stack([dx, p])
solver = Dopri5()
@ -60,6 +75,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
t1=1.0,
dt0=None,
y0=y0,
args=cosmo,
stepsize_controller=controller,
saveat=saveat)
@ -72,7 +88,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
print("Done with single device run")
# MULTI DEVICE RUN
mesh = jax.make_mesh((1, 8), ('x', 'y'))
mesh = jax.make_mesh(pdims, ('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2
@ -94,8 +110,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(cosmo,
mesh_shape,
make_diffrax_ode(mesh_shape,
halo_size=halo_size,
sharding=sharding))
@ -108,8 +123,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
halo_size=halo_size,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(cosmo,
mesh_shape,
make_diffrax_ode(mesh_shape,
paint_absolute_pos=False,
halo_size=halo_size,
sharding=sharding))
@ -130,16 +144,23 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
t1=1.0,
dt0=None,
y0=y0,
args=cosmo,
stepsize_controller=controller,
saveat=saveat)
final_field = solutions.ys[-1, 0]
print(f"Final field sharding is {final_field.sharding}")
assert final_field.sharding.is_equivalent_to(sharding , ndim=3) \
, f"Final field sharding is not correct .. should be {sharding} it is instead {final_field.sharding}"
if absolute_painting:
multi_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape),
solutions.ys[-1, 0],
final_field,
halo_size=halo_size,
sharding=sharding)
else:
multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0],
multi_device_final_field = cic_paint_dx(final_field,
halo_size=halo_size,
sharding=sharding)
@ -150,3 +171,230 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
print(f"MSE is {mse}")
assert mse < _TOLERANCE
@pytest.mark.distributed
@pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("pdims", pdims)
def test_distrubted_gradients(simulation_config, initial_conditions, cosmo,
order, nbody_from_lpt1, nbody_from_lpt2, pdims):
mesh_shape, box_shape = simulation_config
# SINGLE DEVICE RUN
cosmo._workspace = {}
mesh = jax.make_mesh(pdims, ('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2
initial_conditions = lax.with_sharding_constraint(initial_conditions,
sharding)
print(f"sharded initial conditions {initial_conditions.sharding}")
cosmo._workspace = {}
@jax.jit
def forward_model(initial_conditions, cosmo):
dx, p, _ = lpt(cosmo,
initial_conditions,
a=0.1,
order=order,
halo_size=halo_size,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(mesh_shape,
paint_absolute_pos=False,
halo_size=halo_size,
sharding=sharding))
y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p)
solver = Dopri5()
controller = PIDController(rtol=1e-8,
atol=1e-8,
pcoeff=0.4,
icoeff=1,
dcoeff=0)
saveat = SaveAt(t1=True)
solutions = diffeqsolve(ode_fn,
solver,
t0=0.1,
t1=1.0,
dt0=None,
y0=y0,
args=cosmo,
stepsize_controller=controller,
saveat=saveat)
multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0],
halo_size=halo_size,
sharding=sharding)
return multi_device_final_field
@jax.jit
def model(initial_conditions, cosmo):
final_field = forward_model(initial_conditions, cosmo)
return MSE(final_field,
nbody_from_lpt1 if order == 1 else nbody_from_lpt2)
obs_val = model(initial_conditions, cosmo)
shifted_initial_conditions = initial_conditions + jax.random.normal(
jax.random.key(42), initial_conditions.shape) * 5
good_grads = jax.grad(model)(initial_conditions, cosmo)
off_grads = jax.grad(model)(shifted_initial_conditions, cosmo)
assert good_grads.sharding.is_equivalent_to(initial_conditions.sharding,
ndim=3)
assert off_grads.sharding.is_equivalent_to(initial_conditions.sharding,
ndim=3)
@pytest.mark.distributed
@pytest.mark.parametrize("pdims", pdims)
def test_fwd_rev_gradients(cosmo, pdims):
mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0)
cosmo._workspace = {}
mesh = jax.make_mesh(pdims, ('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2
initial_conditions = jax.random.normal(jax.random.PRNGKey(42), mesh_shape)
initial_conditions = lax.with_sharding_constraint(initial_conditions,
sharding)
print(f"sharded initial conditions {initial_conditions.sharding}")
cosmo._workspace = {}
@partial(jax.jit, static_argnums=(2, 3, 4))
def compute_forces(initial_conditions,
cosmo,
a=0.5,
halo_size=0,
sharding=None):
paint_absolute_pos = False
particles = jnp.zeros_like(initial_conditions,
shape=(*initial_conditions.shape, 3))
a = jnp.atleast_1d(a)
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
initial_conditions = jax.lax.with_sharding_constraint(
initial_conditions, sharding)
delta_k = fft3d(initial_conditions)
out_sharding = get_fft_output_sharding(sharding)
delta_k = jax.lax.with_sharding_constraint(delta_k, out_sharding)
initial_force = pm_forces(particles,
delta=delta_k,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding)
return initial_force[..., 0]
forces = compute_forces(initial_conditions,
cosmo,
halo_size=halo_size,
sharding=sharding)
back_gradient = jax.jacrev(compute_forces)(initial_conditions,
cosmo,
halo_size=halo_size,
sharding=sharding)
fwd_gradient = jax.jacfwd(compute_forces)(initial_conditions,
cosmo,
halo_size=halo_size,
sharding=sharding)
print(f"Forces sharding is {forces.sharding}")
print(f"Backward gradient sharding is {back_gradient.sharding}")
print(f"Forward gradient sharding is {fwd_gradient.sharding}")
assert forces.sharding.is_equivalent_to(initial_conditions.sharding,
ndim=3)
assert back_gradient[0, 0, 0, ...].sharding.is_equivalent_to(
initial_conditions.sharding, ndim=3)
assert fwd_gradient.sharding.is_equivalent_to(initial_conditions.sharding,
ndim=3)
@pytest.mark.distributed
@pytest.mark.parametrize("pdims", pdims)
def test_vmap(cosmo, pdims):
mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0)
cosmo._workspace = {}
mesh = jax.make_mesh(pdims, ('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2
single_dev_initial_conditions = jax.random.normal(jax.random.PRNGKey(42),
mesh_shape)
initial_conditions = lax.with_sharding_constraint(
single_dev_initial_conditions, sharding)
single_ics = jnp.stack([
single_dev_initial_conditions, single_dev_initial_conditions,
single_dev_initial_conditions
])
sharded_ics = jnp.stack(
[initial_conditions, initial_conditions, initial_conditions])
print(f"unsharded initial conditions batch {single_ics.sharding}")
print(f"sharded initial conditions batch {sharded_ics.sharding}")
cosmo._workspace = {}
@partial(jax.jit, static_argnums=(2, 3, 4))
def compute_forces(initial_conditions,
cosmo,
a=0.5,
halo_size=0,
sharding=None):
paint_absolute_pos = False
particles = jnp.zeros_like(initial_conditions,
shape=(*initial_conditions.shape, 3))
a = jnp.atleast_1d(a)
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
initial_conditions = jax.lax.with_sharding_constraint(
initial_conditions, sharding)
delta_k = fft3d(initial_conditions)
out_sharding = get_fft_output_sharding(sharding)
delta_k = jax.lax.with_sharding_constraint(delta_k, out_sharding)
initial_force = pm_forces(particles,
delta=delta_k,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding)
return initial_force[..., 0]
def fn(ic):
return compute_forces(ic,
cosmo,
halo_size=halo_size,
sharding=sharding)
v_compute_forces = jax.vmap(fn)
print(f"single_ics shape {single_ics.shape}")
print(f"sharded_ics shape {sharded_ics.shape}")
single_dev_forces = v_compute_forces(single_ics)
sharded_forces = v_compute_forces(sharded_ics)
assert single_dev_forces.ndim == 4
assert sharded_forces.ndim == 4
print(f"Sharded forces {sharded_forces.sharding}")
assert sharded_forces[0].sharding.is_equivalent_to(
initial_conditions.sharding, ndim=3)
assert sharded_forces.sharding.spec[0] == None

View file

@ -39,7 +39,7 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
particles,
a=lpt_scale_factor,
order=order)
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
y0 = jnp.stack([particles + dx, p])
else:
@ -48,7 +48,7 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
a=lpt_scale_factor,
order=order)
ode_fn = ODETerm(
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
y0 = jnp.stack([dx, p])
solver = Dopri5()
@ -66,6 +66,7 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
t1=1.0,
dt0=None,
y0=y0,
args=cosmo,
adjoint=adjoint,
stepsize_controller=controller,
saveat=saveat)