diff --git a/jaxpm/pm.py b/jaxpm/pm.py index e34d584..9951e1c 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -172,8 +172,7 @@ def make_ode_fn(mesh_shape, return nbody_ode -def make_diffrax_ode(cosmo, - mesh_shape, +def make_diffrax_ode(mesh_shape, paint_absolute_pos=True, halo_size=0, sharding=None): @@ -183,6 +182,7 @@ def make_diffrax_ode(cosmo, state is a tuple (position, velocities) """ pos, vel = state + cosmo = args forces = pm_forces(pos, mesh_shape=mesh_shape,