kernels now accept pytrees

This commit is contained in:
Wassim Kabalan 2025-01-18 01:13:39 +01:00
parent f5755b4b5d
commit 204a9526ec

View file

@ -3,7 +3,7 @@ import numpy as np
from jax.lax import FftType
from jax.sharding import PartitionSpec as P
from jaxdecomp import fftfreq3d, get_output_specs
import jax
from jaxpm.distributed import autoshmap
@ -19,13 +19,13 @@ def fftk(k_array):
the order [kx, ky, kz].
"""
kx, ky, kz = fftfreq3d(k_array)
# to the order of dimensions in the transposed FFT
return kx, ky, kz
def interpolate_power_spectrum(input, k, pk, sharding=None):
pk_fn = lambda x: jnp.interp(x.reshape(-1), k, pk).reshape(x.shape)
def pk_fn(input):
return jax.tree.map(lambda x: jnp.interp(x.reshape(-1), k, pk).reshape(x.shape), input)
gpu_mesh = sharding.mesh if sharding is not None else None
specs = sharding.spec if sharding is not None else P()
@ -55,13 +55,13 @@ def gradient_kernel(kvec, direction, order=1):
"""
if order == 0:
wts = 1j * kvec[direction]
wts = jnp.squeeze(wts)
wts = jax.tree.map(jnp.squeeze, wts)
wts[len(wts) // 2] = 0
wts = wts.reshape(kvec[direction].shape)
return wts
else:
w = kvec[direction]
a = 1 / 6.0 * (8 * jnp.sin(w) - jnp.sin(2 * w))
a = jax.tree.map(lambda x: 1 / 6.0 * (8 * jnp.sin(x) - jnp.sin(2 * x)), w)
wts = a * 1j
return wts
@ -85,11 +85,11 @@ def invlaplace_kernel(kvec, fd=False):
Complex kernel values
"""
if fd:
kk = sum((ki * jnp.sinc(ki / (2 * jnp.pi)))**2 for ki in kvec)
kk = sum(jax.tree.map(lambda x: (x * jnp.sinc(x / (2 * jnp.pi)))**2, ki) for ki in kvec)
else:
kk = sum(ki**2 for ki in kvec)
kk_nozeros = jnp.where(kk == 0, 1, kk)
return -jnp.where(kk == 0, 0, 1 / kk_nozeros)
kk = sum(jax.tree.map(lambda x: x**2, ki) for ki in kvec)
kk_nozeros = jax.tree.map(lambda x: jnp.where(x == 0, 1, x), kk)
return jax.tree.map(lambda x , y : -jnp.where(y == 0, 0, 1 / x), kk_nozeros, kk)
def longrange_kernel(kvec, r_split):
@ -110,7 +110,7 @@ def longrange_kernel(kvec, r_split):
"""
if r_split != 0:
kk = sum(ki**2 for ki in kvec)
return np.exp(-kk * r_split**2)
return jax.tree.map(lambda x: np.exp(-x * r_split**2), kk)
else:
return 1.
@ -131,7 +131,7 @@ def cic_compensation(kvec):
wts: array
Complex kernel values
"""
kwts = [jnp.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
kwts = [jax.tree.map(lambda x: jnp.sinc(x / (2 * np.pi)), kvec[i]) for i in range(3)]
wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
return wts
@ -159,7 +159,7 @@ def PGD_kernel(kvec, kl, ks):
ks4 = ks**4
mask = (kk == 0).nonzero()
kk[mask] = 1
v = jnp.exp(-kl2 / kk) * jnp.exp(-kk**2 / ks4)
imask = (~(kk == 0)).astype(int)
v = jax.tree.map(lambda x: jnp.exp(-kl2 / x) * jnp.exp(-x**2 / ks4), kk)
imask = jax.tree.map(lambda x: (~(x == 0)).astype(int), kk)
v *= imask
return v