mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-02-23 10:00:54 +00:00
kernels now accept pytrees
This commit is contained in:
parent
f5755b4b5d
commit
204a9526ec
1 changed files with 13 additions and 13 deletions
|
@ -3,7 +3,7 @@ import numpy as np
|
||||||
from jax.lax import FftType
|
from jax.lax import FftType
|
||||||
from jax.sharding import PartitionSpec as P
|
from jax.sharding import PartitionSpec as P
|
||||||
from jaxdecomp import fftfreq3d, get_output_specs
|
from jaxdecomp import fftfreq3d, get_output_specs
|
||||||
|
import jax
|
||||||
from jaxpm.distributed import autoshmap
|
from jaxpm.distributed import autoshmap
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,13 +19,13 @@ def fftk(k_array):
|
||||||
the order [kx, ky, kz].
|
the order [kx, ky, kz].
|
||||||
"""
|
"""
|
||||||
kx, ky, kz = fftfreq3d(k_array)
|
kx, ky, kz = fftfreq3d(k_array)
|
||||||
# to the order of dimensions in the transposed FFT
|
|
||||||
return kx, ky, kz
|
return kx, ky, kz
|
||||||
|
|
||||||
|
|
||||||
def interpolate_power_spectrum(input, k, pk, sharding=None):
|
def interpolate_power_spectrum(input, k, pk, sharding=None):
|
||||||
|
|
||||||
pk_fn = lambda x: jnp.interp(x.reshape(-1), k, pk).reshape(x.shape)
|
def pk_fn(input):
|
||||||
|
return jax.tree.map(lambda x: jnp.interp(x.reshape(-1), k, pk).reshape(x.shape), input)
|
||||||
|
|
||||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||||
specs = sharding.spec if sharding is not None else P()
|
specs = sharding.spec if sharding is not None else P()
|
||||||
|
@ -55,13 +55,13 @@ def gradient_kernel(kvec, direction, order=1):
|
||||||
"""
|
"""
|
||||||
if order == 0:
|
if order == 0:
|
||||||
wts = 1j * kvec[direction]
|
wts = 1j * kvec[direction]
|
||||||
wts = jnp.squeeze(wts)
|
wts = jax.tree.map(jnp.squeeze, wts)
|
||||||
wts[len(wts) // 2] = 0
|
wts[len(wts) // 2] = 0
|
||||||
wts = wts.reshape(kvec[direction].shape)
|
wts = wts.reshape(kvec[direction].shape)
|
||||||
return wts
|
return wts
|
||||||
else:
|
else:
|
||||||
w = kvec[direction]
|
w = kvec[direction]
|
||||||
a = 1 / 6.0 * (8 * jnp.sin(w) - jnp.sin(2 * w))
|
a = jax.tree.map(lambda x: 1 / 6.0 * (8 * jnp.sin(x) - jnp.sin(2 * x)), w)
|
||||||
wts = a * 1j
|
wts = a * 1j
|
||||||
return wts
|
return wts
|
||||||
|
|
||||||
|
@ -85,11 +85,11 @@ def invlaplace_kernel(kvec, fd=False):
|
||||||
Complex kernel values
|
Complex kernel values
|
||||||
"""
|
"""
|
||||||
if fd:
|
if fd:
|
||||||
kk = sum((ki * jnp.sinc(ki / (2 * jnp.pi)))**2 for ki in kvec)
|
kk = sum(jax.tree.map(lambda x: (x * jnp.sinc(x / (2 * jnp.pi)))**2, ki) for ki in kvec)
|
||||||
else:
|
else:
|
||||||
kk = sum(ki**2 for ki in kvec)
|
kk = sum(jax.tree.map(lambda x: x**2, ki) for ki in kvec)
|
||||||
kk_nozeros = jnp.where(kk == 0, 1, kk)
|
kk_nozeros = jax.tree.map(lambda x: jnp.where(x == 0, 1, x), kk)
|
||||||
return -jnp.where(kk == 0, 0, 1 / kk_nozeros)
|
return jax.tree.map(lambda x , y : -jnp.where(y == 0, 0, 1 / x), kk_nozeros, kk)
|
||||||
|
|
||||||
|
|
||||||
def longrange_kernel(kvec, r_split):
|
def longrange_kernel(kvec, r_split):
|
||||||
|
@ -110,7 +110,7 @@ def longrange_kernel(kvec, r_split):
|
||||||
"""
|
"""
|
||||||
if r_split != 0:
|
if r_split != 0:
|
||||||
kk = sum(ki**2 for ki in kvec)
|
kk = sum(ki**2 for ki in kvec)
|
||||||
return np.exp(-kk * r_split**2)
|
return jax.tree.map(lambda x: np.exp(-x * r_split**2), kk)
|
||||||
else:
|
else:
|
||||||
return 1.
|
return 1.
|
||||||
|
|
||||||
|
@ -131,7 +131,7 @@ def cic_compensation(kvec):
|
||||||
wts: array
|
wts: array
|
||||||
Complex kernel values
|
Complex kernel values
|
||||||
"""
|
"""
|
||||||
kwts = [jnp.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
|
kwts = [jax.tree.map(lambda x: jnp.sinc(x / (2 * np.pi)), kvec[i]) for i in range(3)]
|
||||||
wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
|
wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
|
||||||
return wts
|
return wts
|
||||||
|
|
||||||
|
@ -159,7 +159,7 @@ def PGD_kernel(kvec, kl, ks):
|
||||||
ks4 = ks**4
|
ks4 = ks**4
|
||||||
mask = (kk == 0).nonzero()
|
mask = (kk == 0).nonzero()
|
||||||
kk[mask] = 1
|
kk[mask] = 1
|
||||||
v = jnp.exp(-kl2 / kk) * jnp.exp(-kk**2 / ks4)
|
v = jax.tree.map(lambda x: jnp.exp(-kl2 / x) * jnp.exp(-x**2 / ks4), kk)
|
||||||
imask = (~(kk == 0)).astype(int)
|
imask = jax.tree.map(lambda x: (~(x == 0)).astype(int), kk)
|
||||||
v *= imask
|
v *= imask
|
||||||
return v
|
return v
|
||||||
|
|
Loading…
Add table
Reference in a new issue