mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-04 11:10:53 +00:00
Merge f8325b1c67
into cb2a7ab17f
This commit is contained in:
commit
944a147bf0
6 changed files with 40 additions and 45 deletions
|
@ -172,8 +172,7 @@ def make_ode_fn(mesh_shape,
|
|||
return nbody_ode
|
||||
|
||||
|
||||
def make_diffrax_ode(cosmo,
|
||||
mesh_shape,
|
||||
def make_diffrax_ode(mesh_shape,
|
||||
paint_absolute_pos=True,
|
||||
halo_size=0,
|
||||
sharding=None):
|
||||
|
@ -183,6 +182,7 @@ def make_diffrax_ode(cosmo,
|
|||
state is a tuple (position, velocities)
|
||||
"""
|
||||
pos, vel = state
|
||||
cosmo = args
|
||||
|
||||
forces = pm_forces(pos,
|
||||
mesh_shape=mesh_shape,
|
||||
|
|
|
@ -32,7 +32,7 @@
|
|||
"from jaxpm.painting import cic_paint , cic_paint_dx\n",
|
||||
"from jaxpm.pm import linear_field, lpt, make_diffrax_ode\n",
|
||||
"from jaxpm.distributed import uniform_particles\n",
|
||||
"from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve"
|
||||
"from diffrax import PIDController, Tsit5, ODETerm, SaveAt, diffeqsolve"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -41,10 +41,9 @@
|
|||
"source": [
|
||||
"### Particle Mesh Simulation with Diffrax Leapfrog Solver\n",
|
||||
"\n",
|
||||
"In this setup, we use the `LeapfrogMidpoint` solver from the `diffrax` library to evolve particle displacements over time in our Particle Mesh simulation. The novelty here is the use of a **Leapfrog solver** from `diffrax` for efficient, memory-saving time integration.\n",
|
||||
"In this setup, we use the `Tsit5` solver from the `diffrax` library to evolve particle displacements over time in our Particle Mesh simulation.\n",
|
||||
"\n",
|
||||
"- **Leapfrog Integration**: This symplectic integrator is well-suited for simulations of gravitational dynamics, preserving energy over long timescales and allowing larger time steps without sacrificing accuracy.\n",
|
||||
"- **Efficient Displacement Tracking**: We initialize only displacements (`dx`) rather than absolute positions, which, combined with Leapfrog’s stability, enhances memory efficiency and speeds up computation.\n"
|
||||
"- **Efficient Displacement Tracking**: We initialize only displacements (`dx`) rather than absolute positions, which uses a the `pmwd` cic_painting algorithm which is more memory efficient at the cost of being slightly slower\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -84,10 +83,10 @@
|
|||
" \n",
|
||||
" # Evolve the simulation forward\n",
|
||||
" ode_fn = ODETerm(\n",
|
||||
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
|
||||
" solver = LeapfrogMidpoint()\n",
|
||||
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False))\n",
|
||||
" solver = Tsit5()\n",
|
||||
"\n",
|
||||
" stepsize_controller = ConstantStepSize()\n",
|
||||
" stepsize_controller = PIDController(rtol=1e-6 , atol=1e-6)\n",
|
||||
" res = diffeqsolve(ode_fn,\n",
|
||||
" solver,\n",
|
||||
" t0=0.1,\n",
|
||||
|
@ -257,10 +256,10 @@
|
|||
" \n",
|
||||
" # Evolve the simulation forward\n",
|
||||
" ode_fn = ODETerm(\n",
|
||||
" make_diffrax_ode(cosmo, mesh_shape))\n",
|
||||
" solver = LeapfrogMidpoint()\n",
|
||||
" make_diffrax_ode(mesh_shape))\n",
|
||||
" solver = Tsit5()\n",
|
||||
"\n",
|
||||
" stepsize_controller = ConstantStepSize()\n",
|
||||
" stepsize_controller = PIDController(rtol=1e-6 , atol=1e-6)\n",
|
||||
" res = diffeqsolve(ode_fn,\n",
|
||||
" solver,\n",
|
||||
" t0=0.1,\n",
|
||||
|
|
|
@ -90,7 +90,7 @@
|
|||
"\n",
|
||||
"This cell configures a **2x4 device mesh** across 8 devices and sets up named sharding to distribute data efficiently.\n",
|
||||
"\n",
|
||||
"- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid. `create_device_mesh(pdims)` initializes this layout across available GPUs.\n",
|
||||
"- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid.\n",
|
||||
"- **Sharding with Mesh**: `Mesh(devices, axis_names=('x', 'y'))` assigns the mesh grid axes, which allows flexible mapping of array data across devices.\n",
|
||||
"- **PartitionSpec and NamedSharding**: `PartitionSpec` defines data partitioning across mesh axes `('x', 'y')`, and `NamedSharding(mesh, P('x', 'y'))` specifies this sharding scheme for arrays in the simulation.\n",
|
||||
"\n",
|
||||
|
@ -99,21 +99,18 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"id": "9edd2246",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from jax.experimental.mesh_utils import create_device_mesh\n",
|
||||
"from jax.experimental.multihost_utils import process_allgather\n",
|
||||
"from jax.sharding import Mesh, NamedSharding\n",
|
||||
"from jax.sharding import PartitionSpec as P\n",
|
||||
"from jax.sharding import PartitionSpec as P, NamedSharding\n",
|
||||
"\n",
|
||||
"all_gather = partial(process_allgather, tiled=False)\n",
|
||||
"all_gather = partial(process_allgather, tiled=True)\n",
|
||||
"\n",
|
||||
"pdims = (2, 4)\n",
|
||||
"devices = create_device_mesh(pdims)\n",
|
||||
"mesh = Mesh(devices, axis_names=('x', 'y'))\n",
|
||||
"mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))\n",
|
||||
"sharding = NamedSharding(mesh, P('x', 'y'))"
|
||||
]
|
||||
},
|
||||
|
@ -180,10 +177,10 @@
|
|||
"\n",
|
||||
" # Evolve the simulation forward\n",
|
||||
" ode_fn = ODETerm(\n",
|
||||
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
|
||||
" solver = LeapfrogMidpoint()\n",
|
||||
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False , sharding=sharding))\n",
|
||||
" solver = Tsit5()\n",
|
||||
"\n",
|
||||
" stepsize_controller = ConstantStepSize()\n",
|
||||
" stepsize_controller = PIDController(rtol=1e-6 , atol=1e-6)\n",
|
||||
" res = diffeqsolve(ode_fn,\n",
|
||||
" solver,\n",
|
||||
" t0=0.1,\n",
|
||||
|
@ -410,10 +407,10 @@
|
|||
"\n",
|
||||
" # Evolve the simulation forward\n",
|
||||
" ode_fn = ODETerm(\n",
|
||||
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
|
||||
" solver = LeapfrogMidpoint()\n",
|
||||
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False , sharding=sharding , halo_size=halo_size))\n",
|
||||
" solver = Tsit5()\n",
|
||||
"\n",
|
||||
" stepsize_controller = ConstantStepSize()\n",
|
||||
" stepsize_controller = PIDController(rtol=1e-6 , atol=1e-6)\n",
|
||||
" res = diffeqsolve(ode_fn,\n",
|
||||
" solver,\n",
|
||||
" t0=0.1,\n",
|
||||
|
@ -689,7 +686,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.4"
|
||||
"version": "3.11.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
|
@ -62,7 +62,7 @@
|
|||
"\n",
|
||||
"This cell configures a **2x4 device mesh** across 8 devices and sets up named sharding to distribute data efficiently.\n",
|
||||
"\n",
|
||||
"- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid. `create_device_mesh(pdims)` initializes this layout across available GPUs.\n",
|
||||
"- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid.\n",
|
||||
"- **Sharding with Mesh**: `Mesh(devices, axis_names=('x', 'y'))` assigns the mesh grid axes, which allows flexible mapping of array data across devices.\n",
|
||||
"- **PartitionSpec and NamedSharding**: `PartitionSpec` defines data partitioning across mesh axes `('x', 'y')`, and `NamedSharding(mesh, P('x', 'y'))` specifies this sharding scheme for arrays in the simulation.\n",
|
||||
"\n",
|
||||
|
@ -71,7 +71,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -80,11 +80,10 @@
|
|||
"from jax.sharding import Mesh, NamedSharding\n",
|
||||
"from jax.sharding import PartitionSpec as P\n",
|
||||
"\n",
|
||||
"all_gather = partial(process_allgather, tiled=False)\n",
|
||||
"all_gather = partial(process_allgather, tiled=True)\n",
|
||||
"\n",
|
||||
"pdims = (2, 4)\n",
|
||||
"devices = create_device_mesh(pdims)\n",
|
||||
"mesh = Mesh(devices, axis_names=('x', 'y'))\n",
|
||||
"mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))\n",
|
||||
"sharding = NamedSharding(mesh, P('x', 'y'))"
|
||||
]
|
||||
},
|
||||
|
@ -124,7 +123,7 @@
|
|||
"\n",
|
||||
" # Evolve the simulation forward\n",
|
||||
" ode_fn = ODETerm(\n",
|
||||
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
|
||||
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False,sharding=sharding , halo_size=halo_size))\n",
|
||||
" solver = LeapfrogMidpoint()\n",
|
||||
"\n",
|
||||
" stepsize_controller = ConstantStepSize()\n",
|
||||
|
@ -288,7 +287,7 @@
|
|||
"\n",
|
||||
" # Evolve the simulation forward\n",
|
||||
" ode_fn = ODETerm(\n",
|
||||
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
|
||||
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False,sharding=sharding , halo_size=halo_size))\n",
|
||||
" solver = Dopri5()\n",
|
||||
"\n",
|
||||
" stepsize_controller = PIDController(rtol=1e-5,atol=1e-5)\n",
|
||||
|
|
|
@ -17,9 +17,8 @@ import jax_cosmo as jc
|
|||
import numpy as np
|
||||
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
|
||||
PIDController, SaveAt, diffeqsolve)
|
||||
from jax.experimental.mesh_utils import create_device_mesh
|
||||
from jax.experimental.multihost_utils import process_allgather
|
||||
from jax.sharding import Mesh, NamedSharding
|
||||
from jax.sharding import NamedSharding
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
from jaxpm.kernels import interpolate_power_spectrum
|
||||
|
@ -78,7 +77,7 @@ def parse_arguments():
|
|||
|
||||
def create_mesh_and_sharding(pdims):
|
||||
devices = create_device_mesh(pdims)
|
||||
mesh = Mesh(devices, axis_names=('x', 'y'))
|
||||
mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))
|
||||
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||
return mesh, sharding
|
||||
|
||||
|
@ -106,7 +105,10 @@ def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
|
|||
sharding=sharding)
|
||||
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
|
||||
make_diffrax_ode(mesh_shape,
|
||||
paint_absolute_pos=False,
|
||||
sharding=sharding,
|
||||
halo_size=halo_size))
|
||||
|
||||
# Choose solver
|
||||
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()
|
||||
|
|
|
@ -37,12 +37,12 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
|||
particles,
|
||||
a=0.1,
|
||||
order=order)
|
||||
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
|
||||
ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
|
||||
y0 = jnp.stack([particles + dx, p])
|
||||
else:
|
||||
dx, p, _ = lpt(cosmo, initial_conditions, a=0.1, order=order)
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
|
||||
ode_fn = ODETerm(make_diffrax_ode(mesh_shape,
|
||||
paint_absolute_pos=False))
|
||||
y0 = jnp.stack([dx, p])
|
||||
|
||||
solver = Dopri5()
|
||||
|
@ -94,8 +94,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
|||
sharding=sharding)
|
||||
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo,
|
||||
mesh_shape,
|
||||
make_diffrax_ode(mesh_shape,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding))
|
||||
|
||||
|
@ -108,8 +107,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
|||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo,
|
||||
mesh_shape,
|
||||
make_diffrax_ode(mesh_shape,
|
||||
paint_absolute_pos=False,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding))
|
||||
|
|
Loading…
Add table
Reference in a new issue