mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-05-14 12:01:12 +00:00
update examples
This commit is contained in:
parent
8e0f300572
commit
51ee4dd937
5 changed files with 95 additions and 139 deletions
|
@ -75,7 +75,7 @@ def parse_arguments():
|
|||
|
||||
|
||||
def create_mesh_and_sharding(pdims):
|
||||
|
||||
devices = create_device_mesh(pdims)
|
||||
mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))
|
||||
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||
return mesh, sharding
|
||||
|
@ -104,7 +104,7 @@ def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
|
|||
sharding=sharding)
|
||||
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(mesh_shape, paint_absolute_pos=False , halo_size=halo_size , sharding=sharding))
|
||||
make_diffrax_ode(mesh_shape, paint_absolute_pos=False,sharding=sharding , halo_size=halo_size))
|
||||
|
||||
# Choose solver
|
||||
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue