diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py index f9d0097..912fe2f 100644 --- a/jaxpm/kernels.py +++ b/jaxpm/kernels.py @@ -38,7 +38,7 @@ def interpolate_power_spectrum(input, k, pk, sharding=None): out_specs=out_specs)(input) -def gradient_kernel(kvec, direction, fd=True): +def gradient_kernel(kvec, direction, order=1): """ Computes the gradient kernel in the requested direction Parameters @@ -53,17 +53,16 @@ def gradient_kernel(kvec, direction, fd=True): wts: array Complex kernel values """ - if fd == False: + if order == 0: wts = 1j * kvec[direction] wts = jnp.squeeze(wts) - wts = wts.at[len(wts) // 2].set(0) + wts[len(wts) // 2] = 0 wts = wts.reshape(kvec[direction].shape) return wts else: w = kvec[direction] - #a = 1 / 6.0 * (8 * jnp.sin(w) - jnp.sin(2 * w)) - wts = jnp.sin(w) * 1j - #wts = a * 1j + a = 1 / 6.0 * (8 * jnp.sin(w) - jnp.sin(2 * w)) + wts = a * 1j return wts @@ -86,9 +85,7 @@ def invlaplace_kernel(kvec, fd=False): Complex kernel values """ if fd: - #kk = sum((ki * jnp.sinc(ki / (2 * jnp.pi)))**2 for ki in kvec) - print("new kernel") - kk = sum(4*(jnp.sin(ki/2)**2) for ki in kvec) + kk = sum((ki * jnp.sinc(ki / (2 * jnp.pi)))**2 for ki in kvec) else: kk = sum(ki**2 for ki in kvec) kk_nozeros = jnp.where(kk == 0, 1, kk) diff --git a/jaxpm/pm.py b/jaxpm/pm.py index d073637..e34d584 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -15,7 +15,6 @@ def pm_forces(positions, r_split=0, paint_absolute_pos=True, halo_size=0, - fd=False, sharding=None): """ Computes gravitational forces on particles using a PM scheme @@ -49,11 +48,11 @@ def pm_forces(positions, kvec = fftk(delta_k) # Computes gravitational potential - pot_k = delta_k * invlaplace_kernel(kvec, fd=fd) * longrange_kernel( + pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel( kvec, r_split=r_split) # Computes gravitational forces forces = jnp.stack([ - read_fn(ifft3d(-gradient_kernel(kvec, i, fd=fd) * pot_k),positions + read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),positions ) for i in range(3)], axis=-1) # yapf: disable return forces @@ -82,7 +81,6 @@ def lpt(cosmo, delta=delta_k, paint_absolute_pos=paint_absolute_pos, halo_size=halo_size, - fd=False, sharding=sharding) dx = growth_factor(cosmo, a) * initial_force p = a**2 * growth_rate(cosmo, a) * E * dx @@ -97,7 +95,7 @@ def lpt(cosmo, for i in range(3): # Add products of diagonal terms = 0 + s11*s00 + s22*(s11+s00)... # shear_ii = jnp.fft.irfftn(- ki**2 * pot_k) - nabla_i_nabla_i = gradient_kernel(kvec, i, fd=False)**2 + nabla_i_nabla_i = gradient_kernel(kvec, i)**2 shear_ii = ifft3d(nabla_i_nabla_i * pot_k) delta2 += shear_ii * shear_acc shear_acc += shear_ii @@ -106,8 +104,8 @@ def lpt(cosmo, for j in range(i + 1, 3): # Substract squared strict-up-triangle terms # delta2 -= jnp.fft.irfftn(- ki * kj * pot_k)**2 - nabla_i_nabla_j = gradient_kernel(kvec, i, fd=False) * gradient_kernel( - kvec, j, fd=False) + nabla_i_nabla_j = gradient_kernel(kvec, i) * gradient_kernel( + kvec, j) delta2 -= ifft3d(nabla_i_nabla_j * pot_k)**2 delta_k2 = fft3d(delta2) @@ -115,7 +113,6 @@ def lpt(cosmo, delta=delta_k2, paint_absolute_pos=paint_absolute_pos, halo_size=halo_size, - fd=False, sharding=sharding) # NOTE: growth_factor_second is renormalized: - D2 = 3/7 * growth_factor_second dx2 = 3 / 7 * growth_factor_second(cosmo, a) * init_force2