update examples

This commit is contained in:
Wassim Kabalan 2025-02-28 09:54:53 +01:00
parent 8e0f300572
commit 51ee4dd937
5 changed files with 95 additions and 139 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -62,7 +62,7 @@
"\n",
"This cell configures a **2x4 device mesh** across 8 devices and sets up named sharding to distribute data efficiently.\n",
"\n",
"- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid. `create_device_mesh(pdims)` initializes this layout across available GPUs.\n",
"- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid.\n",
"- **Sharding with Mesh**: `Mesh(devices, axis_names=('x', 'y'))` assigns the mesh grid axes, which allows flexible mapping of array data across devices.\n",
"- **PartitionSpec and NamedSharding**: `PartitionSpec` defines data partitioning across mesh axes `('x', 'y')`, and `NamedSharding(mesh, P('x', 'y'))` specifies this sharding scheme for arrays in the simulation.\n",
"\n",
@ -75,10 +75,12 @@
"metadata": {},
"outputs": [],
"source": [
"from jax.experimental.mesh_utils import create_device_mesh\n",
"from jax.experimental.multihost_utils import process_allgather\n",
"from jax.sharding import PartitionSpec as P, NamedSharding\n",
"from jax.sharding import Mesh, NamedSharding\n",
"from jax.sharding import PartitionSpec as P\n",
"\n",
"all_gather = partial(process_allgather, tiled=False)\n",
"all_gather = partial(process_allgather, tiled=True)\n",
"\n",
"pdims = (2, 4)\n",
"mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))\n",
@ -121,7 +123,7 @@
"\n",
" # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n",
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False , halo_size=halo_size , sharding=sharding))\n",
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False,sharding=sharding , halo_size=halo_size))\n",
" solver = LeapfrogMidpoint()\n",
"\n",
" stepsize_controller = ConstantStepSize()\n",
@ -285,7 +287,7 @@
"\n",
" # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n",
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False , halo_size=halo_size , sharding=sharding))\n",
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False,sharding=sharding , halo_size=halo_size))\n",
" solver = Dopri5()\n",
"\n",
" stepsize_controller = PIDController(rtol=1e-5,atol=1e-5)\n",

View file

@ -75,7 +75,7 @@ def parse_arguments():
def create_mesh_and_sharding(pdims):
devices = create_device_mesh(pdims)
mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
return mesh, sharding
@ -104,7 +104,7 @@ def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(mesh_shape, paint_absolute_pos=False , halo_size=halo_size , sharding=sharding))
make_diffrax_ode(mesh_shape, paint_absolute_pos=False,sharding=sharding , halo_size=halo_size))
# Choose solver
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()