mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 12:20:54 +00:00
remove duplicate get_ode_fn
This commit is contained in:
parent
82f29877f5
commit
31ca41b0a7
1 changed files with 0 additions and 22 deletions
22
jaxpm/pm.py
22
jaxpm/pm.py
|
@ -158,28 +158,6 @@ def make_ode_fn(mesh_shape, halo_size=0, sharding=None):
|
|||
return nbody_ode
|
||||
|
||||
|
||||
def get_ode_fn(cosmo, mesh_shape):
|
||||
|
||||
def nbody_ode(a, state, args):
|
||||
"""
|
||||
State is an array [position, velocities]
|
||||
|
||||
Compatible with [Diffrax API](https://docs.kidger.site/diffrax/)
|
||||
"""
|
||||
pos, vel = state
|
||||
forces = pm_forces(pos, mesh_shape) * 1.5 * cosmo.Omega_m
|
||||
|
||||
# Computes the update of position (drift)
|
||||
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
|
||||
|
||||
# Computes the update of velocity (kick)
|
||||
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
|
||||
|
||||
return jnp.stack([dpos, dvel])
|
||||
|
||||
return nbody_ode
|
||||
|
||||
|
||||
def get_ode_fn(cosmo, mesh_shape, halo_size=0, sharding=None):
|
||||
|
||||
def nbody_ode(a, state, args):
|
||||
|
|
Loading…
Add table
Reference in a new issue