This commit is contained in:
Wassim KABALAN 2025-02-28 09:06:19 +00:00 committed by GitHub
commit 944a147bf0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 40 additions and 45 deletions

View file

@ -172,8 +172,7 @@ def make_ode_fn(mesh_shape,
return nbody_ode return nbody_ode
def make_diffrax_ode(cosmo, def make_diffrax_ode(mesh_shape,
mesh_shape,
paint_absolute_pos=True, paint_absolute_pos=True,
halo_size=0, halo_size=0,
sharding=None): sharding=None):
@ -183,6 +182,7 @@ def make_diffrax_ode(cosmo,
state is a tuple (position, velocities) state is a tuple (position, velocities)
""" """
pos, vel = state pos, vel = state
cosmo = args
forces = pm_forces(pos, forces = pm_forces(pos,
mesh_shape=mesh_shape, mesh_shape=mesh_shape,

View file

@ -32,7 +32,7 @@
"from jaxpm.painting import cic_paint , cic_paint_dx\n", "from jaxpm.painting import cic_paint , cic_paint_dx\n",
"from jaxpm.pm import linear_field, lpt, make_diffrax_ode\n", "from jaxpm.pm import linear_field, lpt, make_diffrax_ode\n",
"from jaxpm.distributed import uniform_particles\n", "from jaxpm.distributed import uniform_particles\n",
"from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve" "from diffrax import PIDController, Tsit5, ODETerm, SaveAt, diffeqsolve"
] ]
}, },
{ {
@ -41,10 +41,9 @@
"source": [ "source": [
"### Particle Mesh Simulation with Diffrax Leapfrog Solver\n", "### Particle Mesh Simulation with Diffrax Leapfrog Solver\n",
"\n", "\n",
"In this setup, we use the `LeapfrogMidpoint` solver from the `diffrax` library to evolve particle displacements over time in our Particle Mesh simulation. The novelty here is the use of a **Leapfrog solver** from `diffrax` for efficient, memory-saving time integration.\n", "In this setup, we use the `Tsit5` solver from the `diffrax` library to evolve particle displacements over time in our Particle Mesh simulation.\n",
"\n", "\n",
"- **Leapfrog Integration**: This symplectic integrator is well-suited for simulations of gravitational dynamics, preserving energy over long timescales and allowing larger time steps without sacrificing accuracy.\n", "- **Efficient Displacement Tracking**: We initialize only displacements (`dx`) rather than absolute positions, which uses a the `pmwd` cic_painting algorithm which is more memory efficient at the cost of being slightly slower\n"
"- **Efficient Displacement Tracking**: We initialize only displacements (`dx`) rather than absolute positions, which, combined with Leapfrogs stability, enhances memory efficiency and speeds up computation.\n"
] ]
}, },
{ {
@ -84,10 +83,10 @@
" \n", " \n",
" # Evolve the simulation forward\n", " # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n", " ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n", " make_diffrax_ode(mesh_shape, paint_absolute_pos=False))\n",
" solver = LeapfrogMidpoint()\n", " solver = Tsit5()\n",
"\n", "\n",
" stepsize_controller = ConstantStepSize()\n", " stepsize_controller = PIDController(rtol=1e-6 , atol=1e-6)\n",
" res = diffeqsolve(ode_fn,\n", " res = diffeqsolve(ode_fn,\n",
" solver,\n", " solver,\n",
" t0=0.1,\n", " t0=0.1,\n",
@ -257,10 +256,10 @@
" \n", " \n",
" # Evolve the simulation forward\n", " # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n", " ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape))\n", " make_diffrax_ode(mesh_shape))\n",
" solver = LeapfrogMidpoint()\n", " solver = Tsit5()\n",
"\n", "\n",
" stepsize_controller = ConstantStepSize()\n", " stepsize_controller = PIDController(rtol=1e-6 , atol=1e-6)\n",
" res = diffeqsolve(ode_fn,\n", " res = diffeqsolve(ode_fn,\n",
" solver,\n", " solver,\n",
" t0=0.1,\n", " t0=0.1,\n",

View file

@ -90,7 +90,7 @@
"\n", "\n",
"This cell configures a **2x4 device mesh** across 8 devices and sets up named sharding to distribute data efficiently.\n", "This cell configures a **2x4 device mesh** across 8 devices and sets up named sharding to distribute data efficiently.\n",
"\n", "\n",
"- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid. `create_device_mesh(pdims)` initializes this layout across available GPUs.\n", "- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid.\n",
"- **Sharding with Mesh**: `Mesh(devices, axis_names=('x', 'y'))` assigns the mesh grid axes, which allows flexible mapping of array data across devices.\n", "- **Sharding with Mesh**: `Mesh(devices, axis_names=('x', 'y'))` assigns the mesh grid axes, which allows flexible mapping of array data across devices.\n",
"- **PartitionSpec and NamedSharding**: `PartitionSpec` defines data partitioning across mesh axes `('x', 'y')`, and `NamedSharding(mesh, P('x', 'y'))` specifies this sharding scheme for arrays in the simulation.\n", "- **PartitionSpec and NamedSharding**: `PartitionSpec` defines data partitioning across mesh axes `('x', 'y')`, and `NamedSharding(mesh, P('x', 'y'))` specifies this sharding scheme for arrays in the simulation.\n",
"\n", "\n",
@ -99,21 +99,18 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": null,
"id": "9edd2246", "id": "9edd2246",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from jax.experimental.mesh_utils import create_device_mesh\n",
"from jax.experimental.multihost_utils import process_allgather\n", "from jax.experimental.multihost_utils import process_allgather\n",
"from jax.sharding import Mesh, NamedSharding\n", "from jax.sharding import PartitionSpec as P, NamedSharding\n",
"from jax.sharding import PartitionSpec as P\n",
"\n", "\n",
"all_gather = partial(process_allgather, tiled=False)\n", "all_gather = partial(process_allgather, tiled=True)\n",
"\n", "\n",
"pdims = (2, 4)\n", "pdims = (2, 4)\n",
"devices = create_device_mesh(pdims)\n", "mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))\n",
"mesh = Mesh(devices, axis_names=('x', 'y'))\n",
"sharding = NamedSharding(mesh, P('x', 'y'))" "sharding = NamedSharding(mesh, P('x', 'y'))"
] ]
}, },
@ -180,10 +177,10 @@
"\n", "\n",
" # Evolve the simulation forward\n", " # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n", " ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n", " make_diffrax_ode(mesh_shape, paint_absolute_pos=False , sharding=sharding))\n",
" solver = LeapfrogMidpoint()\n", " solver = Tsit5()\n",
"\n", "\n",
" stepsize_controller = ConstantStepSize()\n", " stepsize_controller = PIDController(rtol=1e-6 , atol=1e-6)\n",
" res = diffeqsolve(ode_fn,\n", " res = diffeqsolve(ode_fn,\n",
" solver,\n", " solver,\n",
" t0=0.1,\n", " t0=0.1,\n",
@ -410,10 +407,10 @@
"\n", "\n",
" # Evolve the simulation forward\n", " # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n", " ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n", " make_diffrax_ode(mesh_shape, paint_absolute_pos=False , sharding=sharding , halo_size=halo_size))\n",
" solver = LeapfrogMidpoint()\n", " solver = Tsit5()\n",
"\n", "\n",
" stepsize_controller = ConstantStepSize()\n", " stepsize_controller = PIDController(rtol=1e-6 , atol=1e-6)\n",
" res = diffeqsolve(ode_fn,\n", " res = diffeqsolve(ode_fn,\n",
" solver,\n", " solver,\n",
" t0=0.1,\n", " t0=0.1,\n",
@ -689,7 +686,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.4" "version": "3.11.11"
} }
}, },
"nbformat": 4, "nbformat": 4,

View file

@ -62,7 +62,7 @@
"\n", "\n",
"This cell configures a **2x4 device mesh** across 8 devices and sets up named sharding to distribute data efficiently.\n", "This cell configures a **2x4 device mesh** across 8 devices and sets up named sharding to distribute data efficiently.\n",
"\n", "\n",
"- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid. `create_device_mesh(pdims)` initializes this layout across available GPUs.\n", "- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid.\n",
"- **Sharding with Mesh**: `Mesh(devices, axis_names=('x', 'y'))` assigns the mesh grid axes, which allows flexible mapping of array data across devices.\n", "- **Sharding with Mesh**: `Mesh(devices, axis_names=('x', 'y'))` assigns the mesh grid axes, which allows flexible mapping of array data across devices.\n",
"- **PartitionSpec and NamedSharding**: `PartitionSpec` defines data partitioning across mesh axes `('x', 'y')`, and `NamedSharding(mesh, P('x', 'y'))` specifies this sharding scheme for arrays in the simulation.\n", "- **PartitionSpec and NamedSharding**: `PartitionSpec` defines data partitioning across mesh axes `('x', 'y')`, and `NamedSharding(mesh, P('x', 'y'))` specifies this sharding scheme for arrays in the simulation.\n",
"\n", "\n",
@ -71,7 +71,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -80,11 +80,10 @@
"from jax.sharding import Mesh, NamedSharding\n", "from jax.sharding import Mesh, NamedSharding\n",
"from jax.sharding import PartitionSpec as P\n", "from jax.sharding import PartitionSpec as P\n",
"\n", "\n",
"all_gather = partial(process_allgather, tiled=False)\n", "all_gather = partial(process_allgather, tiled=True)\n",
"\n", "\n",
"pdims = (2, 4)\n", "pdims = (2, 4)\n",
"devices = create_device_mesh(pdims)\n", "mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))\n",
"mesh = Mesh(devices, axis_names=('x', 'y'))\n",
"sharding = NamedSharding(mesh, P('x', 'y'))" "sharding = NamedSharding(mesh, P('x', 'y'))"
] ]
}, },
@ -124,7 +123,7 @@
"\n", "\n",
" # Evolve the simulation forward\n", " # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n", " ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n", " make_diffrax_ode(mesh_shape, paint_absolute_pos=False,sharding=sharding , halo_size=halo_size))\n",
" solver = LeapfrogMidpoint()\n", " solver = LeapfrogMidpoint()\n",
"\n", "\n",
" stepsize_controller = ConstantStepSize()\n", " stepsize_controller = ConstantStepSize()\n",
@ -288,7 +287,7 @@
"\n", "\n",
" # Evolve the simulation forward\n", " # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n", " ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n", " make_diffrax_ode(mesh_shape, paint_absolute_pos=False,sharding=sharding , halo_size=halo_size))\n",
" solver = Dopri5()\n", " solver = Dopri5()\n",
"\n", "\n",
" stepsize_controller = PIDController(rtol=1e-5,atol=1e-5)\n", " stepsize_controller = PIDController(rtol=1e-5,atol=1e-5)\n",

