diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py index 912fe2f..f9d0097 100644 --- a/jaxpm/kernels.py +++ b/jaxpm/kernels.py @@ -38,7 +38,7 @@ def interpolate_power_spectrum(input, k, pk, sharding=None): out_specs=out_specs)(input) -def gradient_kernel(kvec, direction, order=1): +def gradient_kernel(kvec, direction, fd=True): """ Computes the gradient kernel in the requested direction Parameters @@ -53,16 +53,17 @@ def gradient_kernel(kvec, direction, order=1): wts: array Complex kernel values """ - if order == 0: + if fd == False: wts = 1j * kvec[direction] wts = jnp.squeeze(wts) - wts[len(wts) // 2] = 0 + wts = wts.at[len(wts) // 2].set(0) wts = wts.reshape(kvec[direction].shape) return wts else: w = kvec[direction] - a = 1 / 6.0 * (8 * jnp.sin(w) - jnp.sin(2 * w)) - wts = a * 1j + #a = 1 / 6.0 * (8 * jnp.sin(w) - jnp.sin(2 * w)) + wts = jnp.sin(w) * 1j + #wts = a * 1j return wts @@ -85,7 +86,9 @@ def invlaplace_kernel(kvec, fd=False): Complex kernel values """ if fd: - kk = sum((ki * jnp.sinc(ki / (2 * jnp.pi)))**2 for ki in kvec) + #kk = sum((ki * jnp.sinc(ki / (2 * jnp.pi)))**2 for ki in kvec) + print("new kernel") + kk = sum(4*(jnp.sin(ki/2)**2) for ki in kvec) else: kk = sum(ki**2 for ki in kvec) kk_nozeros = jnp.where(kk == 0, 1, kk) diff --git a/jaxpm/pm.py b/jaxpm/pm.py index e34d584..d073637 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -15,6 +15,7 @@ def pm_forces(positions, r_split=0, paint_absolute_pos=True, halo_size=0, + fd=False, sharding=None): """ Computes gravitational forces on particles using a PM scheme @@ -48,11 +49,11 @@ def pm_forces(positions, kvec = fftk(delta_k) # Computes gravitational potential - pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel( + pot_k = delta_k * invlaplace_kernel(kvec, fd=fd) * longrange_kernel( kvec, r_split=r_split) # Computes gravitational forces forces = jnp.stack([ - read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),positions + read_fn(ifft3d(-gradient_kernel(kvec, i, fd=fd) * pot_k),positions ) for i in range(3)], axis=-1) # yapf: disable return forces @@ -81,6 +82,7 @@ def lpt(cosmo, delta=delta_k, paint_absolute_pos=paint_absolute_pos, halo_size=halo_size, + fd=False, sharding=sharding) dx = growth_factor(cosmo, a) * initial_force p = a**2 * growth_rate(cosmo, a) * E * dx @@ -95,7 +97,7 @@ def lpt(cosmo, for i in range(3): # Add products of diagonal terms = 0 + s11*s00 + s22*(s11+s00)... # shear_ii = jnp.fft.irfftn(- ki**2 * pot_k) - nabla_i_nabla_i = gradient_kernel(kvec, i)**2 + nabla_i_nabla_i = gradient_kernel(kvec, i, fd=False)**2 shear_ii = ifft3d(nabla_i_nabla_i * pot_k) delta2 += shear_ii * shear_acc shear_acc += shear_ii @@ -104,8 +106,8 @@ def lpt(cosmo, for j in range(i + 1, 3): # Substract squared strict-up-triangle terms # delta2 -= jnp.fft.irfftn(- ki * kj * pot_k)**2 - nabla_i_nabla_j = gradient_kernel(kvec, i) * gradient_kernel( - kvec, j) + nabla_i_nabla_j = gradient_kernel(kvec, i, fd=False) * gradient_kernel( + kvec, j, fd=False) delta2 -= ifft3d(nabla_i_nabla_j * pot_k)**2 delta_k2 = fft3d(delta2) @@ -113,6 +115,7 @@ def lpt(cosmo, delta=delta_k2, paint_absolute_pos=paint_absolute_pos, halo_size=halo_size, + fd=False, sharding=sharding) # NOTE: growth_factor_second is renormalized: - D2 = 3/7 * growth_factor_second dx2 = 3 / 7 * growth_factor_second(cosmo, a) * init_force2