From d28982eec7efdd050330fcce66aaa34bff920e6b Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Tue, 22 Oct 2024 13:00:25 -0400 Subject: [PATCH] Small fix --- jaxpm/pm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/jaxpm/pm.py b/jaxpm/pm.py index 8457baf..7d5b3a1 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -157,7 +157,8 @@ def make_ode_fn(mesh_shape, halo_size=0, sharding=None): return nbody_ode -def get_ode_fn(cosmo:Cosmology, mesh_shape): + +def get_ode_fn(cosmo, mesh_shape): def nbody_ode(a, state, args): """ @@ -170,7 +171,7 @@ def get_ode_fn(cosmo:Cosmology, mesh_shape): # Computes the update of position (drift) dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel - + # Computes the update of velocity (kick) dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces