Merge remote-tracking branch 'upstream/main' into ASKabalan/jaxdecomp_proto

This commit is contained in:
Wassim KABALAN 2024-10-22 12:59:50 -04:00
commit 85cca44fb0
2 changed files with 21 additions and 4 deletions

View file

@ -43,7 +43,6 @@ def interpolate_power_spectrum(input, k, pk, sharding=None):
def gradient_kernel(kvec, direction, order=1):
"""
Computes the gradient kernel in the requested direction
Parameters
-----------
kvec: list
@ -98,12 +97,10 @@ def longrange_kernel(kvec, r_split):
List of wave-vectors
r_split: float
Splitting radius
Returns
--------
wts: array
Complex kernel values
TODO: @modichirag add documentation
"""
if r_split != 0:
@ -124,7 +121,6 @@ def cic_compensation(kvec):
-----------
kvec: list
List of wave-vectors
Returns:
--------
wts: array

View file

@ -157,6 +157,27 @@ def make_ode_fn(mesh_shape, halo_size=0, sharding=None):
return nbody_ode
def get_ode_fn(cosmo:Cosmology, mesh_shape):
def nbody_ode(a, state, args):
"""
State is an array [position, velocities]
Compatible with [Diffrax API](https://docs.kidger.site/diffrax/)
"""
pos, vel = state
forces = pm_forces(pos, mesh_shape) * 1.5 * cosmo.Omega_m
# Computes the update of position (drift)
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
# Computes the update of velocity (kick)
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
return jnp.stack([dpos, dvel])
return nbody_ode
def get_ode_fn(cosmo, mesh_shape, halo_size=0, sharding=None):