diff --git a/jaxpm/pm.py b/jaxpm/pm.py index e34d584..9951e1c 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -172,8 +172,7 @@ def make_ode_fn(mesh_shape, return nbody_ode -def make_diffrax_ode(cosmo, - mesh_shape, +def make_diffrax_ode(mesh_shape, paint_absolute_pos=True, halo_size=0, sharding=None): @@ -183,6 +182,7 @@ def make_diffrax_ode(cosmo, state is a tuple (position, velocities) """ pos, vel = state + cosmo = args forces = pm_forces(pos, mesh_shape=mesh_shape, diff --git a/notebooks/02-Advanced_usage.ipynb b/notebooks/02-Advanced_usage.ipynb index cf7f611..9027ef2 100644 --- a/notebooks/02-Advanced_usage.ipynb +++ b/notebooks/02-Advanced_usage.ipynb @@ -32,7 +32,7 @@ "from jaxpm.painting import cic_paint , cic_paint_dx\n", "from jaxpm.pm import linear_field, lpt, make_diffrax_ode\n", "from jaxpm.distributed import uniform_particles\n", - "from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve" + "from diffrax import PIDController, Tsit5, ODETerm, SaveAt, diffeqsolve" ] }, { @@ -41,10 +41,9 @@ "source": [ "### Particle Mesh Simulation with Diffrax Leapfrog Solver\n", "\n", - "In this setup, we use the `LeapfrogMidpoint` solver from the `diffrax` library to evolve particle displacements over time in our Particle Mesh simulation. The novelty here is the use of a **Leapfrog solver** from `diffrax` for efficient, memory-saving time integration.\n", + "In this setup, we use the `Tsit5` solver from the `diffrax` library to evolve particle displacements over time in our Particle Mesh simulation.\n", "\n", - "- **Leapfrog Integration**: This symplectic integrator is well-suited for simulations of gravitational dynamics, preserving energy over long timescales and allowing larger time steps without sacrificing accuracy.\n", - "- **Efficient Displacement Tracking**: We initialize only displacements (`dx`) rather than absolute positions, which, combined with Leapfrog’s stability, enhances memory efficiency and speeds up computation.\n" + "- **Efficient Displacement Tracking**: We initialize only displacements (`dx`) rather than absolute positions, which uses a the `pmwd` cic_painting algorithm which is more memory efficient at the cost of being slightly slower\n" ] }, { @@ -84,10 +83,10 @@ " \n", " # Evolve the simulation forward\n", " ode_fn = ODETerm(\n", - " make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n", - " solver = LeapfrogMidpoint()\n", + " make_diffrax_ode(mesh_shape, paint_absolute_pos=False))\n", + " solver = Tsit5()\n", "\n", - " stepsize_controller = ConstantStepSize()\n", + " stepsize_controller = PIDController(rtol=1e-6 , atol=1e-6)\n", " res = diffeqsolve(ode_fn,\n", " solver,\n", " t0=0.1,\n", @@ -257,10 +256,10 @@ " \n", " # Evolve the simulation forward\n", " ode_fn = ODETerm(\n", - " make_diffrax_ode(cosmo, mesh_shape))\n", - " solver = LeapfrogMidpoint()\n", + " make_diffrax_ode(mesh_shape))\n", + " solver = Tsit5()\n", "\n", - " stepsize_controller = ConstantStepSize()\n", + " stepsize_controller = PIDController(rtol=1e-6 , atol=1e-6)\n", " res = diffeqsolve(ode_fn,\n", " solver,\n", " t0=0.1,\n", diff --git a/notebooks/03-MultiGPU_PM_Halo.ipynb b/notebooks/03-MultiGPU_PM_Halo.ipynb index 0a652d2..be0bfb8 100644 --- a/notebooks/03-MultiGPU_PM_Halo.ipynb +++ b/notebooks/03-MultiGPU_PM_Halo.ipynb @@ -90,7 +90,7 @@ "\n", "This cell configures a **2x4 device mesh** across 8 devices and sets up named sharding to distribute data efficiently.\n", "\n", - "- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid. `create_device_mesh(pdims)` initializes this layout across available GPUs.\n", + "- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid.\n", "- **Sharding with Mesh**: `Mesh(devices, axis_names=('x', 'y'))` assigns the mesh grid axes, which allows flexible mapping of array data across devices.\n", "- **PartitionSpec and NamedSharding**: `PartitionSpec` defines data partitioning across mesh axes `('x', 'y')`, and `NamedSharding(mesh, P('x', 'y'))` specifies this sharding scheme for arrays in the simulation.\n", "\n", @@ -99,21 +99,18 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "9edd2246", "metadata": {}, "outputs": [], "source": [ - "from jax.experimental.mesh_utils import create_device_mesh\n", "from jax.experimental.multihost_utils import process_allgather\n", - "from jax.sharding import Mesh, NamedSharding\n", - "from jax.sharding import PartitionSpec as P\n", + "from jax.sharding import PartitionSpec as P, NamedSharding\n", "\n", - "all_gather = partial(process_allgather, tiled=False)\n", + "all_gather = partial(process_allgather, tiled=True)\n", "\n", "pdims = (2, 4)\n", - "devices = create_device_mesh(pdims)\n", - "mesh = Mesh(devices, axis_names=('x', 'y'))\n", + "mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))\n", "sharding = NamedSharding(mesh, P('x', 'y'))" ] }, @@ -180,10 +177,10 @@ "\n", " # Evolve the simulation forward\n", " ode_fn = ODETerm(\n", - " make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n", - " solver = LeapfrogMidpoint()\n", + " make_diffrax_ode(mesh_shape, paint_absolute_pos=False , sharding=sharding))\n", + " solver = Tsit5()\n", "\n", - " stepsize_controller = ConstantStepSize()\n", + " stepsize_controller = PIDController(rtol=1e-6 , atol=1e-6)\n", " res = diffeqsolve(ode_fn,\n", " solver,\n", " t0=0.1,\n", @@ -410,10 +407,10 @@ "\n", " # Evolve the simulation forward\n", " ode_fn = ODETerm(\n", - " make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n", - " solver = LeapfrogMidpoint()\n", + " make_diffrax_ode(mesh_shape, paint_absolute_pos=False , sharding=sharding , halo_size=halo_size))\n", + " solver = Tsit5()\n", "\n", - " stepsize_controller = ConstantStepSize()\n", + " stepsize_controller = PIDController(rtol=1e-6 , atol=1e-6)\n", " res = diffeqsolve(ode_fn,\n", " solver,\n", " t0=0.1,\n", @@ -689,7 +686,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.4" + "version": "3.11.11" } }, "nbformat": 4, diff --git a/notebooks/04-MultiGPU_PM_Solvers.ipynb b/notebooks/04-MultiGPU_PM_Solvers.ipynb index 7671bc7..1c22f33 100644 --- a/notebooks/04-MultiGPU_PM_Solvers.ipynb +++ b/notebooks/04-MultiGPU_PM_Solvers.ipynb @@ -62,7 +62,7 @@ "\n", "This cell configures a **2x4 device mesh** across 8 devices and sets up named sharding to distribute data efficiently.\n", "\n", - "- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid. `create_device_mesh(pdims)` initializes this layout across available GPUs.\n", + "- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid.\n", "- **Sharding with Mesh**: `Mesh(devices, axis_names=('x', 'y'))` assigns the mesh grid axes, which allows flexible mapping of array data across devices.\n", "- **PartitionSpec and NamedSharding**: `PartitionSpec` defines data partitioning across mesh axes `('x', 'y')`, and `NamedSharding(mesh, P('x', 'y'))` specifies this sharding scheme for arrays in the simulation.\n", "\n", @@ -71,7 +71,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -80,11 +80,10 @@ "from jax.sharding import Mesh, NamedSharding\n", "from jax.sharding import PartitionSpec as P\n", "\n", - "all_gather = partial(process_allgather, tiled=False)\n", + "all_gather = partial(process_allgather, tiled=True)\n", "\n", "pdims = (2, 4)\n", - "devices = create_device_mesh(pdims)\n", - "mesh = Mesh(devices, axis_names=('x', 'y'))\n", + "mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))\n", "sharding = NamedSharding(mesh, P('x', 'y'))" ] }, @@ -124,7 +123,7 @@ "\n", " # Evolve the simulation forward\n", " ode_fn = ODETerm(\n", - " make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n", + " make_diffrax_ode(mesh_shape, paint_absolute_pos=False,sharding=sharding , halo_size=halo_size))\n", " solver = LeapfrogMidpoint()\n", "\n", " stepsize_controller = ConstantStepSize()\n", @@ -288,7 +287,7 @@ "\n", " # Evolve the simulation forward\n", " ode_fn = ODETerm(\n", - " make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n", + " make_diffrax_ode(mesh_shape, paint_absolute_pos=False,sharding=sharding , halo_size=halo_size))\n", " solver = Dopri5()\n", "\n", " stepsize_controller = PIDController(rtol=1e-5,atol=1e-5)\n", diff --git a/notebooks/05-MultiHost_PM.py b/notebooks/05-MultiHost_PM.py index da3964e..c41d1cf 100644 --- a/notebooks/05-MultiHost_PM.py +++ b/notebooks/05-MultiHost_PM.py @@ -17,9 +17,8 @@ import jax_cosmo as jc import numpy as np from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm, PIDController, SaveAt, diffeqsolve) -from jax.experimental.mesh_utils import create_device_mesh from jax.experimental.multihost_utils import process_allgather -from jax.sharding import Mesh, NamedSharding +from jax.sharding import NamedSharding from jax.sharding import PartitionSpec as P from jaxpm.kernels import interpolate_power_spectrum @@ -78,7 +77,7 @@ def parse_arguments(): def create_mesh_and_sharding(pdims): devices = create_device_mesh(pdims) - mesh = Mesh(devices, axis_names=('x', 'y')) + mesh = jax.make_mesh(pdims, axis_names=('x', 'y')) sharding = NamedSharding(mesh, P('x', 'y')) return mesh, sharding @@ -106,7 +105,10 @@ def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size, sharding=sharding) ode_fn = ODETerm( - make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False)) + make_diffrax_ode(mesh_shape, + paint_absolute_pos=False, + sharding=sharding, + halo_size=halo_size)) # Choose solver solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5() diff --git a/tests/test_distributed_pm.py b/tests/test_distributed_pm.py index fd683ab..eb44456 100644 --- a/tests/test_distributed_pm.py +++ b/tests/test_distributed_pm.py @@ -37,12 +37,12 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, particles, a=0.1, order=order) - ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape)) + ode_fn = ODETerm(make_diffrax_ode(mesh_shape)) y0 = jnp.stack([particles + dx, p]) else: dx, p, _ = lpt(cosmo, initial_conditions, a=0.1, order=order) - ode_fn = ODETerm( - make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False)) + ode_fn = ODETerm(make_diffrax_ode(mesh_shape, + paint_absolute_pos=False)) y0 = jnp.stack([dx, p]) solver = Dopri5() @@ -94,8 +94,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, sharding=sharding) ode_fn = ODETerm( - make_diffrax_ode(cosmo, - mesh_shape, + make_diffrax_ode(mesh_shape, halo_size=halo_size, sharding=sharding)) @@ -108,8 +107,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, halo_size=halo_size, sharding=sharding) ode_fn = ODETerm( - make_diffrax_ode(cosmo, - mesh_shape, + make_diffrax_ode(mesh_shape, paint_absolute_pos=False, halo_size=halo_size, sharding=sharding))