mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-29 16:41:11 +00:00
Adding begnning of implem
This commit is contained in:
parent
3c1abbafcd
commit
1948eae9ed
3 changed files with 68 additions and 87 deletions
100
jaxpm/kernels.py
100
jaxpm/kernels.py
|
@ -1,88 +1,46 @@
|
|||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
|
||||
def fftk(shape, symmetric=True, finite=False, dtype=np.float32):
|
||||
""" Return k_vector given a shape (nc, nc, nc) and box_size
|
||||
def fftk(shape, symmetric=True, dtype=np.float32, comms=None):
|
||||
""" Return k_vector given a shape (nc, nc, nc)
|
||||
"""
|
||||
k = []
|
||||
|
||||
if comms is not None:
|
||||
nx = comms[0].Get_size()
|
||||
ix = comms[0].Get_rank()
|
||||
ny = comms[1].Get_size()
|
||||
iy = comms[1].Get_rank()
|
||||
shape = [shape[0]*nx, shape[1]*ny] + list(shape[2:])
|
||||
|
||||
for d in range(len(shape)):
|
||||
kd = np.fft.fftfreq(shape[d])
|
||||
kd *= 2 * np.pi
|
||||
kdshape = np.ones(len(shape), dtype='int')
|
||||
|
||||
if symmetric and d == len(shape) - 1:
|
||||
kd = kd[:shape[d] // 2 + 1]
|
||||
kdshape[d] = len(kd)
|
||||
kd = kd.reshape(kdshape)
|
||||
|
||||
if (comms is not None) and d==0:
|
||||
kd = kd.reshape([nx, -1])[ix]
|
||||
|
||||
if (comms is not None) and d==1:
|
||||
kd = kd.reshape([ny, -1])[iy]
|
||||
|
||||
k.append(kd.astype(dtype))
|
||||
del kd, kdshape
|
||||
return k
|
||||
|
||||
def gradient_kernel(kvec, direction, order=1):
|
||||
"""
|
||||
Computes the gradient kernel in the requested direction
|
||||
Parameters:
|
||||
-----------
|
||||
kvec: array
|
||||
Array of k values in Fourier space
|
||||
direction: int
|
||||
Index of the direction in which to take the gradient
|
||||
Returns:
|
||||
--------
|
||||
wts: array
|
||||
Complex kernel
|
||||
"""
|
||||
if order == 0:
|
||||
wts = 1j * kvec[direction]
|
||||
wts = jnp.squeeze(wts)
|
||||
wts[len(wts) // 2] = 0
|
||||
wts = wts.reshape(kvec[direction].shape)
|
||||
return wts
|
||||
else:
|
||||
w = kvec[direction]
|
||||
a = 1 / 6.0 * (8 * jnp.sin(w) - jnp.sin(2 * w))
|
||||
wts = a * 1j
|
||||
return wts
|
||||
|
||||
def laplace_kernel(kvec):
|
||||
"""
|
||||
Compute the Laplace kernel from a given K vector
|
||||
Parameters:
|
||||
-----------
|
||||
kvec: array
|
||||
Array of k values in Fourier space
|
||||
Returns:
|
||||
--------
|
||||
wts: array
|
||||
Complex kernel
|
||||
"""
|
||||
kk = sum(ki**2 for ki in kvec)
|
||||
mask = (kk == 0).nonzero()
|
||||
kk[mask] = 1
|
||||
wts = 1. / kk
|
||||
imask = (~(kk == 0)).astype(int)
|
||||
wts *= imask
|
||||
return wts
|
||||
|
||||
def longrange_kernel(kvec, r_split):
|
||||
"""
|
||||
Computes a long range kernel
|
||||
Parameters:
|
||||
-----------
|
||||
kvec: array
|
||||
Array of k values in Fourier space
|
||||
r_split: float
|
||||
TODO: @modichirag add documentation
|
||||
Returns:
|
||||
--------
|
||||
wts: array
|
||||
kernel
|
||||
"""
|
||||
if r_split != 0:
|
||||
kk = sum(ki**2 for ki in kvec)
|
||||
return np.exp(-kk * r_split**2)
|
||||
else:
|
||||
return 1.
|
||||
@partial(jax.pmap,
|
||||
in_axes=[['x','y','z'],
|
||||
['x'],['y'],['z']],
|
||||
out_axes=['x','y','z',...])
|
||||
def apply_gradient_laplace(kfield, kvec):
|
||||
kx, ky, kz = kvec
|
||||
kk = (kx**2 + ky**2 + kz**2)
|
||||
kernel = jnp.where(kk == 0, 1., 1./kk)
|
||||
return jnp.stack([kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky)),
|
||||
kfield * kernel * 1j * 1 / 6.0 *
|
||||
(8 * jnp.sin(kz) - jnp.sin(2 * kz)),
|
||||
kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx))],axis=-1)
|
||||
|
||||
def cic_compensation(kvec):
|
||||
"""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue