diff --git a/jaxpm/pm.py b/jaxpm/pm.py index 8457baf..7d5b3a1 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -157,7 +157,8 @@ def make_ode_fn(mesh_shape, halo_size=0, sharding=None): return nbody_ode -def get_ode_fn(cosmo:Cosmology, mesh_shape): + +def get_ode_fn(cosmo, mesh_shape): def nbody_ode(a, state, args): """ @@ -170,7 +171,7 @@ def get_ode_fn(cosmo:Cosmology, mesh_shape): # Computes the update of position (drift) dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel - + # Computes the update of velocity (kick) dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces