mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 20:30:54 +00:00
Small fix
This commit is contained in:
parent
85cca44fb0
commit
d28982eec7
1 changed files with 3 additions and 2 deletions
|
@ -157,7 +157,8 @@ def make_ode_fn(mesh_shape, halo_size=0, sharding=None):
|
||||||
|
|
||||||
return nbody_ode
|
return nbody_ode
|
||||||
|
|
||||||
def get_ode_fn(cosmo:Cosmology, mesh_shape):
|
|
||||||
|
def get_ode_fn(cosmo, mesh_shape):
|
||||||
|
|
||||||
def nbody_ode(a, state, args):
|
def nbody_ode(a, state, args):
|
||||||
"""
|
"""
|
||||||
|
@ -170,7 +171,7 @@ def get_ode_fn(cosmo:Cosmology, mesh_shape):
|
||||||
|
|
||||||
# Computes the update of position (drift)
|
# Computes the update of position (drift)
|
||||||
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
|
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
|
||||||
|
|
||||||
# Computes the update of velocity (kick)
|
# Computes the update of velocity (kick)
|
||||||
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
|
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue