diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py index 73a8c93..c3306c6 100644 --- a/jaxpm/kernels.py +++ b/jaxpm/kernels.py @@ -83,3 +83,18 @@ def longrange_kernel(kvec, r_split): return np.exp(-kk * r_split**2) else: return 1. + +def cic_compensation(kvec): + """ + Computes cic compensation kernel. + Adapted from https://github.com/bccp/nbodykit/blob/a387cf429d8cb4a07bb19e3b4325ffdf279a131e/nbodykit/source/mesh/catalog.py#L499 + Itself based on equation 18 (with p=2) of + `Jing et al 2005 `_ + Args: + kvec: array of k values in Fourier space + Returns: + v: array of kernel + """ + kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)] + wts = (kwts[0] * kwts[1] * kwts[2])**(-2) + return wts diff --git a/jaxpm/painting.py b/jaxpm/painting.py index 6eb1925..27a1900 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -2,6 +2,8 @@ import jax import jax.numpy as jnp import jax.lax as lax +from jaxpm.kernels import fftk, cic_compensation + def cic_paint(mesh, positions): """ Paints positions onto mesh mesh: [nx, ny, nz] @@ -49,3 +51,18 @@ def cic_read(mesh, positions): return (mesh[neighboor_coords[...,0], neighboor_coords[...,1], neighboor_coords[...,3]]*kernel).sum(axis=-1) + +def compensate_cic(field): + """ + Compensate for CiC painting + Args: + field: input 3D cic-painted field + Returns: + compensated_field + """ + nc = field.shape + kvec = fftk(nc) + + delta_k = jnp.fft.rfftn(field) + delta_k = cic_compensation(kvec) * delta_k + return jnp.fft.irfftn(delta_k) \ No newline at end of file