From d8a1dbe210d838ddc6467b41a1e8ae4be93306ea Mon Sep 17 00:00:00 2001 From: denise lanzieri Date: Sat, 11 Jun 2022 14:28:30 +0200 Subject: [PATCH] neural ode added --- jaxpm/pm.py | 38 +++++++++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/jaxpm/pm.py b/jaxpm/pm.py index d9870f7..d54a252 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -93,4 +93,40 @@ def pgd_correction(pos, params): dpos_pgd = forces_pgd*alpha - return dpos_pgd \ No newline at end of file + return dpos_pgd + + +def make_neural_ode_fn(model, mesh_shape): + def neural_nbody_ode(state, a, cosmo, params): + """ + state is a tuple (position, velocities) + """ + pos, vel = state + kvec = fftk(mesh_shape) + + delta = cic_paint(jnp.zeros(mesh_shape), pos) + + delta_k = jnp.fft.rfftn(delta) + + # Computes gravitational potential + pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec, r_split=0) + + # Apply a correction filter + kk = jnp.sqrt(sum((ki/jnp.pi)**2 for ki in kvec)) + pot_k = pot_k *(1. + model.apply(params, kk, jnp.atleast_1d(a))) + + # Computes gravitational forces + forces = jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k), pos) + for i in range(3)],axis=-1) + + forces = forces * 1.5 * cosmo.Omega_m + + # Computes the update of position (drift) + dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel + + # Computes the update of velocity (kick) + dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces + + return dpos, dvel + return neural_nbody_ode +