diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py index 912fe2f..ef4b5c5 100644 --- a/jaxpm/kernels.py +++ b/jaxpm/kernels.py @@ -3,7 +3,7 @@ import numpy as np from jax.lax import FftType from jax.sharding import PartitionSpec as P from jaxdecomp import fftfreq3d, get_output_specs - +import jax from jaxpm.distributed import autoshmap @@ -19,13 +19,13 @@ def fftk(k_array): the order [kx, ky, kz]. """ kx, ky, kz = fftfreq3d(k_array) - # to the order of dimensions in the transposed FFT return kx, ky, kz def interpolate_power_spectrum(input, k, pk, sharding=None): - pk_fn = lambda x: jnp.interp(x.reshape(-1), k, pk).reshape(x.shape) + def pk_fn(input): + return jax.tree.map(lambda x: jnp.interp(x.reshape(-1), k, pk).reshape(x.shape), input) gpu_mesh = sharding.mesh if sharding is not None else None specs = sharding.spec if sharding is not None else P() @@ -55,13 +55,13 @@ def gradient_kernel(kvec, direction, order=1): """ if order == 0: wts = 1j * kvec[direction] - wts = jnp.squeeze(wts) + wts = jax.tree.map(jnp.squeeze, wts) wts[len(wts) // 2] = 0 wts = wts.reshape(kvec[direction].shape) return wts else: w = kvec[direction] - a = 1 / 6.0 * (8 * jnp.sin(w) - jnp.sin(2 * w)) + a = jax.tree.map(lambda x: 1 / 6.0 * (8 * jnp.sin(x) - jnp.sin(2 * x)), w) wts = a * 1j return wts @@ -85,11 +85,11 @@ def invlaplace_kernel(kvec, fd=False): Complex kernel values """ if fd: - kk = sum((ki * jnp.sinc(ki / (2 * jnp.pi)))**2 for ki in kvec) + kk = sum(jax.tree.map(lambda x: (x * jnp.sinc(x / (2 * jnp.pi)))**2, ki) for ki in kvec) else: - kk = sum(ki**2 for ki in kvec) - kk_nozeros = jnp.where(kk == 0, 1, kk) - return -jnp.where(kk == 0, 0, 1 / kk_nozeros) + kk = sum(jax.tree.map(lambda x: x**2, ki) for ki in kvec) + kk_nozeros = jax.tree.map(lambda x: jnp.where(x == 0, 1, x), kk) + return jax.tree.map(lambda x , y : -jnp.where(y == 0, 0, 1 / x), kk_nozeros, kk) def longrange_kernel(kvec, r_split): @@ -110,7 +110,7 @@ def longrange_kernel(kvec, r_split): """ if r_split != 0: kk = sum(ki**2 for ki in kvec) - return np.exp(-kk * r_split**2) + return jax.tree.map(lambda x: np.exp(-x * r_split**2), kk) else: return 1. @@ -131,7 +131,7 @@ def cic_compensation(kvec): wts: array Complex kernel values """ - kwts = [jnp.sinc(kvec[i] / (2 * np.pi)) for i in range(3)] + kwts = [jax.tree.map(lambda x: jnp.sinc(x / (2 * np.pi)), kvec[i]) for i in range(3)] wts = (kwts[0] * kwts[1] * kwts[2])**(-2) return wts @@ -159,7 +159,7 @@ def PGD_kernel(kvec, kl, ks): ks4 = ks**4 mask = (kk == 0).nonzero() kk[mask] = 1 - v = jnp.exp(-kl2 / kk) * jnp.exp(-kk**2 / ks4) - imask = (~(kk == 0)).astype(int) + v = jax.tree.map(lambda x: jnp.exp(-kl2 / x) * jnp.exp(-x**2 / ks4), kk) + imask = jax.tree.map(lambda x: (~(x == 0)).astype(int), kk) v *= imask return v