update examples

This commit is contained in:
Wassim Kabalan 2025-02-28 09:54:53 +01:00
parent 8e0f300572
commit 51ee4dd937
5 changed files with 95 additions and 139 deletions

View file

@ -75,7 +75,7 @@ def parse_arguments():
def create_mesh_and_sharding(pdims):
devices = create_device_mesh(pdims)
mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
return mesh, sharding
@ -104,7 +104,7 @@ def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(mesh_shape, paint_absolute_pos=False , halo_size=halo_size , sharding=sharding))
make_diffrax_ode(mesh_shape, paint_absolute_pos=False,sharding=sharding , halo_size=halo_size))
# Choose solver
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()