mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-08 04:40:53 +00:00
put back old functgion
This commit is contained in:
parent
a742065ffd
commit
6408aff1de
1 changed files with 27 additions and 0 deletions
27
jaxpm/pm.py
27
jaxpm/pm.py
|
@ -98,3 +98,30 @@ def make_ode_fn(mesh_shape):
|
||||||
return dpos, dvel
|
return dpos, dvel
|
||||||
|
|
||||||
return nbody_ode
|
return nbody_ode
|
||||||
|
|
||||||
|
|
||||||
|
def pgd_correction(pos, params):
|
||||||
|
"""
|
||||||
|
improve the short-range interactions of PM-Nbody simulations with potential gradient descent method, based on https://arxiv.org/abs/1804.00671
|
||||||
|
args:
|
||||||
|
pos: particle positions [npart, 3]
|
||||||
|
params: [alpha, kl, ks] pgd parameters
|
||||||
|
"""
|
||||||
|
kvec = fftk(mesh_shape)
|
||||||
|
|
||||||
|
delta = cic_paint(jnp.zeros(mesh_shape), pos)
|
||||||
|
alpha, kl, ks = params
|
||||||
|
delta_k = jnp.fft.rfftn(delta)
|
||||||
|
PGD_range = PGD_kernel(kvec, kl, ks)
|
||||||
|
|
||||||
|
pot_k_pgd = (delta_k * laplace_kernel(kvec)) * PGD_range
|
||||||
|
|
||||||
|
forces_pgd = jnp.stack([
|
||||||
|
cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k_pgd), pos)
|
||||||
|
for i in range(3)
|
||||||
|
],
|
||||||
|
axis=-1)
|
||||||
|
|
||||||
|
dpos_pgd = forces_pgd * alpha
|
||||||
|
|
||||||
|
return dpos_pgd
|
Loading…
Add table
Reference in a new issue