forked from Aquila-Consortium/JaxPM_highres
fix: use proper kernels for PM integration
This commit is contained in:
parent
cb2a7ab17f
commit
5868c71522
2 changed files with 17 additions and 11 deletions
|
@ -38,7 +38,7 @@ def interpolate_power_spectrum(input, k, pk, sharding=None):
|
||||||
out_specs=out_specs)(input)
|
out_specs=out_specs)(input)
|
||||||
|
|
||||||
|
|
||||||
def gradient_kernel(kvec, direction, order=1):
|
def gradient_kernel(kvec, direction, fd=True):
|
||||||
"""
|
"""
|
||||||
Computes the gradient kernel in the requested direction
|
Computes the gradient kernel in the requested direction
|
||||||
Parameters
|
Parameters
|
||||||
|
@ -53,16 +53,17 @@ def gradient_kernel(kvec, direction, order=1):
|
||||||
wts: array
|
wts: array
|
||||||
Complex kernel values
|
Complex kernel values
|
||||||
"""
|
"""
|
||||||
if order == 0:
|
if fd == False:
|
||||||
wts = 1j * kvec[direction]
|
wts = 1j * kvec[direction]
|
||||||
wts = jnp.squeeze(wts)
|
wts = jnp.squeeze(wts)
|
||||||
wts[len(wts) // 2] = 0
|
wts = wts.at[len(wts) // 2].set(0)
|
||||||
wts = wts.reshape(kvec[direction].shape)
|
wts = wts.reshape(kvec[direction].shape)
|
||||||
return wts
|
return wts
|
||||||
else:
|
else:
|
||||||
w = kvec[direction]
|
w = kvec[direction]
|
||||||
a = 1 / 6.0 * (8 * jnp.sin(w) - jnp.sin(2 * w))
|
#a = 1 / 6.0 * (8 * jnp.sin(w) - jnp.sin(2 * w))
|
||||||
wts = a * 1j
|
wts = jnp.sin(w) * 1j
|
||||||
|
#wts = a * 1j
|
||||||
return wts
|
return wts
|
||||||
|
|
||||||
|
|
||||||
|
@ -85,7 +86,9 @@ def invlaplace_kernel(kvec, fd=False):
|
||||||
Complex kernel values
|
Complex kernel values
|
||||||
"""
|
"""
|
||||||
if fd:
|
if fd:
|
||||||
kk = sum((ki * jnp.sinc(ki / (2 * jnp.pi)))**2 for ki in kvec)
|
#kk = sum((ki * jnp.sinc(ki / (2 * jnp.pi)))**2 for ki in kvec)
|
||||||
|
print("new kernel")
|
||||||
|
kk = sum(4*(jnp.sin(ki/2)**2) for ki in kvec)
|
||||||
else:
|
else:
|
||||||
kk = sum(ki**2 for ki in kvec)
|
kk = sum(ki**2 for ki in kvec)
|
||||||
kk_nozeros = jnp.where(kk == 0, 1, kk)
|
kk_nozeros = jnp.where(kk == 0, 1, kk)
|
||||||
|
|
13
jaxpm/pm.py
13
jaxpm/pm.py
|
@ -15,6 +15,7 @@ def pm_forces(positions,
|
||||||
r_split=0,
|
r_split=0,
|
||||||
paint_absolute_pos=True,
|
paint_absolute_pos=True,
|
||||||
halo_size=0,
|
halo_size=0,
|
||||||
|
fd=False,
|
||||||
sharding=None):
|
sharding=None):
|
||||||
"""
|
"""
|
||||||
Computes gravitational forces on particles using a PM scheme
|
Computes gravitational forces on particles using a PM scheme
|
||||||
|
@ -48,11 +49,11 @@ def pm_forces(positions,
|
||||||
|
|
||||||
kvec = fftk(delta_k)
|
kvec = fftk(delta_k)
|
||||||
# Computes gravitational potential
|
# Computes gravitational potential
|
||||||
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(
|
pot_k = delta_k * invlaplace_kernel(kvec, fd=fd) * longrange_kernel(
|
||||||
kvec, r_split=r_split)
|
kvec, r_split=r_split)
|
||||||
# Computes gravitational forces
|
# Computes gravitational forces
|
||||||
forces = jnp.stack([
|
forces = jnp.stack([
|
||||||
read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),positions
|
read_fn(ifft3d(-gradient_kernel(kvec, i, fd=fd) * pot_k),positions
|
||||||
) for i in range(3)], axis=-1) # yapf: disable
|
) for i in range(3)], axis=-1) # yapf: disable
|
||||||
|
|
||||||
return forces
|
return forces
|
||||||
|
@ -81,6 +82,7 @@ def lpt(cosmo,
|
||||||
delta=delta_k,
|
delta=delta_k,
|
||||||
paint_absolute_pos=paint_absolute_pos,
|
paint_absolute_pos=paint_absolute_pos,
|
||||||
halo_size=halo_size,
|
halo_size=halo_size,
|
||||||
|
fd=False,
|
||||||
sharding=sharding)
|
sharding=sharding)
|
||||||
dx = growth_factor(cosmo, a) * initial_force
|
dx = growth_factor(cosmo, a) * initial_force
|
||||||
p = a**2 * growth_rate(cosmo, a) * E * dx
|
p = a**2 * growth_rate(cosmo, a) * E * dx
|
||||||
|
@ -95,7 +97,7 @@ def lpt(cosmo,
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
# Add products of diagonal terms = 0 + s11*s00 + s22*(s11+s00)...
|
# Add products of diagonal terms = 0 + s11*s00 + s22*(s11+s00)...
|
||||||
# shear_ii = jnp.fft.irfftn(- ki**2 * pot_k)
|
# shear_ii = jnp.fft.irfftn(- ki**2 * pot_k)
|
||||||
nabla_i_nabla_i = gradient_kernel(kvec, i)**2
|
nabla_i_nabla_i = gradient_kernel(kvec, i, fd=False)**2
|
||||||
shear_ii = ifft3d(nabla_i_nabla_i * pot_k)
|
shear_ii = ifft3d(nabla_i_nabla_i * pot_k)
|
||||||
delta2 += shear_ii * shear_acc
|
delta2 += shear_ii * shear_acc
|
||||||
shear_acc += shear_ii
|
shear_acc += shear_ii
|
||||||
|
@ -104,8 +106,8 @@ def lpt(cosmo,
|
||||||
for j in range(i + 1, 3):
|
for j in range(i + 1, 3):
|
||||||
# Substract squared strict-up-triangle terms
|
# Substract squared strict-up-triangle terms
|
||||||
# delta2 -= jnp.fft.irfftn(- ki * kj * pot_k)**2
|
# delta2 -= jnp.fft.irfftn(- ki * kj * pot_k)**2
|
||||||
nabla_i_nabla_j = gradient_kernel(kvec, i) * gradient_kernel(
|
nabla_i_nabla_j = gradient_kernel(kvec, i, fd=False) * gradient_kernel(
|
||||||
kvec, j)
|
kvec, j, fd=False)
|
||||||
delta2 -= ifft3d(nabla_i_nabla_j * pot_k)**2
|
delta2 -= ifft3d(nabla_i_nabla_j * pot_k)**2
|
||||||
|
|
||||||
delta_k2 = fft3d(delta2)
|
delta_k2 = fft3d(delta2)
|
||||||
|
@ -113,6 +115,7 @@ def lpt(cosmo,
|
||||||
delta=delta_k2,
|
delta=delta_k2,
|
||||||
paint_absolute_pos=paint_absolute_pos,
|
paint_absolute_pos=paint_absolute_pos,
|
||||||
halo_size=halo_size,
|
halo_size=halo_size,
|
||||||
|
fd=False,
|
||||||
sharding=sharding)
|
sharding=sharding)
|
||||||
# NOTE: growth_factor_second is renormalized: - D2 = 3/7 * growth_factor_second
|
# NOTE: growth_factor_second is renormalized: - D2 = 3/7 * growth_factor_second
|
||||||
dx2 = 3 / 7 * growth_factor_second(cosmo, a) * init_force2
|
dx2 = 3 / 7 * growth_factor_second(cosmo, a) * init_force2
|
||||||
|
|
Loading…
Add table
Reference in a new issue