mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-08 04:40:53 +00:00
update compensate CIC
This commit is contained in:
parent
cf799b6520
commit
0bb992fc56
2 changed files with 5 additions and 6 deletions
|
@ -126,7 +126,7 @@ def cic_compensation(kvec):
|
||||||
wts: array
|
wts: array
|
||||||
Complex kernel values
|
Complex kernel values
|
||||||
"""
|
"""
|
||||||
kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
|
kwts = [jnp.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
|
||||||
wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
|
wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
|
||||||
return wts
|
return wts
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ import jax.numpy as jnp
|
||||||
from jax.sharding import PartitionSpec as P
|
from jax.sharding import PartitionSpec as P
|
||||||
|
|
||||||
from jaxpm.distributed import (autoshmap, get_halo_size, halo_exchange,
|
from jaxpm.distributed import (autoshmap, get_halo_size, halo_exchange,
|
||||||
slice_pad, slice_unpad)
|
slice_pad, slice_unpad, fft3d, ifft3d)
|
||||||
from jaxpm.kernels import cic_compensation, fftk
|
from jaxpm.kernels import cic_compensation, fftk
|
||||||
from jaxpm.painting_utils import gather, scatter
|
from jaxpm.painting_utils import gather, scatter
|
||||||
|
|
||||||
|
@ -230,9 +230,8 @@ def compensate_cic(field):
|
||||||
Returns:
|
Returns:
|
||||||
compensated_field
|
compensated_field
|
||||||
"""
|
"""
|
||||||
nc = field.shape
|
delta_k = fft3d(field)
|
||||||
kvec = fftk(nc)
|
|
||||||
|
|
||||||
delta_k = jnp.fft.rfftn(field)
|
kvec = fftk(delta_k)
|
||||||
delta_k = cic_compensation(kvec) * delta_k
|
delta_k = cic_compensation(kvec) * delta_k
|
||||||
return jnp.fft.irfftn(delta_k)
|
return ifft3d(delta_k)
|
||||||
|
|
Loading…
Add table
Reference in a new issue