update compensate CIC

This commit is contained in:
Wassim KABALAN 2024-10-22 17:30:20 -04:00
parent cf799b6520
commit 0bb992fc56
2 changed files with 5 additions and 6 deletions

View file

@ -126,7 +126,7 @@ def cic_compensation(kvec):
wts: array
Complex kernel values
"""
kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
kwts = [jnp.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
return wts

View file

@ -6,7 +6,7 @@ import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
from jaxpm.distributed import (autoshmap, get_halo_size, halo_exchange,
slice_pad, slice_unpad)
slice_pad, slice_unpad, fft3d, ifft3d)
from jaxpm.kernels import cic_compensation, fftk
from jaxpm.painting_utils import gather, scatter
@ -230,9 +230,8 @@ def compensate_cic(field):
Returns:
compensated_field
"""
nc = field.shape
kvec = fftk(nc)
delta_k = fft3d(field)
delta_k = jnp.fft.rfftn(field)
kvec = fftk(delta_k)
delta_k = cic_compensation(kvec) * delta_k
return jnp.fft.irfftn(delta_k)
return ifft3d(delta_k)