mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-04 11:10:53 +00:00
Compare commits
5 commits
9b0e988910
...
944a147bf0
Author | SHA1 | Date | |
---|---|---|---|
|
944a147bf0 | ||
|
f8325b1c67 | ||
|
f3b8f4160e | ||
|
b43cb373a0 | ||
|
51ee4dd937 |
6 changed files with 105 additions and 147 deletions
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -62,7 +62,7 @@
|
|||
"\n",
|
||||
"This cell configures a **2x4 device mesh** across 8 devices and sets up named sharding to distribute data efficiently.\n",
|
||||
"\n",
|
||||
"- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid. `create_device_mesh(pdims)` initializes this layout across available GPUs.\n",
|
||||
"- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid.\n",
|
||||
"- **Sharding with Mesh**: `Mesh(devices, axis_names=('x', 'y'))` assigns the mesh grid axes, which allows flexible mapping of array data across devices.\n",
|
||||
"- **PartitionSpec and NamedSharding**: `PartitionSpec` defines data partitioning across mesh axes `('x', 'y')`, and `NamedSharding(mesh, P('x', 'y'))` specifies this sharding scheme for arrays in the simulation.\n",
|
||||
"\n",
|
||||
|
@ -75,10 +75,12 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from jax.experimental.mesh_utils import create_device_mesh\n",
|
||||
"from jax.experimental.multihost_utils import process_allgather\n",
|
||||
"from jax.sharding import PartitionSpec as P, NamedSharding\n",
|
||||
"from jax.sharding import Mesh, NamedSharding\n",
|
||||
"from jax.sharding import PartitionSpec as P\n",
|
||||
"\n",
|
||||
"all_gather = partial(process_allgather, tiled=False)\n",
|
||||
"all_gather = partial(process_allgather, tiled=True)\n",
|
||||
"\n",
|
||||
"pdims = (2, 4)\n",
|
||||
"mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))\n",
|
||||
|
@ -121,7 +123,7 @@
|
|||
"\n",
|
||||
" # Evolve the simulation forward\n",
|
||||
" ode_fn = ODETerm(\n",
|
||||
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False , halo_size=halo_size , sharding=sharding))\n",
|
||||
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False,sharding=sharding , halo_size=halo_size))\n",
|
||||
" solver = LeapfrogMidpoint()\n",
|
||||
"\n",
|
||||
" stepsize_controller = ConstantStepSize()\n",
|
||||
|
@ -285,7 +287,7 @@
|
|||
"\n",
|
||||
" # Evolve the simulation forward\n",
|
||||
" ode_fn = ODETerm(\n",
|
||||
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False , halo_size=halo_size , sharding=sharding))\n",
|
||||
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False,sharding=sharding , halo_size=halo_size))\n",
|
||||
" solver = Dopri5()\n",
|
||||
"\n",
|
||||
" stepsize_controller = PIDController(rtol=1e-5,atol=1e-5)\n",
|
||||
|
|
|
@ -18,7 +18,8 @@ import numpy as np
|
|||
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
|
||||
PIDController, SaveAt, diffeqsolve)
|
||||
from jax.experimental.multihost_utils import process_allgather
|
||||
from jax.sharding import PartitionSpec as P, NamedSharding
|
||||
from jax.sharding import NamedSharding
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
from jaxpm.kernels import interpolate_power_spectrum
|
||||
from jaxpm.painting import cic_paint_dx
|
||||
|
@ -75,7 +76,7 @@ def parse_arguments():
|
|||
|
||||
|
||||
def create_mesh_and_sharding(pdims):
|
||||
|
||||
devices = create_device_mesh(pdims)
|
||||
mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))
|
||||
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||
return mesh, sharding
|
||||
|
@ -104,7 +105,10 @@ def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
|
|||
sharding=sharding)
|
||||
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(mesh_shape, paint_absolute_pos=False , halo_size=halo_size , sharding=sharding))
|
||||
make_diffrax_ode(mesh_shape,
|
||||
paint_absolute_pos=False,
|
||||
sharding=sharding,
|
||||
halo_size=halo_size))
|
||||
|
||||
# Choose solver
|
||||
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()
|
||||
|
|
|
@ -37,12 +37,12 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
|||
particles,
|
||||
a=0.1,
|
||||
order=order)
|
||||
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
|
||||
ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
|
||||
y0 = jnp.stack([particles + dx, p])
|
||||
else:
|
||||
dx, p, _ = lpt(cosmo, initial_conditions, a=0.1, order=order)
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
|
||||
ode_fn = ODETerm(make_diffrax_ode(mesh_shape,
|
||||
paint_absolute_pos=False))
|
||||
y0 = jnp.stack([dx, p])
|
||||
|
||||
solver = Dopri5()
|
||||
|
@ -94,8 +94,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
|||
sharding=sharding)
|
||||
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo,
|
||||
mesh_shape,
|
||||
make_diffrax_ode(mesh_shape,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding))
|
||||
|
||||
|
@ -108,8 +107,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
|||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo,
|
||||
mesh_shape,
|
||||
make_diffrax_ode(mesh_shape,
|
||||
paint_absolute_pos=False,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding))
|
||||
|
|
Loading…
Add table
Reference in a new issue