From 0bb992fc56ffaff88fdd7aab6e0eb521d5a5c616 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Tue, 22 Oct 2024 17:30:20 -0400 Subject: [PATCH] update compensate CIC --- jaxpm/kernels.py | 2 +- jaxpm/painting.py | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py index 9a12773..d333b8c 100644 --- a/jaxpm/kernels.py +++ b/jaxpm/kernels.py @@ -126,7 +126,7 @@ def cic_compensation(kvec): wts: array Complex kernel values """ - kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)] + kwts = [jnp.sinc(kvec[i] / (2 * np.pi)) for i in range(3)] wts = (kwts[0] * kwts[1] * kwts[2])**(-2) return wts diff --git a/jaxpm/painting.py b/jaxpm/painting.py index 643e85f..e1f52ed 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -6,7 +6,7 @@ import jax.numpy as jnp from jax.sharding import PartitionSpec as P from jaxpm.distributed import (autoshmap, get_halo_size, halo_exchange, - slice_pad, slice_unpad) + slice_pad, slice_unpad, fft3d, ifft3d) from jaxpm.kernels import cic_compensation, fftk from jaxpm.painting_utils import gather, scatter @@ -230,9 +230,8 @@ def compensate_cic(field): Returns: compensated_field """ - nc = field.shape - kvec = fftk(nc) + delta_k = fft3d(field) - delta_k = jnp.fft.rfftn(field) + kvec = fftk(delta_k) delta_k = cic_compensation(kvec) * delta_k - return jnp.fft.irfftn(delta_k) + return ifft3d(delta_k)