View file

@ -17,9 +17,8 @@ import jax_cosmo as jc
import numpy as np import numpy as np
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm, from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
PIDController, SaveAt, diffeqsolve) PIDController, SaveAt, diffeqsolve)
from jax.experimental.mesh_utils import create_device_mesh
from jax.experimental.multihost_utils import process_allgather from jax.experimental.multihost_utils import process_allgather
from jax.sharding import Mesh, NamedSharding from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P from jax.sharding import PartitionSpec as P
from jaxpm.kernels import interpolate_power_spectrum from jaxpm.kernels import interpolate_power_spectrum
@ -78,7 +77,7 @@ def parse_arguments():
def create_mesh_and_sharding(pdims): def create_mesh_and_sharding(pdims):
devices = create_device_mesh(pdims) devices = create_device_mesh(pdims)
mesh = Mesh(devices, axis_names=('x', 'y')) mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y')) sharding = NamedSharding(mesh, P('x', 'y'))
return mesh, sharding return mesh, sharding
@ -106,7 +105,10 @@ def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
sharding=sharding) sharding=sharding)
ode_fn = ODETerm( ode_fn = ODETerm(
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False)) make_diffrax_ode(mesh_shape,
paint_absolute_pos=False,
sharding=sharding,
halo_size=halo_size))
# Choose solver # Choose solver
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5() solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()

View file

@ -37,12 +37,12 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
particles, particles,
a=0.1, a=0.1,
order=order) order=order)
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape)) ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
y0 = jnp.stack([particles + dx, p]) y0 = jnp.stack([particles + dx, p])
else: else:
dx, p, _ = lpt(cosmo, initial_conditions, a=0.1, order=order) dx, p, _ = lpt(cosmo, initial_conditions, a=0.1, order=order)
ode_fn = ODETerm( ode_fn = ODETerm(make_diffrax_ode(mesh_shape,
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False)) paint_absolute_pos=False))
y0 = jnp.stack([dx, p]) y0 = jnp.stack([dx, p])
solver = Dopri5() solver = Dopri5()
@ -94,8 +94,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
sharding=sharding) sharding=sharding)
ode_fn = ODETerm( ode_fn = ODETerm(
make_diffrax_ode(cosmo, make_diffrax_ode(mesh_shape,
mesh_shape,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding)) sharding=sharding))
@ -108,8 +107,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding) sharding=sharding)
ode_fn = ODETerm( ode_fn = ODETerm(
make_diffrax_ode(cosmo, make_diffrax_ode(mesh_shape,
mesh_shape,
paint_absolute_pos=False, paint_absolute_pos=False,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding)) sharding=sharding))