This commit is contained in:
Wassim KABALAN 2025-02-28 13:03:47 +00:00 committed by GitHub
commit c369cd6e38
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 372 additions and 69 deletions

View file

@ -12,7 +12,7 @@ from jaxpm.kernels import cic_compensation, fftk
from jaxpm.painting_utils import gather, scatter from jaxpm.painting_utils import gather, scatter
def _cic_paint_impl(grid_mesh, positions, weight=None): def _cic_paint_impl(grid_mesh, positions, weight=1.):
""" Paints positions onto mesh """ Paints positions onto mesh
mesh: [nx, ny, nz] mesh: [nx, ny, nz]
displacement field: [nx, ny, nz, 3] displacement field: [nx, ny, nz, 3]
@ -27,12 +27,10 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
neighboor_coords = floor + connection neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords) kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
if weight is not None: if jnp.isscalar(weight):
if jnp.isscalar(weight): kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel) else:
else: kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]), kernel)
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]),
kernel)
neighboor_coords = jnp.mod( neighboor_coords = jnp.mod(
neighboor_coords.reshape([-1, 8, 3]).astype('int32'), neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
@ -48,7 +46,13 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
@partial(jax.jit, static_argnums=(3, 4)) @partial(jax.jit, static_argnums=(3, 4))
def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None): def cic_paint(grid_mesh, positions, weight=1., halo_size=0, sharding=None):
if sharding is not None:
print("""
WARNING : absolute painting is not recommended in multi-device mode.
Please use relative painting instead.
""")
positions = positions.reshape((*grid_mesh.shape, 3)) positions = positions.reshape((*grid_mesh.shape, 3))
@ -57,9 +61,11 @@ def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if isinstance(sharding, NamedSharding) else P() spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
weight_spec = P() if jnp.isscalar(weight) else spec
grid_mesh = autoshmap(_cic_paint_impl, grid_mesh = autoshmap(_cic_paint_impl,
gpu_mesh=gpu_mesh, gpu_mesh=gpu_mesh,
in_specs=(spec, spec, P()), in_specs=(spec, spec, weight_spec),
out_specs=spec)(grid_mesh, positions, weight) out_specs=spec)(grid_mesh, positions, weight)
grid_mesh = halo_exchange(grid_mesh, grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents, halo_extents=halo_extents,
@ -151,7 +157,10 @@ def cic_paint_2d(mesh, positions, weight):
return mesh return mesh
def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24): def _cic_paint_dx_impl(displacements,
weight=1.,
halo_size=0,
chunk_size=2**24):
halo_x, _ = halo_size[0] halo_x, _ = halo_size[0]
halo_y, _ = halo_size[1] halo_y, _ = halo_size[1]
@ -190,13 +199,13 @@ def cic_paint_dx(displacements,
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if isinstance(sharding, NamedSharding) else P() spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
weight_spec = P() if jnp.isscalar(weight) else spec
grid_mesh = autoshmap(partial(_cic_paint_dx_impl, grid_mesh = autoshmap(partial(_cic_paint_dx_impl,
halo_size=halo_size, halo_size=halo_size,
weight=weight,
chunk_size=chunk_size), chunk_size=chunk_size),
gpu_mesh=gpu_mesh, gpu_mesh=gpu_mesh,
in_specs=spec, in_specs=(spec, weight_spec),
out_specs=spec)(displacements) out_specs=spec)(displacements, weight)
grid_mesh = halo_exchange(grid_mesh, grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents, halo_extents=halo_extents,

View file

@ -172,8 +172,7 @@ def make_ode_fn(mesh_shape,
return nbody_ode return nbody_ode
def make_diffrax_ode(cosmo, def make_diffrax_ode(mesh_shape,
mesh_shape,
paint_absolute_pos=True, paint_absolute_pos=True,
halo_size=0, halo_size=0,
sharding=None): sharding=None):
@ -183,6 +182,7 @@ def make_diffrax_ode(cosmo,
state is a tuple (position, velocities) state is a tuple (position, velocities)
""" """
pos, vel = state pos, vel = state
cosmo = args
forces = pm_forces(pos, forces = pm_forces(pos,
mesh_shape=mesh_shape, mesh_shape=mesh_shape,

View file

@ -32,7 +32,7 @@
"from jaxpm.painting import cic_paint , cic_paint_dx\n", "from jaxpm.painting import cic_paint , cic_paint_dx\n",
"from jaxpm.pm import linear_field, lpt, make_diffrax_ode\n", "from jaxpm.pm import linear_field, lpt, make_diffrax_ode\n",
"from jaxpm.distributed import uniform_particles\n", "from jaxpm.distributed import uniform_particles\n",
"from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve" "from diffrax import PIDController, Tsit5, ODETerm, SaveAt, diffeqsolve"
] ]
}, },
{ {
@ -41,10 +41,9 @@
"source": [ "source": [
"### Particle Mesh Simulation with Diffrax Leapfrog Solver\n", "### Particle Mesh Simulation with Diffrax Leapfrog Solver\n",
"\n", "\n",
"In this setup, we use the `LeapfrogMidpoint` solver from the `diffrax` library to evolve particle displacements over time in our Particle Mesh simulation. The novelty here is the use of a **Leapfrog solver** from `diffrax` for efficient, memory-saving time integration.\n", "In this setup, we use the `Tsit5` solver from the `diffrax` library to evolve particle displacements over time in our Particle Mesh simulation.\n",
"\n", "\n",
"- **Leapfrog Integration**: This symplectic integrator is well-suited for simulations of gravitational dynamics, preserving energy over long timescales and allowing larger time steps without sacrificing accuracy.\n", "- **Efficient Displacement Tracking**: We initialize only displacements (`dx`) rather than absolute positions, which uses a the `pmwd` cic_painting algorithm which is more memory efficient at the cost of being slightly slower\n"
"- **Efficient Displacement Tracking**: We initialize only displacements (`dx`) rather than absolute positions, which, combined with Leapfrogs stability, enhances memory efficiency and speeds up computation.\n"
] ]
}, },
{ {
@ -84,10 +83,10 @@
" \n", " \n",
" # Evolve the simulation forward\n", " # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n", " ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n", " make_diffrax_ode(mesh_shape, paint_absolute_pos=False))\n",
" solver = LeapfrogMidpoint()\n", " solver = Tsit5()\n",
"\n", "\n",
" stepsize_controller = ConstantStepSize()\n", " stepsize_controller = PIDController(rtol=1e-6 , atol=1e-6)\n",
" res = diffeqsolve(ode_fn,\n", " res = diffeqsolve(ode_fn,\n",
" solver,\n", " solver,\n",
" t0=0.1,\n", " t0=0.1,\n",
@ -257,10 +256,10 @@
" \n", " \n",
" # Evolve the simulation forward\n", " # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n", " ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape))\n", " make_diffrax_ode(mesh_shape))\n",
" solver = LeapfrogMidpoint()\n", " solver = Tsit5()\n",
"\n", "\n",
" stepsize_controller = ConstantStepSize()\n", " stepsize_controller = PIDController(rtol=1e-6 , atol=1e-6)\n",
" res = diffeqsolve(ode_fn,\n", " res = diffeqsolve(ode_fn,\n",
" solver,\n", " solver,\n",
" t0=0.1,\n", " t0=0.1,\n",

View file

@ -90,7 +90,7 @@
"\n", "\n",
"This cell configures a **2x4 device mesh** across 8 devices and sets up named sharding to distribute data efficiently.\n", "This cell configures a **2x4 device mesh** across 8 devices and sets up named sharding to distribute data efficiently.\n",
"\n", "\n",
"- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid. `create_device_mesh(pdims)` initializes this layout across available GPUs.\n", "- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid.\n",
"- **Sharding with Mesh**: `Mesh(devices, axis_names=('x', 'y'))` assigns the mesh grid axes, which allows flexible mapping of array data across devices.\n", "- **Sharding with Mesh**: `Mesh(devices, axis_names=('x', 'y'))` assigns the mesh grid axes, which allows flexible mapping of array data across devices.\n",
"- **PartitionSpec and NamedSharding**: `PartitionSpec` defines data partitioning across mesh axes `('x', 'y')`, and `NamedSharding(mesh, P('x', 'y'))` specifies this sharding scheme for arrays in the simulation.\n", "- **PartitionSpec and NamedSharding**: `PartitionSpec` defines data partitioning across mesh axes `('x', 'y')`, and `NamedSharding(mesh, P('x', 'y'))` specifies this sharding scheme for arrays in the simulation.\n",
"\n", "\n",
@ -99,21 +99,18 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": null,
"id": "9edd2246", "id": "9edd2246",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from jax.experimental.mesh_utils import create_device_mesh\n",
"from jax.experimental.multihost_utils import process_allgather\n", "from jax.experimental.multihost_utils import process_allgather\n",
"from jax.sharding import Mesh, NamedSharding\n", "from jax.sharding import PartitionSpec as P, NamedSharding\n",
"from jax.sharding import PartitionSpec as P\n",
"\n", "\n",
"all_gather = partial(process_allgather, tiled=False)\n", "all_gather = partial(process_allgather, tiled=True)\n",
"\n", "\n",
"pdims = (2, 4)\n", "pdims = (2, 4)\n",
"devices = create_device_mesh(pdims)\n", "mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))\n",
"mesh = Mesh(devices, axis_names=('x', 'y'))\n",
"sharding = NamedSharding(mesh, P('x', 'y'))" "sharding = NamedSharding(mesh, P('x', 'y'))"
] ]
}, },
@ -180,10 +177,10 @@
"\n", "\n",
" # Evolve the simulation forward\n", " # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n", " ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n", " make_diffrax_ode(mesh_shape, paint_absolute_pos=False , sharding=sharding))\n",
" solver = LeapfrogMidpoint()\n", " solver = Tsit5()\n",
"\n", "\n",
" stepsize_controller = ConstantStepSize()\n", " stepsize_controller = PIDController(rtol=1e-6 , atol=1e-6)\n",
" res = diffeqsolve(ode_fn,\n", " res = diffeqsolve(ode_fn,\n",
" solver,\n", " solver,\n",
" t0=0.1,\n", " t0=0.1,\n",
@ -410,10 +407,10 @@
"\n", "\n",
" # Evolve the simulation forward\n", " # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n", " ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n", " make_diffrax_ode(mesh_shape, paint_absolute_pos=False , sharding=sharding , halo_size=halo_size))\n",
" solver = LeapfrogMidpoint()\n", " solver = Tsit5()\n",
"\n", "\n",
" stepsize_controller = ConstantStepSize()\n", " stepsize_controller = PIDController(rtol=1e-6 , atol=1e-6)\n",
" res = diffeqsolve(ode_fn,\n", " res = diffeqsolve(ode_fn,\n",
" solver,\n", " solver,\n",
" t0=0.1,\n", " t0=0.1,\n",
@ -689,7 +686,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.4" "version": "3.11.11"
} }
}, },
"nbformat": 4, "nbformat": 4,

