mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 12:20:54 +00:00
update compensate CIC
This commit is contained in:
parent
cf799b6520
commit
0bb992fc56
2 changed files with 5 additions and 6 deletions
|
@ -126,7 +126,7 @@ def cic_compensation(kvec):
|
|||
wts: array
|
||||
Complex kernel values
|
||||
"""
|
||||
kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
|
||||
kwts = [jnp.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
|
||||
wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
|
||||
return wts
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ import jax.numpy as jnp
|
|||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
from jaxpm.distributed import (autoshmap, get_halo_size, halo_exchange,
|
||||
slice_pad, slice_unpad)
|
||||
slice_pad, slice_unpad, fft3d, ifft3d)
|
||||
from jaxpm.kernels import cic_compensation, fftk
|
||||
from jaxpm.painting_utils import gather, scatter
|
||||
|
||||
|
@ -230,9 +230,8 @@ def compensate_cic(field):
|
|||
Returns:
|
||||
compensated_field
|
||||
"""
|
||||
nc = field.shape
|
||||
kvec = fftk(nc)
|
||||
delta_k = fft3d(field)
|
||||
|
||||
delta_k = jnp.fft.rfftn(field)
|
||||
kvec = fftk(delta_k)
|
||||
delta_k = cic_compensation(kvec) * delta_k
|
||||
return jnp.fft.irfftn(delta_k)
|
||||
return ifft3d(delta_k)
|
||||
|
|
Loading…
Add table
Reference in a new issue