update examples

This commit is contained in:
Wassim Kabalan 2025-02-28 09:54:53 +01:00
parent 8e0f300572
commit 51ee4dd937
5 changed files with 95 additions and 139 deletions

View file

@ -62,7 +62,7 @@
"\n",
"This cell configures a **2x4 device mesh** across 8 devices and sets up named sharding to distribute data efficiently.\n",
"\n",
"- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid. `create_device_mesh(pdims)` initializes this layout across available GPUs.\n",
"- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid.\n",
"- **Sharding with Mesh**: `Mesh(devices, axis_names=('x', 'y'))` assigns the mesh grid axes, which allows flexible mapping of array data across devices.\n",
"- **PartitionSpec and NamedSharding**: `PartitionSpec` defines data partitioning across mesh axes `('x', 'y')`, and `NamedSharding(mesh, P('x', 'y'))` specifies this sharding scheme for arrays in the simulation.\n",
"\n",
@ -75,10 +75,12 @@
"metadata": {},
"outputs": [],
"source": [
"from jax.experimental.mesh_utils import create_device_mesh\n",
"from jax.experimental.multihost_utils import process_allgather\n",
"from jax.sharding import PartitionSpec as P, NamedSharding\n",
"from jax.sharding import Mesh, NamedSharding\n",
"from jax.sharding import PartitionSpec as P\n",
"\n",
"all_gather = partial(process_allgather, tiled=False)\n",
"all_gather = partial(process_allgather, tiled=True)\n",
"\n",
"pdims = (2, 4)\n",
"mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))\n",
@ -121,7 +123,7 @@
"\n",
" # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n",
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False , halo_size=halo_size , sharding=sharding))\n",
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False,sharding=sharding , halo_size=halo_size))\n",
" solver = LeapfrogMidpoint()\n",
"\n",
" stepsize_controller = ConstantStepSize()\n",
@ -285,7 +287,7 @@
"\n",
" # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n",
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False , halo_size=halo_size , sharding=sharding))\n",
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False,sharding=sharding , halo_size=halo_size))\n",
" solver = Dopri5()\n",
"\n",
" stepsize_controller = PIDController(rtol=1e-5,atol=1e-5)\n",