This commit is contained in:
Wassim KABALAN 2025-02-28 13:03:47 +00:00 committed by GitHub
commit c369cd6e38
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 372 additions and 69 deletions

View file

@ -12,7 +12,7 @@ from jaxpm.kernels import cic_compensation, fftk
from jaxpm.painting_utils import gather, scatter
def _cic_paint_impl(grid_mesh, positions, weight=None):
def _cic_paint_impl(grid_mesh, positions, weight=1.):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
displacement field: [nx, ny, nz, 3]
@ -27,12 +27,10 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
if weight is not None:
if jnp.isscalar(weight):
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
else:
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]),
kernel)
if jnp.isscalar(weight):
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
else:
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]), kernel)
neighboor_coords = jnp.mod(
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
@ -48,7 +46,13 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
@partial(jax.jit, static_argnums=(3, 4))
def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
def cic_paint(grid_mesh, positions, weight=1., halo_size=0, sharding=None):
if sharding is not None:
print("""
WARNING : absolute painting is not recommended in multi-device mode.
Please use relative painting instead.
""")
positions = positions.reshape((*grid_mesh.shape, 3))
@ -57,9 +61,11 @@ def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
weight_spec = P() if jnp.isscalar(weight) else spec
grid_mesh = autoshmap(_cic_paint_impl,
gpu_mesh=gpu_mesh,
in_specs=(spec, spec, P()),
in_specs=(spec, spec, weight_spec),
out_specs=spec)(grid_mesh, positions, weight)
grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,
@ -151,7 +157,10 @@ def cic_paint_2d(mesh, positions, weight):
return mesh
def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
def _cic_paint_dx_impl(displacements,
weight=1.,
halo_size=0,
chunk_size=2**24):
halo_x, _ = halo_size[0]
halo_y, _ = halo_size[1]
@ -190,13 +199,13 @@ def cic_paint_dx(displacements,
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
weight_spec = P() if jnp.isscalar(weight) else spec
grid_mesh = autoshmap(partial(_cic_paint_dx_impl,
halo_size=halo_size,
weight=weight,
chunk_size=chunk_size),
gpu_mesh=gpu_mesh,
in_specs=spec,
out_specs=spec)(displacements)
in_specs=(spec, weight_spec),
out_specs=spec)(displacements, weight)
grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,

View file

@ -172,8 +172,7 @@ def make_ode_fn(mesh_shape,
return nbody_ode
def make_diffrax_ode(cosmo,
mesh_shape,
def make_diffrax_ode(mesh_shape,
paint_absolute_pos=True,
halo_size=0,
sharding=None):
@ -183,6 +182,7 @@ def make_diffrax_ode(cosmo,
state is a tuple (position, velocities)
"""
pos, vel = state
cosmo = args
forces = pm_forces(pos,
mesh_shape=mesh_shape,

View file

@ -32,7 +32,7 @@
"from jaxpm.painting import cic_paint , cic_paint_dx\n",
"from jaxpm.pm import linear_field, lpt, make_diffrax_ode\n",
"from jaxpm.distributed import uniform_particles\n",
"from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve"
"from diffrax import PIDController, Tsit5, ODETerm, SaveAt, diffeqsolve"
]
},
{
@ -41,10 +41,9 @@
"source": [
"### Particle Mesh Simulation with Diffrax Leapfrog Solver\n",
"\n",
"In this setup, we use the `LeapfrogMidpoint` solver from the `diffrax` library to evolve particle displacements over time in our Particle Mesh simulation. The novelty here is the use of a **Leapfrog solver** from `diffrax` for efficient, memory-saving time integration.\n",
"In this setup, we use the `Tsit5` solver from the `diffrax` library to evolve particle displacements over time in our Particle Mesh simulation.\n",
"\n",
"- **Leapfrog Integration**: This symplectic integrator is well-suited for simulations of gravitational dynamics, preserving energy over long timescales and allowing larger time steps without sacrificing accuracy.\n",
"- **Efficient Displacement Tracking**: We initialize only displacements (`dx`) rather than absolute positions, which, combined with Leapfrogs stability, enhances memory efficiency and speeds up computation.\n"
"- **Efficient Displacement Tracking**: We initialize only displacements (`dx`) rather than absolute positions, which uses a the `pmwd` cic_painting algorithm which is more memory efficient at the cost of being slightly slower\n"
]
},
{
@ -84,10 +83,10 @@
" \n",
" # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
" solver = LeapfrogMidpoint()\n",
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False))\n",
" solver = Tsit5()\n",
"\n",
" stepsize_controller = ConstantStepSize()\n",
" stepsize_controller = PIDController(rtol=1e-6 , atol=1e-6)\n",
" res = diffeqsolve(ode_fn,\n",
" solver,\n",
" t0=0.1,\n",
@ -257,10 +256,10 @@
" \n",
" # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape))\n",
" solver = LeapfrogMidpoint()\n",
" make_diffrax_ode(mesh_shape))\n",
" solver = Tsit5()\n",
"\n",
" stepsize_controller = ConstantStepSize()\n",
" stepsize_controller = PIDController(rtol=1e-6 , atol=1e-6)\n",
" res = diffeqsolve(ode_fn,\n",
" solver,\n",
" t0=0.1,\n",

View file

@ -90,7 +90,7 @@
"\n",
"This cell configures a **2x4 device mesh** across 8 devices and sets up named sharding to distribute data efficiently.\n",
"\n",
"- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid. `create_device_mesh(pdims)` initializes this layout across available GPUs.\n",
"- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid.\n",
"- **Sharding with Mesh**: `Mesh(devices, axis_names=('x', 'y'))` assigns the mesh grid axes, which allows flexible mapping of array data across devices.\n",
"- **PartitionSpec and NamedSharding**: `PartitionSpec` defines data partitioning across mesh axes `('x', 'y')`, and `NamedSharding(mesh, P('x', 'y'))` specifies this sharding scheme for arrays in the simulation.\n",
"\n",
@ -99,21 +99,18 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"id": "9edd2246",
"metadata": {},
"outputs": [],
"source": [
"from jax.experimental.mesh_utils import create_device_mesh\n",
"from jax.experimental.multihost_utils import process_allgather\n",
"from jax.sharding import Mesh, NamedSharding\n",
"from jax.sharding import PartitionSpec as P\n",
"from jax.sharding import PartitionSpec as P, NamedSharding\n",
"\n",
"all_gather = partial(process_allgather, tiled=False)\n",
"all_gather = partial(process_allgather, tiled=True)\n",
"\n",
"pdims = (2, 4)\n",
"devices = create_device_mesh(pdims)\n",
"mesh = Mesh(devices, axis_names=('x', 'y'))\n",
"mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))\n",
"sharding = NamedSharding(mesh, P('x', 'y'))"
]
},
@ -180,10 +177,10 @@
"\n",
" # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
" solver = LeapfrogMidpoint()\n",
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False , sharding=sharding))\n",
" solver = Tsit5()\n",
"\n",
" stepsize_controller = ConstantStepSize()\n",
" stepsize_controller = PIDController(rtol=1e-6 , atol=1e-6)\n",
" res = diffeqsolve(ode_fn,\n",
" solver,\n",
" t0=0.1,\n",
@ -410,10 +407,10 @@
"\n",
" # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
" solver = LeapfrogMidpoint()\n",
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False , sharding=sharding , halo_size=halo_size))\n",
" solver = Tsit5()\n",
"\n",
" stepsize_controller = ConstantStepSize()\n",
" stepsize_controller = PIDController(rtol=1e-6 , atol=1e-6)\n",
" res = diffeqsolve(ode_fn,\n",
" solver,\n",
" t0=0.1,\n",
@ -689,7 +686,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
"version": "3.11.11"
}
},
"nbformat": 4,

View file

@ -62,7 +62,7 @@
"\n",
"This cell configures a **2x4 device mesh** across 8 devices and sets up named sharding to distribute data efficiently.\n",
"\n",
"- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid. `create_device_mesh(pdims)` initializes this layout across available GPUs.\n",
"- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid.\n",
"- **Sharding with Mesh**: `Mesh(devices, axis_names=('x', 'y'))` assigns the mesh grid axes, which allows flexible mapping of array data across devices.\n",
"- **PartitionSpec and NamedSharding**: `PartitionSpec` defines data partitioning across mesh axes `('x', 'y')`, and `NamedSharding(mesh, P('x', 'y'))` specifies this sharding scheme for arrays in the simulation.\n",
"\n",
@ -71,7 +71,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -80,11 +80,10 @@
"from jax.sharding import Mesh, NamedSharding\n",
"from jax.sharding import PartitionSpec as P\n",
"\n",
"all_gather = partial(process_allgather, tiled=False)\n",
"all_gather = partial(process_allgather, tiled=True)\n",
"\n",
"pdims = (2, 4)\n",
"devices = create_device_mesh(pdims)\n",
"mesh = Mesh(devices, axis_names=('x', 'y'))\n",
"mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))\n",
"sharding = NamedSharding(mesh, P('x', 'y'))"
]
},
@ -124,7 +123,7 @@
"\n",
" # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False,sharding=sharding , halo_size=halo_size))\n",
" solver = LeapfrogMidpoint()\n",
"\n",
" stepsize_controller = ConstantStepSize()\n",
@ -288,7 +287,7 @@
"\n",
" # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False,sharding=sharding , halo_size=halo_size))\n",
" solver = Dopri5()\n",
"\n",
" stepsize_controller = PIDController(rtol=1e-5,atol=1e-5)\n",

View file

@ -17,9 +17,8 @@ import jax_cosmo as jc
import numpy as np
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
PIDController, SaveAt, diffeqsolve)
from jax.experimental.mesh_utils import create_device_mesh
from jax.experimental.multihost_utils import process_allgather
from jax.sharding import Mesh, NamedSharding
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P
from jaxpm.kernels import interpolate_power_spectrum
@ -78,7 +77,7 @@ def parse_arguments():
def create_mesh_and_sharding(pdims):
devices = create_device_mesh(pdims)
mesh = Mesh(devices, axis_names=('x', 'y'))
mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
return mesh, sharding
@ -106,7 +105,10 @@ def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
make_diffrax_ode(mesh_shape,
paint_absolute_pos=False,
sharding=sharding,
halo_size=halo_size))
# Choose solver
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()

View file

@ -37,3 +37,50 @@ Each notebook includes installation instructions and guidelines for configuring
- **SLURM** for job scheduling on clusters (if running multi-host setups)
> **Note**: These notebooks are tested on the **Jean Zay** supercomputer and may require configuration changes for different HPC clusters.
## Caveats
### Cloud-in-Cell (CIC) Painting (Single Device)
There is two ways to perform the CIC painting in JAXPM. The first one is to use the `cic_paint` which paints absolute particle positions to the mesh. The second one is to use the `cic_paint_dx` which paints relative particle positions to the mesh (using uniform particles). The absolute version is faster at the cost of more memory usage.
inorder to use relative painting you need to :
- Set the `particles` argument in `lpt` function from `jaxpm.pm` to `None`
- Set `paint_absolute_pos` to `False` in `make_ode_fn` or `make_diffrax_ode` function from `jaxpm.pm` (it is True by default)
Otherwise you set `particles` to the starting particles of your choice and leave `paint_absolute_pos` to `True` (default value).
### Cloud-in-Cell (CIC) Painting (Multi Device)
Both `cic_paint` and `cic_paint_dx` functions are available in multi-device mode.
You need to set the arguments `sharding` and `halo_size` which is explained in the notebook [03-MultiGPU_PM_Halo.ipynb](03-MultiGPU_PM_Halo.ipynb).
One thing to note that `cic_paint` is not as accurate as `cic_paint_dx` in multi-device mode and therefor is not recommended.
Using relative painting in multi-device mode is just like in single device mode.\
You need to set the `particles` argument in `lpt` function from `jaxpm.pm` to `None` and set `paint_absolute_pos` to `False`
### Distributed PM
To run a distributed PM follow the examples in notebooks [03](03-MultiGPU_PM_Halo.ipynb) and [05](05-MultiHost_PM.ipynb) for multi-host.
In short you need to set the arguments `sharding` and `halo_size` in `lpt` , `linear_field` the `make_ode` functions and `pm_forces` if you use it.
Missmatching the shardings will give you errors and unexpected results.
You can also use `normal_field` and `uniform_particles` from `jaxpm.pm.distributed` to create the fields and particles with a sharding.
### Choosing the right pdims
pdims are processor dimensions.\
Explained more in the jaxdecomp paper [here](https://github.com/DifferentiableUniverseInitiative/jaxDecomp).
For 8 devices there are three decompositions that are possible:
- (1 , 8)
- (2 , 4) , (4 , 2)
- (8 , 1)
(1 , X) should be the fastest (2 , X) or (X , 2) is more accurate but slightly slower.\
and (X , 1) is giving the least accurate results for some reason so it is not recommended.

View file

@ -76,7 +76,7 @@ def test_nbody_absolute(simulation_config, initial_conditions,
a=lpt_scale_factor,
order=order)
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
solver = Dopri5()
controller = PIDController(rtol=1e-8,
@ -95,6 +95,7 @@ def test_nbody_absolute(simulation_config, initial_conditions,
t1=1.0,
dt0=None,
y0=y0,
args=cosmo,
stepsize_controller=controller,
saveat=saveat)
@ -121,8 +122,7 @@ def test_nbody_relative(simulation_config, initial_conditions,
# Initial displacement
dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order)
ode_fn = ODETerm(
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
ode_fn = ODETerm(make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
solver = Dopri5()
controller = PIDController(rtol=1e-9,
@ -141,6 +141,7 @@ def test_nbody_relative(simulation_config, initial_conditions,
t1=1.0,
dt0=None,
y0=y0,
args=cosmo,
stepsize_controller=controller,
saveat=saveat)

View file

@ -2,8 +2,11 @@ from conftest import initialize_distributed
initialize_distributed() # ignore : E402
from functools import partial # noqa : E402
import jax # noqa : E402
import jax.numpy as jnp # noqa : E402
import jax_cosmo as jc # noqa : E402
import pytest # noqa : E402
from diffrax import SaveAt # noqa : E402
from diffrax import Dopri5, ODETerm, PIDController, diffeqsolve
@ -12,19 +15,31 @@ from jax import lax # noqa : E402
from jax.experimental.multihost_utils import process_allgather # noqa : E402
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P # noqa : E402
from jaxdecomp import get_fft_output_sharding
from jaxpm.distributed import uniform_particles # noqa : E402
from jaxpm.distributed import fft3d, ifft3d
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402
from jaxpm.pm import lpt, make_diffrax_ode, pm_forces # noqa : E402
_TOLERANCE = 3.0 # 🙃🙃
_TOLERANCE = 1e-1 # 🙃🙃
pdims = [(1, 8), (8, 1), (4, 2), (2, 4)]
@pytest.mark.distributed
@pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("pdims", pdims)
@pytest.mark.parametrize("absolute_painting", [True, False])
def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
absolute_painting):
pdims, absolute_painting):
if absolute_painting:
pytest.skip("Absolute painting is not recommended in distributed mode")
painting_str = "absolute" if absolute_painting else "relative"
print("=" * 50)
print(f"Running with {painting_str} painting and pdims {pdims} ...")
mesh_shape, box_shape = simulation_config
# SINGLE DEVICE RUN
@ -37,12 +52,12 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
particles,
a=0.1,
order=order)
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
y0 = jnp.stack([particles + dx, p])
else:
dx, p, _ = lpt(cosmo, initial_conditions, a=0.1, order=order)
ode_fn = ODETerm(
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
ode_fn = ODETerm(make_diffrax_ode(mesh_shape,
paint_absolute_pos=False))
y0 = jnp.stack([dx, p])
solver = Dopri5()
@ -60,6 +75,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
t1=1.0,
dt0=None,
y0=y0,
args=cosmo,
stepsize_controller=controller,
saveat=saveat)
@ -72,7 +88,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
print("Done with single device run")
# MULTI DEVICE RUN
mesh = jax.make_mesh((1, 8), ('x', 'y'))
mesh = jax.make_mesh(pdims, ('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2
@ -94,8 +110,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(cosmo,
mesh_shape,
make_diffrax_ode(mesh_shape,
halo_size=halo_size,
sharding=sharding))
@ -108,8 +123,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
halo_size=halo_size,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(cosmo,
mesh_shape,
make_diffrax_ode(mesh_shape,
paint_absolute_pos=False,
halo_size=halo_size,
sharding=sharding))
@ -130,16 +144,23 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
t1=1.0,
dt0=None,
y0=y0,
args=cosmo,
stepsize_controller=controller,
saveat=saveat)
final_field = solutions.ys[-1, 0]
print(f"Final field sharding is {final_field.sharding}")
assert final_field.sharding.is_equivalent_to(sharding , ndim=3) \
, f"Final field sharding is not correct .. should be {sharding} it is instead {final_field.sharding}"
if absolute_painting:
multi_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape),
solutions.ys[-1, 0],
final_field,
halo_size=halo_size,
sharding=sharding)
else:
multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0],
multi_device_final_field = cic_paint_dx(final_field,
halo_size=halo_size,
sharding=sharding)
@ -150,3 +171,230 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
print(f"MSE is {mse}")
assert mse < _TOLERANCE
@pytest.mark.distributed
@pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("pdims", pdims)
def test_distrubted_gradients(simulation_config, initial_conditions, cosmo,
order, nbody_from_lpt1, nbody_from_lpt2, pdims):
mesh_shape, box_shape = simulation_config
# SINGLE DEVICE RUN
cosmo._workspace = {}
mesh = jax.make_mesh(pdims, ('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2
initial_conditions = lax.with_sharding_constraint(initial_conditions,
sharding)
print(f"sharded initial conditions {initial_conditions.sharding}")
cosmo._workspace = {}
@jax.jit
def forward_model(initial_conditions, cosmo):
dx, p, _ = lpt(cosmo,
initial_conditions,
a=0.1,
order=order,
halo_size=halo_size,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(mesh_shape,
paint_absolute_pos=False,
halo_size=halo_size,
sharding=sharding))
y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p)
solver = Dopri5()
controller = PIDController(rtol=1e-8,
atol=1e-8,
pcoeff=0.4,
icoeff=1,
dcoeff=0)
saveat = SaveAt(t1=True)
solutions = diffeqsolve(ode_fn,
solver,
t0=0.1,
t1=1.0,
dt0=None,
y0=y0,
args=cosmo,
stepsize_controller=controller,
saveat=saveat)
multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0],
halo_size=halo_size,
sharding=sharding)
return multi_device_final_field
@jax.jit
def model(initial_conditions, cosmo):
final_field = forward_model(initial_conditions, cosmo)
return MSE(final_field,
nbody_from_lpt1 if order == 1 else nbody_from_lpt2)
obs_val = model(initial_conditions, cosmo)
shifted_initial_conditions = initial_conditions + jax.random.normal(
jax.random.key(42), initial_conditions.shape) * 5
good_grads = jax.grad(model)(initial_conditions, cosmo)
off_grads = jax.grad(model)(shifted_initial_conditions, cosmo)
assert good_grads.sharding.is_equivalent_to(initial_conditions.sharding,
ndim=3)
assert off_grads.sharding.is_equivalent_to(initial_conditions.sharding,
ndim=3)
@pytest.mark.distributed
@pytest.mark.parametrize("pdims", pdims)
def test_fwd_rev_gradients(cosmo, pdims):
mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0)
cosmo._workspace = {}
mesh = jax.make_mesh(pdims, ('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2
initial_conditions = jax.random.normal(jax.random.PRNGKey(42), mesh_shape)
initial_conditions = lax.with_sharding_constraint(initial_conditions,
sharding)
print(f"sharded initial conditions {initial_conditions.sharding}")
cosmo._workspace = {}
@partial(jax.jit, static_argnums=(2, 3, 4))
def compute_forces(initial_conditions,
cosmo,
a=0.5,
halo_size=0,
sharding=None):
paint_absolute_pos = False
particles = jnp.zeros_like(initial_conditions,
shape=(*initial_conditions.shape, 3))
a = jnp.atleast_1d(a)
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
initial_conditions = jax.lax.with_sharding_constraint(
initial_conditions, sharding)
delta_k = fft3d(initial_conditions)
out_sharding = get_fft_output_sharding(sharding)
delta_k = jax.lax.with_sharding_constraint(delta_k, out_sharding)
initial_force = pm_forces(particles,
delta=delta_k,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding)
return initial_force[..., 0]
forces = compute_forces(initial_conditions,
cosmo,
halo_size=halo_size,
sharding=sharding)
back_gradient = jax.jacrev(compute_forces)(initial_conditions,
cosmo,
halo_size=halo_size,
sharding=sharding)
fwd_gradient = jax.jacfwd(compute_forces)(initial_conditions,
cosmo,
halo_size=halo_size,
sharding=sharding)
print(f"Forces sharding is {forces.sharding}")
print(f"Backward gradient sharding is {back_gradient.sharding}")
print(f"Forward gradient sharding is {fwd_gradient.sharding}")
assert forces.sharding.is_equivalent_to(initial_conditions.sharding,
ndim=3)
assert back_gradient[0, 0, 0, ...].sharding.is_equivalent_to(
initial_conditions.sharding, ndim=3)
assert fwd_gradient.sharding.is_equivalent_to(initial_conditions.sharding,
ndim=3)
@pytest.mark.distributed
@pytest.mark.parametrize("pdims", pdims)
def test_vmap(cosmo, pdims):
mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0)
cosmo._workspace = {}
mesh = jax.make_mesh(pdims, ('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2
single_dev_initial_conditions = jax.random.normal(jax.random.PRNGKey(42),
mesh_shape)
initial_conditions = lax.with_sharding_constraint(
single_dev_initial_conditions, sharding)
single_ics = jnp.stack([
single_dev_initial_conditions, single_dev_initial_conditions,
single_dev_initial_conditions
])
sharded_ics = jnp.stack(
[initial_conditions, initial_conditions, initial_conditions])
print(f"unsharded initial conditions batch {single_ics.sharding}")
print(f"sharded initial conditions batch {sharded_ics.sharding}")
cosmo._workspace = {}
@partial(jax.jit, static_argnums=(2, 3, 4))
def compute_forces(initial_conditions,
cosmo,
a=0.5,
halo_size=0,
sharding=None):
paint_absolute_pos = False
particles = jnp.zeros_like(initial_conditions,
shape=(*initial_conditions.shape, 3))
a = jnp.atleast_1d(a)
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
initial_conditions = jax.lax.with_sharding_constraint(
initial_conditions, sharding)
delta_k = fft3d(initial_conditions)
out_sharding = get_fft_output_sharding(sharding)
delta_k = jax.lax.with_sharding_constraint(delta_k, out_sharding)
initial_force = pm_forces(particles,
delta=delta_k,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding)
return initial_force[..., 0]
def fn(ic):
return compute_forces(ic,
cosmo,
halo_size=halo_size,
sharding=sharding)
v_compute_forces = jax.vmap(fn)
print(f"single_ics shape {single_ics.shape}")
print(f"sharded_ics shape {sharded_ics.shape}")
single_dev_forces = v_compute_forces(single_ics)
sharded_forces = v_compute_forces(sharded_ics)
assert single_dev_forces.ndim == 4
assert sharded_forces.ndim == 4
print(f"Sharded forces {sharded_forces.sharding}")
assert sharded_forces[0].sharding.is_equivalent_to(
initial_conditions.sharding, ndim=3)
assert sharded_forces.sharding.spec[0] == None

View file

@ -39,7 +39,7 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
particles,
a=lpt_scale_factor,
order=order)
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
y0 = jnp.stack([particles + dx, p])
else:
@ -48,7 +48,7 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
a=lpt_scale_factor,
order=order)
ode_fn = ODETerm(
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
y0 = jnp.stack([dx, p])
solver = Dopri5()
@ -66,6 +66,7 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
t1=1.0,
dt0=None,
y0=y0,
args=cosmo,
adjoint=adjoint,
stepsize_controller=controller,
saveat=saveat)