View file

@ -62,7 +62,7 @@
"\n", "\n",
"This cell configures a **2x4 device mesh** across 8 devices and sets up named sharding to distribute data efficiently.\n", "This cell configures a **2x4 device mesh** across 8 devices and sets up named sharding to distribute data efficiently.\n",
"\n", "\n",
"- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid. `create_device_mesh(pdims)` initializes this layout across available GPUs.\n", "- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid.\n",
"- **Sharding with Mesh**: `Mesh(devices, axis_names=('x', 'y'))` assigns the mesh grid axes, which allows flexible mapping of array data across devices.\n", "- **Sharding with Mesh**: `Mesh(devices, axis_names=('x', 'y'))` assigns the mesh grid axes, which allows flexible mapping of array data across devices.\n",
"- **PartitionSpec and NamedSharding**: `PartitionSpec` defines data partitioning across mesh axes `('x', 'y')`, and `NamedSharding(mesh, P('x', 'y'))` specifies this sharding scheme for arrays in the simulation.\n", "- **PartitionSpec and NamedSharding**: `PartitionSpec` defines data partitioning across mesh axes `('x', 'y')`, and `NamedSharding(mesh, P('x', 'y'))` specifies this sharding scheme for arrays in the simulation.\n",
"\n", "\n",
@ -71,7 +71,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -80,11 +80,10 @@
"from jax.sharding import Mesh, NamedSharding\n", "from jax.sharding import Mesh, NamedSharding\n",
"from jax.sharding import PartitionSpec as P\n", "from jax.sharding import PartitionSpec as P\n",
"\n", "\n",
"all_gather = partial(process_allgather, tiled=False)\n", "all_gather = partial(process_allgather, tiled=True)\n",
"\n", "\n",
"pdims = (2, 4)\n", "pdims = (2, 4)\n",
"devices = create_device_mesh(pdims)\n", "mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))\n",
"mesh = Mesh(devices, axis_names=('x', 'y'))\n",
"sharding = NamedSharding(mesh, P('x', 'y'))" "sharding = NamedSharding(mesh, P('x', 'y'))"
] ]
}, },
@ -124,7 +123,7 @@
"\n", "\n",
" # Evolve the simulation forward\n", " # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n", " ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n", " make_diffrax_ode(mesh_shape, paint_absolute_pos=False,sharding=sharding , halo_size=halo_size))\n",
" solver = LeapfrogMidpoint()\n", " solver = LeapfrogMidpoint()\n",
"\n", "\n",
" stepsize_controller = ConstantStepSize()\n", " stepsize_controller = ConstantStepSize()\n",
@ -288,7 +287,7 @@
"\n", "\n",
" # Evolve the simulation forward\n", " # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n", " ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n", " make_diffrax_ode(mesh_shape, paint_absolute_pos=False,sharding=sharding , halo_size=halo_size))\n",
" solver = Dopri5()\n", " solver = Dopri5()\n",
"\n", "\n",
" stepsize_controller = PIDController(rtol=1e-5,atol=1e-5)\n", " stepsize_controller = PIDController(rtol=1e-5,atol=1e-5)\n",

