adding example of distributed solution

This commit is contained in:
EiffL 2024-07-09 17:45:28 -04:00
parent a2811c0606
commit a742065ffd
5 changed files with 192 additions and 62 deletions

View file

@ -1,24 +1,33 @@
from jaxpm.distributed import autoshmap
from jax.sharding import PartitionSpec as P
from functools import partial
import jax.numpy as jnp
import numpy as np
def fftk(shape, symmetric=True, finite=False, dtype=np.float32):
""" Return k_vector given a shape (nc, nc, nc) and box_size
def fftk(shape, dtype=np.float32):
"""
k = []
for d in range(len(shape)):
kd = np.fft.fftfreq(shape[d])
kd *= 2 * np.pi
kdshape = np.ones(len(shape), dtype='int')
if symmetric and d == len(shape) - 1:
kd = kd[:shape[d] // 2 + 1]
kdshape[d] = len(kd)
kd = kd.reshape(kdshape)
Generate Fourier transform wave numbers for a given mesh.
k.append(kd.astype(dtype))
del kd, kdshape
return k
Args:
nc (int): Shape of the mesh grid.
Returns:
list: List of wave number arrays for each dimension in
the order [kx, ky, kz].
"""
kx, ky, kz = [jnp.fft.fftfreq(s, dtype=dtype) * 2 * np.pi for s in shape]
@partial(
autoshmap,
in_specs=(P('x'), P('y'), P(None)),
out_specs=(P('x'), P(None, 'y'), P(None)))
def get_kvec(ky, kz, kx):
return (ky.reshape([-1, 1, 1]),
kz.reshape([1, -1, 1]),
kx.reshape([1, 1, -1])) # yapf: disable
ky, kz, kx = get_kvec(ky, kz, kx) # The order corresponds
# to the order of dimensions in the transposed FFT
return kx, ky, kz
def gradient_kernel(kvec, direction, order=1):
"""
@ -60,11 +69,7 @@ def laplace_kernel(kvec):
Complex kernel
"""
kk = sum(ki**2 for ki in kvec)
mask = (kk == 0).nonzero()
kk[mask] = 1
wts = 1. / kk
imask = (~(kk == 0)).astype(int)
wts *= imask
wts = jnp.where(kk == 0, 1., 1. / kk)
return wts