mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-29 16:41:11 +00:00
adding example of distributed solution
This commit is contained in:
parent
a2811c0606
commit
a742065ffd
5 changed files with 192 additions and 62 deletions
|
@ -1,24 +1,33 @@
|
|||
from jaxpm.distributed import autoshmap
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from functools import partial
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
|
||||
def fftk(shape, symmetric=True, finite=False, dtype=np.float32):
|
||||
""" Return k_vector given a shape (nc, nc, nc) and box_size
|
||||
def fftk(shape, dtype=np.float32):
|
||||
"""
|
||||
k = []
|
||||
for d in range(len(shape)):
|
||||
kd = np.fft.fftfreq(shape[d])
|
||||
kd *= 2 * np.pi
|
||||
kdshape = np.ones(len(shape), dtype='int')
|
||||
if symmetric and d == len(shape) - 1:
|
||||
kd = kd[:shape[d] // 2 + 1]
|
||||
kdshape[d] = len(kd)
|
||||
kd = kd.reshape(kdshape)
|
||||
Generate Fourier transform wave numbers for a given mesh.
|
||||
|
||||
k.append(kd.astype(dtype))
|
||||
del kd, kdshape
|
||||
return k
|
||||
Args:
|
||||
nc (int): Shape of the mesh grid.
|
||||
|
||||
Returns:
|
||||
list: List of wave number arrays for each dimension in
|
||||
the order [kx, ky, kz].
|
||||
"""
|
||||
kx, ky, kz = [jnp.fft.fftfreq(s, dtype=dtype) * 2 * np.pi for s in shape]
|
||||
@partial(
|
||||
autoshmap,
|
||||
in_specs=(P('x'), P('y'), P(None)),
|
||||
out_specs=(P('x'), P(None, 'y'), P(None)))
|
||||
def get_kvec(ky, kz, kx):
|
||||
return (ky.reshape([-1, 1, 1]),
|
||||
kz.reshape([1, -1, 1]),
|
||||
kx.reshape([1, 1, -1])) # yapf: disable
|
||||
ky, kz, kx = get_kvec(ky, kz, kx) # The order corresponds
|
||||
# to the order of dimensions in the transposed FFT
|
||||
return kx, ky, kz
|
||||
|
||||
def gradient_kernel(kvec, direction, order=1):
|
||||
"""
|
||||
|
@ -60,11 +69,7 @@ def laplace_kernel(kvec):
|
|||
Complex kernel
|
||||
"""
|
||||
kk = sum(ki**2 for ki in kvec)
|
||||
mask = (kk == 0).nonzero()
|
||||
kk[mask] = 1
|
||||
wts = 1. / kk
|
||||
imask = (~(kk == 0)).astype(int)
|
||||
wts *= imask
|
||||
wts = jnp.where(kk == 0, 1., 1. / kk)
|
||||
return wts
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue