This commit is contained in:
Wassim Kabalan 2025-02-28 09:56:00 +01:00
parent 51ee4dd937
commit b43cb373a0

View file

@ -18,7 +18,8 @@ import numpy as np
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
PIDController, SaveAt, diffeqsolve)
from jax.experimental.multihost_utils import process_allgather
from jax.sharding import PartitionSpec as P, NamedSharding
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P
from jaxpm.kernels import interpolate_power_spectrum
from jaxpm.painting import cic_paint_dx
@ -104,7 +105,10 @@ def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(mesh_shape, paint_absolute_pos=False,sharding=sharding , halo_size=halo_size))
make_diffrax_ode(mesh_shape,
paint_absolute_pos=False,
sharding=sharding,
halo_size=halo_size))
# Choose solver
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()