update compensate CIC

This commit is contained in:
Wassim KABALAN 2024-10-22 17:30:20 -04:00
parent cf799b6520
commit 0bb992fc56
2 changed files with 5 additions and 6 deletions

View file

@ -6,7 +6,7 @@ import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
from jaxpm.distributed import (autoshmap, get_halo_size, halo_exchange,
slice_pad, slice_unpad)
slice_pad, slice_unpad, fft3d, ifft3d)
from jaxpm.kernels import cic_compensation, fftk
from jaxpm.painting_utils import gather, scatter
@ -230,9 +230,8 @@ def compensate_cic(field):
Returns:
compensated_field
"""
nc = field.shape
kvec = fftk(nc)
delta_k = fft3d(field)
delta_k = jnp.fft.rfftn(field)
kvec = fftk(delta_k)
delta_k = cic_compensation(kvec) * delta_k
return jnp.fft.irfftn(delta_k)
return ifft3d(delta_k)