fix: use proper kernels for PM integration

This commit is contained in:
Guilhem Lavaux 2025-02-10 07:49:34 +01:00
parent cb2a7ab17f
commit 5868c71522
2 changed files with 17 additions and 11 deletions

View file

@ -38,7 +38,7 @@ def interpolate_power_spectrum(input, k, pk, sharding=None):
out_specs=out_specs)(input) out_specs=out_specs)(input)
def gradient_kernel(kvec, direction, order=1): def gradient_kernel(kvec, direction, fd=True):
""" """
Computes the gradient kernel in the requested direction Computes the gradient kernel in the requested direction
Parameters Parameters
@ -53,16 +53,17 @@ def gradient_kernel(kvec, direction, order=1):
wts: array wts: array
Complex kernel values Complex kernel values
""" """
if order == 0: if fd == False:
wts = 1j * kvec[direction] wts = 1j * kvec[direction]
wts = jnp.squeeze(wts) wts = jnp.squeeze(wts)
wts[len(wts) // 2] = 0 wts = wts.at[len(wts) // 2].set(0)
wts = wts.reshape(kvec[direction].shape) wts = wts.reshape(kvec[direction].shape)
return wts return wts
else: else:
w = kvec[direction] w = kvec[direction]
a = 1 / 6.0 * (8 * jnp.sin(w) - jnp.sin(2 * w)) #a = 1 / 6.0 * (8 * jnp.sin(w) - jnp.sin(2 * w))
wts = a * 1j wts = jnp.sin(w) * 1j
#wts = a * 1j
return wts return wts
@ -85,7 +86,9 @@ def invlaplace_kernel(kvec, fd=False):
Complex kernel values Complex kernel values
""" """
if fd: if fd:
kk = sum((ki * jnp.sinc(ki / (2 * jnp.pi)))**2 for ki in kvec) #kk = sum((ki * jnp.sinc(ki / (2 * jnp.pi)))**2 for ki in kvec)
print("new kernel")
kk = sum(4*(jnp.sin(ki/2)**2) for ki in kvec)
else: else:
kk = sum(ki**2 for ki in kvec) kk = sum(ki**2 for ki in kvec)
kk_nozeros = jnp.where(kk == 0, 1, kk) kk_nozeros = jnp.where(kk == 0, 1, kk)

View file

@ -15,6 +15,7 @@ def pm_forces(positions,
r_split=0, r_split=0,
paint_absolute_pos=True, paint_absolute_pos=True,
halo_size=0, halo_size=0,
fd=False,
sharding=None): sharding=None):
""" """
Computes gravitational forces on particles using a PM scheme Computes gravitational forces on particles using a PM scheme
@ -48,11 +49,11 @@ def pm_forces(positions,
kvec = fftk(delta_k) kvec = fftk(delta_k)
# Computes gravitational potential # Computes gravitational potential
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel( pot_k = delta_k * invlaplace_kernel(kvec, fd=fd) * longrange_kernel(
kvec, r_split=r_split) kvec, r_split=r_split)
# Computes gravitational forces # Computes gravitational forces
forces = jnp.stack([ forces = jnp.stack([
read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),positions read_fn(ifft3d(-gradient_kernel(kvec, i, fd=fd) * pot_k),positions
) for i in range(3)], axis=-1) # yapf: disable ) for i in range(3)], axis=-1) # yapf: disable
return forces return forces
@ -81,6 +82,7 @@ def lpt(cosmo,
delta=delta_k, delta=delta_k,
paint_absolute_pos=paint_absolute_pos, paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size, halo_size=halo_size,
fd=False,
sharding=sharding) sharding=sharding)
dx = growth_factor(cosmo, a) * initial_force dx = growth_factor(cosmo, a) * initial_force
p = a**2 * growth_rate(cosmo, a) * E * dx p = a**2 * growth_rate(cosmo, a) * E * dx
@ -95,7 +97,7 @@ def lpt(cosmo,
for i in range(3): for i in range(3):
# Add products of diagonal terms = 0 + s11*s00 + s22*(s11+s00)... # Add products of diagonal terms = 0 + s11*s00 + s22*(s11+s00)...
# shear_ii = jnp.fft.irfftn(- ki**2 * pot_k) # shear_ii = jnp.fft.irfftn(- ki**2 * pot_k)
nabla_i_nabla_i = gradient_kernel(kvec, i)**2 nabla_i_nabla_i = gradient_kernel(kvec, i, fd=False)**2
shear_ii = ifft3d(nabla_i_nabla_i * pot_k) shear_ii = ifft3d(nabla_i_nabla_i * pot_k)
delta2 += shear_ii * shear_acc delta2 += shear_ii * shear_acc
shear_acc += shear_ii shear_acc += shear_ii
@ -104,8 +106,8 @@ def lpt(cosmo,
for j in range(i + 1, 3): for j in range(i + 1, 3):
# Substract squared strict-up-triangle terms # Substract squared strict-up-triangle terms
# delta2 -= jnp.fft.irfftn(- ki * kj * pot_k)**2 # delta2 -= jnp.fft.irfftn(- ki * kj * pot_k)**2
nabla_i_nabla_j = gradient_kernel(kvec, i) * gradient_kernel( nabla_i_nabla_j = gradient_kernel(kvec, i, fd=False) * gradient_kernel(
kvec, j) kvec, j, fd=False)
delta2 -= ifft3d(nabla_i_nabla_j * pot_k)**2 delta2 -= ifft3d(nabla_i_nabla_j * pot_k)**2
delta_k2 = fft3d(delta2) delta_k2 = fft3d(delta2)
@ -113,6 +115,7 @@ def lpt(cosmo,
delta=delta_k2, delta=delta_k2,
paint_absolute_pos=paint_absolute_pos, paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size, halo_size=halo_size,
fd=False,
sharding=sharding) sharding=sharding)
# NOTE: growth_factor_second is renormalized: - D2 = 3/7 * growth_factor_second # NOTE: growth_factor_second is renormalized: - D2 = 3/7 * growth_factor_second
dx2 = 3 / 7 * growth_factor_second(cosmo, a) * init_force2 dx2 = 3 / 7 * growth_factor_second(cosmo, a) * init_force2