View file

@ -17,9 +17,8 @@ import jax_cosmo as jc
import numpy as np import numpy as np
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm, from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
PIDController, SaveAt, diffeqsolve) PIDController, SaveAt, diffeqsolve)
from jax.experimental.mesh_utils import create_device_mesh
from jax.experimental.multihost_utils import process_allgather from jax.experimental.multihost_utils import process_allgather
from jax.sharding import Mesh, NamedSharding from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P from jax.sharding import PartitionSpec as P
from jaxpm.kernels import interpolate_power_spectrum from jaxpm.kernels import interpolate_power_spectrum
@ -78,7 +77,7 @@ def parse_arguments():
def create_mesh_and_sharding(pdims): def create_mesh_and_sharding(pdims):
devices = create_device_mesh(pdims) devices = create_device_mesh(pdims)
mesh = Mesh(devices, axis_names=('x', 'y')) mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y')) sharding = NamedSharding(mesh, P('x', 'y'))
return mesh, sharding return mesh, sharding
@ -106,7 +105,10 @@ def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
sharding=sharding) sharding=sharding)
ode_fn = ODETerm( ode_fn = ODETerm(
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False)) make_diffrax_ode(mesh_shape,
paint_absolute_pos=False,
sharding=sharding,
halo_size=halo_size))
# Choose solver # Choose solver
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5() solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()

