mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-30 00:51:11 +00:00
Fix sharding error (#37)
* Use cosmo as arg for the ODE function * Update examples * format * notebook update * fix tests * add correct annotations for weights in painting and warning for cic_paint in distributed pm * update test_against_fpm * update distributed tests and add jacfwd jacrev and vmap tests * format * add Caveats to notebook readme * final touches * update Growth.py to allow using FastPM solver * fix 2D painting when input is (X , Y , 2) shape * update cic read halo size and notebooks examples * Allow env variable control of caching in growth * Format * update test jax version * update notebooks/03-MultiGPU_PM_Halo.ipynb * update numpy install in wf * update tolerance :) * reorganize install in test workflow * update tests * add mpi4py * update tests.yml * update tests * update wf * format * make normal_field signature consistent with jax.random.normal * update by default normal_field dtype to match JAX * format * debug test workflow * format * debug test workflow * updating tests * fix accuracy * fixed tolerance * adding caching * Update conftest.py * Update tolerance and precision settings in distributed PM tests * revererting back changes to growth.py --------- Co-authored-by: Francois Lanusse <fr.eiffel@gmail.com> Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>
This commit is contained in:
parent
cb2a7ab17f
commit
6693e5c725
17 changed files with 675 additions and 298 deletions
|
@ -62,7 +62,7 @@
|
|||
"\n",
|
||||
"This cell configures a **2x4 device mesh** across 8 devices and sets up named sharding to distribute data efficiently.\n",
|
||||
"\n",
|
||||
"- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid. `create_device_mesh(pdims)` initializes this layout across available GPUs.\n",
|
||||
"- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid.\n",
|
||||
"- **Sharding with Mesh**: `Mesh(devices, axis_names=('x', 'y'))` assigns the mesh grid axes, which allows flexible mapping of array data across devices.\n",
|
||||
"- **PartitionSpec and NamedSharding**: `PartitionSpec` defines data partitioning across mesh axes `('x', 'y')`, and `NamedSharding(mesh, P('x', 'y'))` specifies this sharding scheme for arrays in the simulation.\n",
|
||||
"\n",
|
||||
|
@ -71,7 +71,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -80,11 +80,10 @@
|
|||
"from jax.sharding import Mesh, NamedSharding\n",
|
||||
"from jax.sharding import PartitionSpec as P\n",
|
||||
"\n",
|
||||
"all_gather = partial(process_allgather, tiled=False)\n",
|
||||
"all_gather = partial(process_allgather, tiled=True)\n",
|
||||
"\n",
|
||||
"pdims = (2, 4)\n",
|
||||
"devices = create_device_mesh(pdims)\n",
|
||||
"mesh = Mesh(devices, axis_names=('x', 'y'))\n",
|
||||
"mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))\n",
|
||||
"sharding = NamedSharding(mesh, P('x', 'y'))"
|
||||
]
|
||||
},
|
||||
|
@ -124,7 +123,7 @@
|
|||
"\n",
|
||||
" # Evolve the simulation forward\n",
|
||||
" ode_fn = ODETerm(\n",
|
||||
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
|
||||
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False,sharding=sharding , halo_size=halo_size))\n",
|
||||
" solver = LeapfrogMidpoint()\n",
|
||||
"\n",
|
||||
" stepsize_controller = ConstantStepSize()\n",
|
||||
|
@ -288,7 +287,7 @@
|
|||
"\n",
|
||||
" # Evolve the simulation forward\n",
|
||||
" ode_fn = ODETerm(\n",
|
||||
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
|
||||
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False,sharding=sharding , halo_size=halo_size))\n",
|
||||
" solver = Dopri5()\n",
|
||||
"\n",
|
||||
" stepsize_controller = PIDController(rtol=1e-5,atol=1e-5)\n",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue