mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 12:20:54 +00:00
put back old functgion
This commit is contained in:
parent
a742065ffd
commit
6408aff1de
1 changed files with 27 additions and 0 deletions
27
jaxpm/pm.py
27
jaxpm/pm.py
|
@ -98,3 +98,30 @@ def make_ode_fn(mesh_shape):
|
|||
return dpos, dvel
|
||||
|
||||
return nbody_ode
|
||||
|
||||
|
||||
def pgd_correction(pos, params):
|
||||
"""
|
||||
improve the short-range interactions of PM-Nbody simulations with potential gradient descent method, based on https://arxiv.org/abs/1804.00671
|
||||
args:
|
||||
pos: particle positions [npart, 3]
|
||||
params: [alpha, kl, ks] pgd parameters
|
||||
"""
|
||||
kvec = fftk(mesh_shape)
|
||||
|
||||
delta = cic_paint(jnp.zeros(mesh_shape), pos)
|
||||
alpha, kl, ks = params
|
||||
delta_k = jnp.fft.rfftn(delta)
|
||||
PGD_range = PGD_kernel(kvec, kl, ks)
|
||||
|
||||
pot_k_pgd = (delta_k * laplace_kernel(kvec)) * PGD_range
|
||||
|
||||
forces_pgd = jnp.stack([
|
||||
cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k_pgd), pos)
|
||||
for i in range(3)
|
||||
],
|
||||
axis=-1)
|
||||
|
||||
dpos_pgd = forces_pgd * alpha
|
||||
|
||||
return dpos_pgd
|
Loading…
Add table
Reference in a new issue