diff --git a/jaxpm/pm.py b/jaxpm/pm.py index 7d5b3a1..3155467 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -158,28 +158,6 @@ def make_ode_fn(mesh_shape, halo_size=0, sharding=None): return nbody_ode -def get_ode_fn(cosmo, mesh_shape): - - def nbody_ode(a, state, args): - """ - State is an array [position, velocities] - - Compatible with [Diffrax API](https://docs.kidger.site/diffrax/) - """ - pos, vel = state - forces = pm_forces(pos, mesh_shape) * 1.5 * cosmo.Omega_m - - # Computes the update of position (drift) - dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel - - # Computes the update of velocity (kick) - dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces - - return jnp.stack([dpos, dvel]) - - return nbody_ode - - def get_ode_fn(cosmo, mesh_shape, halo_size=0, sharding=None): def nbody_ode(a, state, args):