View file

@ -37,3 +37,50 @@ Each notebook includes installation instructions and guidelines for configuring
- **SLURM** for job scheduling on clusters (if running multi-host setups) - **SLURM** for job scheduling on clusters (if running multi-host setups)
> **Note**: These notebooks are tested on the **Jean Zay** supercomputer and may require configuration changes for different HPC clusters. > **Note**: These notebooks are tested on the **Jean Zay** supercomputer and may require configuration changes for different HPC clusters.
## Caveats
### Cloud-in-Cell (CIC) Painting (Single Device)
There is two ways to perform the CIC painting in JAXPM. The first one is to use the `cic_paint` which paints absolute particle positions to the mesh. The second one is to use the `cic_paint_dx` which paints relative particle positions to the mesh (using uniform particles). The absolute version is faster at the cost of more memory usage.
inorder to use relative painting you need to :
- Set the `particles` argument in `lpt` function from `jaxpm.pm` to `None`
- Set `paint_absolute_pos` to `False` in `make_ode_fn` or `make_diffrax_ode` function from `jaxpm.pm` (it is True by default)
Otherwise you set `particles` to the starting particles of your choice and leave `paint_absolute_pos` to `True` (default value).
### Cloud-in-Cell (CIC) Painting (Multi Device)
Both `cic_paint` and `cic_paint_dx` functions are available in multi-device mode.
You need to set the arguments `sharding` and `halo_size` which is explained in the notebook [03-MultiGPU_PM_Halo.ipynb](03-MultiGPU_PM_Halo.ipynb).
One thing to note that `cic_paint` is not as accurate as `cic_paint_dx` in multi-device mode and therefor is not recommended.
Using relative painting in multi-device mode is just like in single device mode.\
You need to set the `particles` argument in `lpt` function from `jaxpm.pm` to `None` and set `paint_absolute_pos` to `False`
### Distributed PM
To run a distributed PM follow the examples in notebooks [03](03-MultiGPU_PM_Halo.ipynb) and [05](05-MultiHost_PM.ipynb) for multi-host.
In short you need to set the arguments `sharding` and `halo_size` in `lpt` , `linear_field` the `make_ode` functions and `pm_forces` if you use it.
Missmatching the shardings will give you errors and unexpected results.
You can also use `normal_field` and `uniform_particles` from `jaxpm.pm.distributed` to create the fields and particles with a sharding.
### Choosing the right pdims
pdims are processor dimensions.\
Explained more in the jaxdecomp paper [here](https://github.com/DifferentiableUniverseInitiative/jaxDecomp).
For 8 devices there are three decompositions that are possible:
- (1 , 8)
- (2 , 4) , (4 , 2)
- (8 , 1)
(1 , X) should be the fastest (2 , X) or (X , 2) is more accurate but slightly slower.\
and (X , 1) is giving the least accurate results for some reason so it is not recommended.

View file

@ -76,7 +76,7 @@ def test_nbody_absolute(simulation_config, initial_conditions,
a=lpt_scale_factor, a=lpt_scale_factor,
order=order) order=order)
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape)) ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
solver = Dopri5() solver = Dopri5()
controller = PIDController(rtol=1e-8, controller = PIDController(rtol=1e-8,
@ -95,6 +95,7 @@ def test_nbody_absolute(simulation_config, initial_conditions,
t1=1.0, t1=1.0,
dt0=None, dt0=None,
y0=y0, y0=y0,
args=cosmo,
stepsize_controller=controller, stepsize_controller=controller,
saveat=saveat) saveat=saveat)
@ -121,8 +122,7 @@ def test_nbody_relative(simulation_config, initial_conditions,
# Initial displacement # Initial displacement
dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order) dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order)
ode_fn = ODETerm( ode_fn = ODETerm(make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
solver = Dopri5() solver = Dopri5()
controller = PIDController(rtol=1e-9, controller = PIDController(rtol=1e-9,
@ -141,6 +141,7 @@ def test_nbody_relative(simulation_config, initial_conditions,
t1=1.0, t1=1.0,
dt0=None, dt0=None,
y0=y0, y0=y0,
args=cosmo,
stepsize_controller=controller, stepsize_controller=controller,
saveat=saveat) saveat=saveat)

View file

@ -2,8 +2,11 @@ from conftest import initialize_distributed
initialize_distributed() # ignore : E402 initialize_distributed() # ignore : E402
from functools import partial # noqa : E402
import jax # noqa : E402 import jax # noqa : E402
import jax.numpy as jnp # noqa : E402 import jax.numpy as jnp # noqa : E402
import jax_cosmo as jc # noqa : E402
import pytest # noqa : E402 import pytest # noqa : E402
from diffrax import SaveAt # noqa : E402 from diffrax import SaveAt # noqa : E402
from diffrax import Dopri5, ODETerm, PIDController, diffeqsolve from diffrax import Dopri5, ODETerm, PIDController, diffeqsolve
@ -12,19 +15,31 @@ from jax import lax # noqa : E402
from jax.experimental.multihost_utils import process_allgather # noqa : E402 from jax.experimental.multihost_utils import process_allgather # noqa : E402
from jax.sharding import NamedSharding from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P # noqa : E402 from jax.sharding import PartitionSpec as P # noqa : E402
from jaxdecomp import get_fft_output_sharding
from jaxpm.distributed import uniform_particles # noqa : E402 from jaxpm.distributed import uniform_particles # noqa : E402
from jaxpm.distributed import fft3d, ifft3d
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402 from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402 from jaxpm.pm import lpt, make_diffrax_ode, pm_forces # noqa : E402
_TOLERANCE = 3.0 # 🙃🙃 _TOLERANCE = 1e-1 # 🙃🙃
pdims = [(1, 8), (8, 1), (4, 2), (2, 4)]
@pytest.mark.distributed @pytest.mark.distributed
@pytest.mark.parametrize("order", [1, 2]) @pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("pdims", pdims)
@pytest.mark.parametrize("absolute_painting", [True, False]) @pytest.mark.parametrize("absolute_painting", [True, False])
def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
absolute_painting): pdims, absolute_painting):
if absolute_painting:
pytest.skip("Absolute painting is not recommended in distributed mode")
painting_str = "absolute" if absolute_painting else "relative"
print("=" * 50)
print(f"Running with {painting_str} painting and pdims {pdims} ...")
mesh_shape, box_shape = simulation_config mesh_shape, box_shape = simulation_config
# SINGLE DEVICE RUN # SINGLE DEVICE RUN
@ -37,12 +52,12 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
particles, particles,
a=0.1, a=0.1,
order=order) order=order)
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape)) ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
y0 = jnp.stack([particles + dx, p]) y0 = jnp.stack([particles + dx, p])
else: else:
dx, p, _ = lpt(cosmo, initial_conditions, a=0.1, order=order) dx, p, _ = lpt(cosmo, initial_conditions, a=0.1, order=order)
ode_fn = ODETerm( ode_fn = ODETerm(make_diffrax_ode(mesh_shape,
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False)) paint_absolute_pos=False))
y0 = jnp.stack([dx, p]) y0 = jnp.stack([dx, p])
solver = Dopri5() solver = Dopri5()
@ -60,6 +75,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
t1=1.0, t1=1.0,
dt0=None, dt0=None,
y0=y0, y0=y0,
args=cosmo,
stepsize_controller=controller, stepsize_controller=controller,
saveat=saveat) saveat=saveat)
@ -72,7 +88,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
print("Done with single device run") print("Done with single device run")
# MULTI DEVICE RUN # MULTI DEVICE RUN
mesh = jax.make_mesh((1, 8), ('x', 'y')) mesh = jax.make_mesh(pdims, ('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y')) sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2 halo_size = mesh_shape[0] // 2
@ -94,8 +110,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
sharding=sharding) sharding=sharding)
ode_fn = ODETerm( ode_fn = ODETerm(
make_diffrax_ode(cosmo, make_diffrax_ode(mesh_shape,
mesh_shape,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding)) sharding=sharding))
@ -108,8 +123,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding) sharding=sharding)
ode_fn = ODETerm( ode_fn = ODETerm(
make_diffrax_ode(cosmo, make_diffrax_ode(mesh_shape,
mesh_shape,
paint_absolute_pos=False, paint_absolute_pos=False,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding)) sharding=sharding))
@ -130,16 +144,23 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
t1=1.0, t1=1.0,
dt0=None, dt0=None,
y0=y0, y0=y0,
args=cosmo,
stepsize_controller=controller, stepsize_controller=controller,
saveat=saveat) saveat=saveat)
final_field = solutions.ys[-1, 0]
print(f"Final field sharding is {final_field.sharding}")
assert final_field.sharding.is_equivalent_to(sharding , ndim=3) \
, f"Final field sharding is not correct .. should be {sharding} it is instead {final_field.sharding}"
if absolute_painting: if absolute_painting:
multi_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape), multi_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape),
solutions.ys[-1, 0], final_field,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding) sharding=sharding)
else: else:
multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0], multi_device_final_field = cic_paint_dx(final_field,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding) sharding=sharding)
@ -150,3 +171,230 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
print(f"MSE is {mse}") print(f"MSE is {mse}")
assert mse < _TOLERANCE assert mse < _TOLERANCE
@pytest.mark.distributed
@pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("pdims", pdims)
def test_distrubted_gradients(simulation_config, initial_conditions, cosmo,
order, nbody_from_lpt1, nbody_from_lpt2, pdims):
mesh_shape, box_shape = simulation_config
# SINGLE DEVICE RUN
cosmo._workspace = {}
mesh = jax.make_mesh(pdims, ('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2
initial_conditions = lax.with_sharding_constraint(initial_conditions,
sharding)
print(f"sharded initial conditions {initial_conditions.sharding}")
cosmo._workspace = {}
@jax.jit
def forward_model(initial_conditions, cosmo):
dx, p, _ = lpt(cosmo,
initial_conditions,
a=0.1,
order=order,
halo_size=halo_size,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(mesh_shape,
paint_absolute_pos=False,
halo_size=halo_size,
sharding=sharding))
y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p)
solver = Dopri5()
controller = PIDController(rtol=1e-8,
atol=1e-8,
pcoeff=0.4,
icoeff=1,
dcoeff=0)
saveat = SaveAt(t1=True)
solutions = diffeqsolve(ode_fn,
solver,
t0=0.1,
t1=1.0,
dt0=None,
y0=y0,
args=cosmo,
stepsize_controller=controller,
saveat=saveat)
multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0],
halo_size=halo_size,
sharding=sharding)
return multi_device_final_field
@jax.jit
def model(initial_conditions, cosmo):
final_field = forward_model(initial_conditions, cosmo)
return MSE(final_field,
nbody_from_lpt1 if order == 1 else nbody_from_lpt2)
obs_val = model(initial_conditions, cosmo)
shifted_initial_conditions = initial_conditions + jax.random.normal(
jax.random.key(42), initial_conditions.shape) * 5
good_grads = jax.grad(model)(initial_conditions, cosmo)
off_grads = jax.grad(model)(shifted_initial_conditions, cosmo)
assert good_grads.sharding.is_equivalent_to(initial_conditions.sharding,
ndim=3)
assert off_grads.sharding.is_equivalent_to(initial_conditions.sharding,
ndim=3)
@pytest.mark.distributed
@pytest.mark.parametrize("pdims", pdims)
def test_fwd_rev_gradients(cosmo, pdims):
mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0)
cosmo._workspace = {}
mesh = jax.make_mesh(pdims, ('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2
initial_conditions = jax.random.normal(jax.random.PRNGKey(42), mesh_shape)
initial_conditions = lax.with_sharding_constraint(initial_conditions,
sharding)
print(f"sharded initial conditions {initial_conditions.sharding}")
cosmo._workspace = {}
@partial(jax.jit, static_argnums=(2, 3, 4))
def compute_forces(initial_conditions,
cosmo,
a=0.5,
halo_size=0,
sharding=None):
paint_absolute_pos = False
particles = jnp.zeros_like(initial_conditions,
shape=(*initial_conditions.shape, 3))
a = jnp.atleast_1d(a)
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
initial_conditions = jax.lax.with_sharding_constraint(
initial_conditions, sharding)
delta_k = fft3d(initial_conditions)
out_sharding = get_fft_output_sharding(sharding)
delta_k = jax.lax.with_sharding_constraint(delta_k, out_sharding)
initial_force = pm_forces(particles,
delta=delta_k,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding)
return initial_force[..., 0]
forces = compute_forces(initial_conditions,
cosmo,
halo_size=halo_size,
sharding=sharding)
back_gradient = jax.jacrev(compute_forces)(initial_conditions,
cosmo,
halo_size=halo_size,
sharding=sharding)
fwd_gradient = jax.jacfwd(compute_forces)(initial_conditions,
cosmo,
halo_size=halo_size,
sharding=sharding)
print(f"Forces sharding is {forces.sharding}")
print(f"Backward gradient sharding is {back_gradient.sharding}")
print(f"Forward gradient sharding is {fwd_gradient.sharding}")
assert forces.sharding.is_equivalent_to(initial_conditions.sharding,
ndim=3)
assert back_gradient[0, 0, 0, ...].sharding.is_equivalent_to(
initial_conditions.sharding, ndim=3)
assert fwd_gradient.sharding.is_equivalent_to(initial_conditions.sharding,
ndim=3)
@pytest.mark.distributed
@pytest.mark.parametrize("pdims", pdims)
def test_vmap(cosmo, pdims):
mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0)
cosmo._workspace = {}
mesh = jax.make_mesh(pdims, ('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2
single_dev_initial_conditions = jax.random.normal(jax.random.PRNGKey(42),
mesh_shape)
initial_conditions = lax.with_sharding_constraint(
single_dev_initial_conditions, sharding)
single_ics = jnp.stack([
single_dev_initial_conditions, single_dev_initial_conditions,
single_dev_initial_conditions
])
sharded_ics = jnp.stack(
[initial_conditions, initial_conditions, initial_conditions])
print(f"unsharded initial conditions batch {single_ics.sharding}")
print(f"sharded initial conditions batch {sharded_ics.sharding}")
cosmo._workspace = {}
@partial(jax.jit, static_argnums=(2, 3, 4))
def compute_forces(initial_conditions,
cosmo,
a=0.5,
halo_size=0,
sharding=None):
paint_absolute_pos = False
particles = jnp.zeros_like(initial_conditions,
shape=(*initial_conditions.shape, 3))
a = jnp.atleast_1d(a)
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
initial_conditions = jax.lax.with_sharding_constraint(
initial_conditions, sharding)
delta_k = fft3d(initial_conditions)
out_sharding = get_fft_output_sharding(sharding)
delta_k = jax.lax.with_sharding_constraint(delta_k, out_sharding)
initial_force = pm_forces(particles,
delta=delta_k,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding)
return initial_force[..., 0]
def fn(ic):
return compute_forces(ic,
cosmo,
halo_size=halo_size,
sharding=sharding)
v_compute_forces = jax.vmap(fn)
print(f"single_ics shape {single_ics.shape}")
print(f"sharded_ics shape {sharded_ics.shape}")
single_dev_forces = v_compute_forces(single_ics)
sharded_forces = v_compute_forces(sharded_ics)
assert single_dev_forces.ndim == 4
assert sharded_forces.ndim == 4
print(f"Sharded forces {sharded_forces.sharding}")
assert sharded_forces[0].sharding.is_equivalent_to(
initial_conditions.sharding, ndim=3)
assert sharded_forces.sharding.spec[0] == None

View file

@ -39,7 +39,7 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
particles, particles,
a=lpt_scale_factor, a=lpt_scale_factor,
order=order) order=order)
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape)) ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
y0 = jnp.stack([particles + dx, p]) y0 = jnp.stack([particles + dx, p])
else: else:
@ -48,7 +48,7 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
a=lpt_scale_factor, a=lpt_scale_factor,
order=order) order=order)
ode_fn = ODETerm( ode_fn = ODETerm(
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False)) make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
y0 = jnp.stack([dx, p]) y0 = jnp.stack([dx, p])
solver = Dopri5() solver = Dopri5()
@ -66,6 +66,7 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
t1=1.0, t1=1.0,
dt0=None, dt0=None,
y0=y0, y0=y0,
args=cosmo,
adjoint=adjoint, adjoint=adjoint,
stepsize_controller=controller, stepsize_controller=controller,
saveat=saveat) saveat=saveat)