fix: use proper kernels for PM integration

This commit is contained in:
Guilhem Lavaux 2025-02-10 07:49:34 +01:00
parent cb2a7ab17f
commit 5868c71522
2 changed files with 17 additions and 11 deletions

View file

@ -38,7 +38,7 @@ def interpolate_power_spectrum(input, k, pk, sharding=None):
out_specs=out_specs)(input)
def gradient_kernel(kvec, direction, order=1):
def gradient_kernel(kvec, direction, fd=True):
"""
Computes the gradient kernel in the requested direction
Parameters
@ -53,16 +53,17 @@ def gradient_kernel(kvec, direction, order=1):
wts: array
Complex kernel values
"""
if order == 0:
if fd == False:
wts = 1j * kvec[direction]
wts = jnp.squeeze(wts)
wts[len(wts) // 2] = 0
wts = wts.at[len(wts) // 2].set(0)
wts = wts.reshape(kvec[direction].shape)
return wts
else:
w = kvec[direction]
a = 1 / 6.0 * (8 * jnp.sin(w) - jnp.sin(2 * w))
wts = a * 1j
#a = 1 / 6.0 * (8 * jnp.sin(w) - jnp.sin(2 * w))
wts = jnp.sin(w) * 1j
#wts = a * 1j
return wts
@ -85,7 +86,9 @@ def invlaplace_kernel(kvec, fd=False):
Complex kernel values
"""
if fd:
kk = sum((ki * jnp.sinc(ki / (2 * jnp.pi)))**2 for ki in kvec)
#kk = sum((ki * jnp.sinc(ki / (2 * jnp.pi)))**2 for ki in kvec)
print("new kernel")
kk = sum(4*(jnp.sin(ki/2)**2) for ki in kvec)
else:
kk = sum(ki**2 for ki in kvec)
kk_nozeros = jnp.where(kk == 0, 1, kk)

View file

@ -15,6 +15,7 @@ def pm_forces(positions,
r_split=0,
paint_absolute_pos=True,
halo_size=0,
fd=False,
sharding=None):
"""
Computes gravitational forces on particles using a PM scheme
@ -48,11 +49,11 @@ def pm_forces(positions,
kvec = fftk(delta_k)
# Computes gravitational potential
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(
pot_k = delta_k * invlaplace_kernel(kvec, fd=fd) * longrange_kernel(
kvec, r_split=r_split)
# Computes gravitational forces
forces = jnp.stack([
read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),positions
read_fn(ifft3d(-gradient_kernel(kvec, i, fd=fd) * pot_k),positions
) for i in range(3)], axis=-1) # yapf: disable
return forces
@ -81,6 +82,7 @@ def lpt(cosmo,
delta=delta_k,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
fd=False,
sharding=sharding)
dx = growth_factor(cosmo, a) * initial_force
p = a**2 * growth_rate(cosmo, a) * E * dx
@ -95,7 +97,7 @@ def lpt(cosmo,
for i in range(3):
# Add products of diagonal terms = 0 + s11*s00 + s22*(s11+s00)...
# shear_ii = jnp.fft.irfftn(- ki**2 * pot_k)
nabla_i_nabla_i = gradient_kernel(kvec, i)**2
nabla_i_nabla_i = gradient_kernel(kvec, i, fd=False)**2
shear_ii = ifft3d(nabla_i_nabla_i * pot_k)
delta2 += shear_ii * shear_acc
shear_acc += shear_ii
@ -104,8 +106,8 @@ def lpt(cosmo,
for j in range(i + 1, 3):
# Substract squared strict-up-triangle terms
# delta2 -= jnp.fft.irfftn(- ki * kj * pot_k)**2
nabla_i_nabla_j = gradient_kernel(kvec, i) * gradient_kernel(
kvec, j)
nabla_i_nabla_j = gradient_kernel(kvec, i, fd=False) * gradient_kernel(
kvec, j, fd=False)
delta2 -= ifft3d(nabla_i_nabla_j * pot_k)**2
delta_k2 = fft3d(delta2)
@ -113,6 +115,7 @@ def lpt(cosmo,
delta=delta_k2,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
fd=False,
sharding=sharding)
# NOTE: growth_factor_second is renormalized: - D2 = 3/7 * growth_factor_second
dx2 = 3 / 7 * growth_factor_second(cosmo, a) * init_force2