update formatting

This commit is contained in:
EiffL 2024-07-09 18:02:57 -04:00
parent 6408aff1de
commit 319942a6bc
5 changed files with 113 additions and 96 deletions

View file

@ -1,12 +1,14 @@
from jaxpm.distributed import autoshmap
from jax.sharding import PartitionSpec as P
from functools import partial
import jax.numpy as jnp
import numpy as np
from jax.sharding import PartitionSpec as P
from jaxpm.distributed import autoshmap
def fftk(shape, dtype=np.float32):
"""
"""
Generate Fourier transform wave numbers for a given mesh.
Args:
@ -16,18 +18,19 @@ def fftk(shape, dtype=np.float32):
list: List of wave number arrays for each dimension in
the order [kx, ky, kz].
"""
kx, ky, kz = [jnp.fft.fftfreq(s, dtype=dtype) * 2 * np.pi for s in shape]
@partial(
autoshmap,
in_specs=(P('x'), P('y'), P(None)),
out_specs=(P('x'), P(None, 'y'), P(None)))
def get_kvec(ky, kz, kx):
return (ky.reshape([-1, 1, 1]),
kz.reshape([1, -1, 1]),
kx.reshape([1, 1, -1])) # yapf: disable
ky, kz, kx = get_kvec(ky, kz, kx) # The order corresponds
# to the order of dimensions in the transposed FFT
return kx, ky, kz
kx, ky, kz = [jnp.fft.fftfreq(s, dtype=dtype) * 2 * np.pi for s in shape]
@partial(autoshmap,
in_specs=(P('x'), P('y'), P(None)),
out_specs=(P('x'), P(None, 'y'), P(None)))
def get_kvec(ky, kz, kx):
return (ky.reshape([-1, 1, 1]),
kz.reshape([1, -1, 1]),
kx.reshape([1, 1, -1])) # yapf: disable
ky, kz, kx = get_kvec(ky, kz, kx) # The order corresponds
# to the order of dimensions in the transposed FFT
return kx, ky, kz
def gradient_kernel(kvec, direction, order=1):
"""