mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-18 17:10:54 +00:00
Merge remote-tracking branch 'upstream/main' into ASKabalan/jaxdecomp_proto
This commit is contained in:
commit
85cca44fb0
2 changed files with 21 additions and 4 deletions
|
@ -43,7 +43,6 @@ def interpolate_power_spectrum(input, k, pk, sharding=None):
|
||||||
def gradient_kernel(kvec, direction, order=1):
|
def gradient_kernel(kvec, direction, order=1):
|
||||||
"""
|
"""
|
||||||
Computes the gradient kernel in the requested direction
|
Computes the gradient kernel in the requested direction
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
-----------
|
-----------
|
||||||
kvec: list
|
kvec: list
|
||||||
|
@ -98,12 +97,10 @@ def longrange_kernel(kvec, r_split):
|
||||||
List of wave-vectors
|
List of wave-vectors
|
||||||
r_split: float
|
r_split: float
|
||||||
Splitting radius
|
Splitting radius
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
--------
|
--------
|
||||||
wts: array
|
wts: array
|
||||||
Complex kernel values
|
Complex kernel values
|
||||||
|
|
||||||
TODO: @modichirag add documentation
|
TODO: @modichirag add documentation
|
||||||
"""
|
"""
|
||||||
if r_split != 0:
|
if r_split != 0:
|
||||||
|
@ -124,7 +121,6 @@ def cic_compensation(kvec):
|
||||||
-----------
|
-----------
|
||||||
kvec: list
|
kvec: list
|
||||||
List of wave-vectors
|
List of wave-vectors
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
--------
|
--------
|
||||||
wts: array
|
wts: array
|
||||||
|
|
21
jaxpm/pm.py
21
jaxpm/pm.py
|
@ -157,6 +157,27 @@ def make_ode_fn(mesh_shape, halo_size=0, sharding=None):
|
||||||
|
|
||||||
return nbody_ode
|
return nbody_ode
|
||||||
|
|
||||||
|
def get_ode_fn(cosmo:Cosmology, mesh_shape):
|
||||||
|
|
||||||
|
def nbody_ode(a, state, args):
|
||||||
|
"""
|
||||||
|
State is an array [position, velocities]
|
||||||
|
|
||||||
|
Compatible with [Diffrax API](https://docs.kidger.site/diffrax/)
|
||||||
|
"""
|
||||||
|
pos, vel = state
|
||||||
|
forces = pm_forces(pos, mesh_shape) * 1.5 * cosmo.Omega_m
|
||||||
|
|
||||||
|
# Computes the update of position (drift)
|
||||||
|
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
|
||||||
|
|
||||||
|
# Computes the update of velocity (kick)
|
||||||
|
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
|
||||||
|
|
||||||
|
return jnp.stack([dpos, dvel])
|
||||||
|
|
||||||
|
return nbody_ode
|
||||||
|
|
||||||
|
|
||||||
def get_ode_fn(cosmo, mesh_shape, halo_size=0, sharding=None):
|
def get_ode_fn(cosmo, mesh_shape, halo_size=0, sharding=None):
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue