forked from guilhem_lavaux/JaxPM
fix: use proper kernels for PM integration
This commit is contained in:
parent
cb2a7ab17f
commit
5868c71522
2 changed files with 17 additions and 11 deletions
|
@ -38,7 +38,7 @@ def interpolate_power_spectrum(input, k, pk, sharding=None):
|
|||
out_specs=out_specs)(input)
|
||||
|
||||
|
||||
def gradient_kernel(kvec, direction, order=1):
|
||||
def gradient_kernel(kvec, direction, fd=True):
|
||||
"""
|
||||
Computes the gradient kernel in the requested direction
|
||||
Parameters
|
||||
|
@ -53,16 +53,17 @@ def gradient_kernel(kvec, direction, order=1):
|
|||
wts: array
|
||||
Complex kernel values
|
||||
"""
|
||||
if order == 0:
|
||||
if fd == False:
|
||||
wts = 1j * kvec[direction]
|
||||
wts = jnp.squeeze(wts)
|
||||
wts[len(wts) // 2] = 0
|
||||
wts = wts.at[len(wts) // 2].set(0)
|
||||
wts = wts.reshape(kvec[direction].shape)
|
||||
return wts
|
||||
else:
|
||||
w = kvec[direction]
|
||||
a = 1 / 6.0 * (8 * jnp.sin(w) - jnp.sin(2 * w))
|
||||
wts = a * 1j
|
||||
#a = 1 / 6.0 * (8 * jnp.sin(w) - jnp.sin(2 * w))
|
||||
wts = jnp.sin(w) * 1j
|
||||
#wts = a * 1j
|
||||
return wts
|
||||
|
||||
|
||||
|
@ -85,7 +86,9 @@ def invlaplace_kernel(kvec, fd=False):
|
|||
Complex kernel values
|
||||
"""
|
||||
if fd:
|
||||
kk = sum((ki * jnp.sinc(ki / (2 * jnp.pi)))**2 for ki in kvec)
|
||||
#kk = sum((ki * jnp.sinc(ki / (2 * jnp.pi)))**2 for ki in kvec)
|
||||
print("new kernel")
|
||||
kk = sum(4*(jnp.sin(ki/2)**2) for ki in kvec)
|
||||
else:
|
||||
kk = sum(ki**2 for ki in kvec)
|
||||
kk_nozeros = jnp.where(kk == 0, 1, kk)
|
||||
|
|
13
jaxpm/pm.py
13
jaxpm/pm.py
|
@ -15,6 +15,7 @@ def pm_forces(positions,
|
|||
r_split=0,
|
||||
paint_absolute_pos=True,
|
||||
halo_size=0,
|
||||
fd=False,
|
||||
sharding=None):
|
||||
"""
|
||||
Computes gravitational forces on particles using a PM scheme
|
||||
|
@ -48,11 +49,11 @@ def pm_forces(positions,
|
|||
|
||||
kvec = fftk(delta_k)
|
||||
# Computes gravitational potential
|
||||
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(
|
||||
pot_k = delta_k * invlaplace_kernel(kvec, fd=fd) * longrange_kernel(
|
||||
kvec, r_split=r_split)
|
||||
# Computes gravitational forces
|
||||
forces = jnp.stack([
|
||||
read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),positions
|
||||
read_fn(ifft3d(-gradient_kernel(kvec, i, fd=fd) * pot_k),positions
|
||||
) for i in range(3)], axis=-1) # yapf: disable
|
||||
|
||||
return forces
|
||||
|
@ -81,6 +82,7 @@ def lpt(cosmo,
|
|||
delta=delta_k,
|
||||
paint_absolute_pos=paint_absolute_pos,
|
||||
halo_size=halo_size,
|
||||
fd=False,
|
||||
sharding=sharding)
|
||||
dx = growth_factor(cosmo, a) * initial_force
|
||||
p = a**2 * growth_rate(cosmo, a) * E * dx
|
||||
|
@ -95,7 +97,7 @@ def lpt(cosmo,
|
|||
for i in range(3):
|
||||
# Add products of diagonal terms = 0 + s11*s00 + s22*(s11+s00)...
|
||||
# shear_ii = jnp.fft.irfftn(- ki**2 * pot_k)
|
||||
nabla_i_nabla_i = gradient_kernel(kvec, i)**2
|
||||
nabla_i_nabla_i = gradient_kernel(kvec, i, fd=False)**2
|
||||
shear_ii = ifft3d(nabla_i_nabla_i * pot_k)
|
||||
delta2 += shear_ii * shear_acc
|
||||
shear_acc += shear_ii
|
||||
|
@ -104,8 +106,8 @@ def lpt(cosmo,
|
|||
for j in range(i + 1, 3):
|
||||
# Substract squared strict-up-triangle terms
|
||||
# delta2 -= jnp.fft.irfftn(- ki * kj * pot_k)**2
|
||||
nabla_i_nabla_j = gradient_kernel(kvec, i) * gradient_kernel(
|
||||
kvec, j)
|
||||
nabla_i_nabla_j = gradient_kernel(kvec, i, fd=False) * gradient_kernel(
|
||||
kvec, j, fd=False)
|
||||
delta2 -= ifft3d(nabla_i_nabla_j * pot_k)**2
|
||||
|
||||
delta_k2 = fft3d(delta2)
|
||||
|
@ -113,6 +115,7 @@ def lpt(cosmo,
|
|||
delta=delta_k2,
|
||||
paint_absolute_pos=paint_absolute_pos,
|
||||
halo_size=halo_size,
|
||||
fd=False,
|
||||
sharding=sharding)
|
||||
# NOTE: growth_factor_second is renormalized: - D2 = 3/7 * growth_factor_second
|
||||
dx2 = 3 / 7 * growth_factor_second(cosmo, a) * init_force2
|
||||
|
|
Loading…
Add table
Reference in a new issue