forked from Aquila-Consortium/JaxPM_highres
neural ode added
This commit is contained in:
parent
20fc4a5562
commit
d8a1dbe210
1 changed files with 37 additions and 1 deletions
38
jaxpm/pm.py
38
jaxpm/pm.py
|
@ -93,4 +93,40 @@ def pgd_correction(pos, params):
|
||||||
|
|
||||||
dpos_pgd = forces_pgd*alpha
|
dpos_pgd = forces_pgd*alpha
|
||||||
|
|
||||||
return dpos_pgd
|
return dpos_pgd
|
||||||
|
|
||||||
|
|
||||||
|
def make_neural_ode_fn(model, mesh_shape):
|
||||||
|
def neural_nbody_ode(state, a, cosmo, params):
|
||||||
|
"""
|
||||||
|
state is a tuple (position, velocities)
|
||||||
|
"""
|
||||||
|
pos, vel = state
|
||||||
|
kvec = fftk(mesh_shape)
|
||||||
|
|
||||||
|
delta = cic_paint(jnp.zeros(mesh_shape), pos)
|
||||||
|
|
||||||
|
delta_k = jnp.fft.rfftn(delta)
|
||||||
|
|
||||||
|
# Computes gravitational potential
|
||||||
|
pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec, r_split=0)
|
||||||
|
|
||||||
|
# Apply a correction filter
|
||||||
|
kk = jnp.sqrt(sum((ki/jnp.pi)**2 for ki in kvec))
|
||||||
|
pot_k = pot_k *(1. + model.apply(params, kk, jnp.atleast_1d(a)))
|
||||||
|
|
||||||
|
# Computes gravitational forces
|
||||||
|
forces = jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k), pos)
|
||||||
|
for i in range(3)],axis=-1)
|
||||||
|
|
||||||
|
forces = forces * 1.5 * cosmo.Omega_m
|
||||||
|
|
||||||
|
# Computes the update of position (drift)
|
||||||
|
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
|
||||||
|
|
||||||
|
# Computes the update of velocity (kick)
|
||||||
|
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
|
||||||
|
|
||||||
|
return dpos, dvel
|
||||||
|
return neural_nbody_ode
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue