diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py index a7da0ee..9a12773 100644 --- a/jaxpm/kernels.py +++ b/jaxpm/kernels.py @@ -43,7 +43,6 @@ def interpolate_power_spectrum(input, k, pk, sharding=None): def gradient_kernel(kvec, direction, order=1): """ Computes the gradient kernel in the requested direction - Parameters ----------- kvec: list @@ -98,12 +97,10 @@ def longrange_kernel(kvec, r_split): List of wave-vectors r_split: float Splitting radius - Returns -------- wts: array Complex kernel values - TODO: @modichirag add documentation """ if r_split != 0: @@ -124,7 +121,6 @@ def cic_compensation(kvec): ----------- kvec: list List of wave-vectors - Returns: -------- wts: array diff --git a/jaxpm/pm.py b/jaxpm/pm.py index 3155467..8457baf 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -157,6 +157,27 @@ def make_ode_fn(mesh_shape, halo_size=0, sharding=None): return nbody_ode +def get_ode_fn(cosmo:Cosmology, mesh_shape): + + def nbody_ode(a, state, args): + """ + State is an array [position, velocities] + + Compatible with [Diffrax API](https://docs.kidger.site/diffrax/) + """ + pos, vel = state + forces = pm_forces(pos, mesh_shape) * 1.5 * cosmo.Omega_m + + # Computes the update of position (drift) + dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel + + # Computes the update of velocity (kick) + dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces + + return jnp.stack([dpos, dvel]) + + return nbody_ode + def get_ode_fn(cosmo, mesh_shape, halo_size=0, sharding=None):