Small fix

This commit is contained in:
Wassim KABALAN 2024-10-22 13:00:25 -04:00
parent 85cca44fb0
commit d28982eec7

View file

@ -157,7 +157,8 @@ def make_ode_fn(mesh_shape, halo_size=0, sharding=None):
return nbody_ode
def get_ode_fn(cosmo:Cosmology, mesh_shape):
def get_ode_fn(cosmo, mesh_shape):
def nbody_ode(a, state, args):
"""
@ -170,7 +171,7 @@ def get_ode_fn(cosmo:Cosmology, mesh_shape):
# Computes the update of position (drift)
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
# Computes the update of velocity (kick)
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces