This commit is contained in:
Wassim KABALAN 2025-02-26 14:49:51 +01:00 committed by GitHub
commit 9b0e988910
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 134 additions and 97 deletions

View file

@ -172,8 +172,7 @@ def make_ode_fn(mesh_shape,
return nbody_ode
def make_diffrax_ode(cosmo,
mesh_shape,
def make_diffrax_ode(mesh_shape,
paint_absolute_pos=True,
halo_size=0,
sharding=None):
@ -183,6 +182,7 @@ def make_diffrax_ode(cosmo,
state is a tuple (position, velocities)
"""
pos, vel = state
cosmo = args
forces = pm_forces(pos,
mesh_shape=mesh_shape,

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -71,20 +71,17 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from jax.experimental.mesh_utils import create_device_mesh\n",
"from jax.experimental.multihost_utils import process_allgather\n",
"from jax.sharding import Mesh, NamedSharding\n",
"from jax.sharding import PartitionSpec as P\n",
"from jax.sharding import PartitionSpec as P, NamedSharding\n",
"\n",
"all_gather = partial(process_allgather, tiled=False)\n",
"\n",
"pdims = (2, 4)\n",
"devices = create_device_mesh(pdims)\n",
"mesh = Mesh(devices, axis_names=('x', 'y'))\n",
"mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))\n",
"sharding = NamedSharding(mesh, P('x', 'y'))"
]
},
@ -124,7 +121,7 @@
"\n",
" # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False , halo_size=halo_size , sharding=sharding))\n",
" solver = LeapfrogMidpoint()\n",
"\n",
" stepsize_controller = ConstantStepSize()\n",
@ -288,7 +285,7 @@
"\n",
" # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False , halo_size=halo_size , sharding=sharding))\n",
" solver = Dopri5()\n",
"\n",
" stepsize_controller = PIDController(rtol=1e-5,atol=1e-5)\n",

View file

@ -17,10 +17,8 @@ import jax_cosmo as jc
import numpy as np
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
PIDController, SaveAt, diffeqsolve)
from jax.experimental.mesh_utils import create_device_mesh
from jax.experimental.multihost_utils import process_allgather
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from jax.sharding import PartitionSpec as P, NamedSharding
from jaxpm.kernels import interpolate_power_spectrum
from jaxpm.painting import cic_paint_dx
@ -77,8 +75,8 @@ def parse_arguments():
def create_mesh_and_sharding(pdims):
devices = create_device_mesh(pdims)
mesh = Mesh(devices, axis_names=('x', 'y'))
mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
return mesh, sharding
@ -106,7 +104,7 @@ def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
make_diffrax_ode(mesh_shape, paint_absolute_pos=False , halo_size=halo_size , sharding=sharding))
# Choose solver
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()