Update examples

This commit is contained in:
Wassim Kabalan 2025-02-26 14:32:47 +01:00
parent 0b08c6f59a
commit 8e0f300572
5 changed files with 132 additions and 95 deletions

View file

@ -71,20 +71,17 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from jax.experimental.mesh_utils import create_device_mesh\n",
"from jax.experimental.multihost_utils import process_allgather\n",
"from jax.sharding import Mesh, NamedSharding\n",
"from jax.sharding import PartitionSpec as P\n",
"from jax.sharding import PartitionSpec as P, NamedSharding\n",
"\n",
"all_gather = partial(process_allgather, tiled=False)\n",
"\n",
"pdims = (2, 4)\n",
"devices = create_device_mesh(pdims)\n",
"mesh = Mesh(devices, axis_names=('x', 'y'))\n",
"mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))\n",
"sharding = NamedSharding(mesh, P('x', 'y'))"
]
},
@ -124,7 +121,7 @@
"\n",
" # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False , halo_size=halo_size , sharding=sharding))\n",
" solver = LeapfrogMidpoint()\n",
"\n",
" stepsize_controller = ConstantStepSize()\n",
@ -288,7 +285,7 @@
"\n",
" # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False , halo_size=halo_size , sharding=sharding))\n",
" solver = Dopri5()\n",
"\n",
" stepsize_controller = PIDController(rtol=1e-5,atol=1e-5)\n",