mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-05-15 04:21:12 +00:00
update compensate CIC
This commit is contained in:
parent
cf799b6520
commit
0bb992fc56
2 changed files with 5 additions and 6 deletions
|
@ -6,7 +6,7 @@ import jax.numpy as jnp
|
|||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
from jaxpm.distributed import (autoshmap, get_halo_size, halo_exchange,
|
||||
slice_pad, slice_unpad)
|
||||
slice_pad, slice_unpad, fft3d, ifft3d)
|
||||
from jaxpm.kernels import cic_compensation, fftk
|
||||
from jaxpm.painting_utils import gather, scatter
|
||||
|
||||
|
@ -230,9 +230,8 @@ def compensate_cic(field):
|
|||
Returns:
|
||||
compensated_field
|
||||
"""
|
||||
nc = field.shape
|
||||
kvec = fftk(nc)
|
||||
delta_k = fft3d(field)
|
||||
|
||||
delta_k = jnp.fft.rfftn(field)
|
||||
kvec = fftk(delta_k)
|
||||
delta_k = cic_compensation(kvec) * delta_k
|
||||
return jnp.fft.irfftn(delta_k)
|
||||
return ifft3d(delta_k)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue