mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-02-22 17:47:11 +00:00
Adds a trivial jaxpm implementation
This commit is contained in:
parent
d4d0a03c79
commit
3507339663
4 changed files with 147 additions and 0 deletions
0
jaxpm/__init__.py
Normal file
0
jaxpm/__init__.py
Normal file
85
jaxpm/kernels.py
Normal file
85
jaxpm/kernels.py
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
import numpy as np
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
def fftk(shape, symmetric=True, finite=False, dtype=np.float32):
|
||||||
|
""" Return k_vector given a shape (nc, nc, nc) and box_size
|
||||||
|
"""
|
||||||
|
k = []
|
||||||
|
for d in range(len(shape)):
|
||||||
|
kd = np.fft.fftfreq(shape[d])
|
||||||
|
kd *= 2 * np.pi
|
||||||
|
kdshape = np.ones(len(shape), dtype='int')
|
||||||
|
if symmetric and d == len(shape) - 1:
|
||||||
|
kd = kd[:shape[d] // 2 + 1]
|
||||||
|
kdshape[d] = len(kd)
|
||||||
|
kd = kd.reshape(kdshape)
|
||||||
|
|
||||||
|
k.append(kd.astype(dtype))
|
||||||
|
del kd, kdshape
|
||||||
|
return k
|
||||||
|
|
||||||
|
def gradient_kernel(kvec, direction, order=1):
|
||||||
|
"""
|
||||||
|
Computes the gradient kernel in the requested direction
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
kvec: array
|
||||||
|
Array of k values in Fourier space
|
||||||
|
direction: int
|
||||||
|
Index of the direction in which to take the gradient
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
wts: array
|
||||||
|
Complex kernel
|
||||||
|
"""
|
||||||
|
if order == 0:
|
||||||
|
wts = 1j * kvec[direction]
|
||||||
|
wts = jnp.squeeze(wts)
|
||||||
|
wts[len(wts) // 2] = 0
|
||||||
|
wts = wts.reshape(kvec[direction].shape)
|
||||||
|
return wts
|
||||||
|
else:
|
||||||
|
w = kvec[direction]
|
||||||
|
a = 1 / 6.0 * (8 * jnp.sin(w) - jnp.sin(2 * w))
|
||||||
|
wts = a * 1j
|
||||||
|
return wts
|
||||||
|
|
||||||
|
def laplace_kernel(kvec):
|
||||||
|
"""
|
||||||
|
Compute the Laplace kernel from a given K vector
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
kvec: array
|
||||||
|
Array of k values in Fourier space
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
wts: array
|
||||||
|
Complex kernel
|
||||||
|
"""
|
||||||
|
kk = sum(ki**2 for ki in kvec)
|
||||||
|
mask = (kk == 0).nonzero()
|
||||||
|
kk[mask] = 1
|
||||||
|
wts = 1. / kk
|
||||||
|
imask = (~(kk == 0)).astype(int)
|
||||||
|
wts *= imask
|
||||||
|
return wts
|
||||||
|
|
||||||
|
def longrange_kernel(kvec, r_split):
|
||||||
|
"""
|
||||||
|
Computes a long range kernel
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
kvec: array
|
||||||
|
Array of k values in Fourier space
|
||||||
|
r_split: float
|
||||||
|
TODO: @modichirag add documentation
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
wts: array
|
||||||
|
kernel
|
||||||
|
"""
|
||||||
|
if r_split != 0:
|
||||||
|
kk = sum(ki**2 for ki in kvec)
|
||||||
|
return np.exp(-kk * r_split**2)
|
||||||
|
else:
|
||||||
|
return 1.
|
51
jaxpm/painting.py
Normal file
51
jaxpm/painting.py
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import jax.lax as lax
|
||||||
|
|
||||||
|
def cic_paint(mesh, positions):
|
||||||
|
""" Paints positions onto mesh
|
||||||
|
mesh: [nx, ny, nz]
|
||||||
|
positions: [npart, 3]
|
||||||
|
"""
|
||||||
|
positions = jnp.expand_dims(positions, 1)
|
||||||
|
floor = jnp.floor(positions)
|
||||||
|
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
|
||||||
|
[0., 0, 1], [1., 1, 0], [1., 0, 1],
|
||||||
|
[0., 1, 1], [1., 1, 1]]])
|
||||||
|
|
||||||
|
neighboor_coords = floor + connection
|
||||||
|
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||||
|
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||||
|
|
||||||
|
neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,8,3]).astype('int32'), jnp.array(mesh.shape))
|
||||||
|
|
||||||
|
dnums = jax.lax.ScatterDimensionNumbers(
|
||||||
|
update_window_dims=(),
|
||||||
|
inserted_window_dims=(0, 1, 2),
|
||||||
|
scatter_dims_to_operand_dims=(0, 1, 2))
|
||||||
|
mesh = lax.scatter_add(mesh,
|
||||||
|
neighboor_coords,
|
||||||
|
kernel.reshape([-1,8]),
|
||||||
|
dnums)
|
||||||
|
return mesh
|
||||||
|
|
||||||
|
def cic_read(mesh, positions):
|
||||||
|
""" Paints positions onto mesh
|
||||||
|
mesh: [nx, ny, nz]
|
||||||
|
positions: [npart, 3]
|
||||||
|
"""
|
||||||
|
positions = jnp.expand_dims(positions, 1)
|
||||||
|
floor = jnp.floor(positions)
|
||||||
|
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
|
||||||
|
[0., 0, 1], [1., 1, 0], [1., 0, 1],
|
||||||
|
[0., 1, 1], [1., 1, 1]]])
|
||||||
|
|
||||||
|
neighboor_coords = floor + connection
|
||||||
|
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||||
|
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||||
|
|
||||||
|
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'), jnp.array(mesh.shape))
|
||||||
|
|
||||||
|
return (mesh[neighboor_coords[...,0],
|
||||||
|
neighboor_coords[...,1],
|
||||||
|
neighboor_coords[...,3]]*kernel).sum(axis=-1)
|
11
setup.py
Normal file
11
setup.py
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name='JaxPM',
|
||||||
|
version='0.0.1',
|
||||||
|
url='https://github.com/DifferentiableUniverseInitiative/JaxPM',
|
||||||
|
author='JaxPM developers',
|
||||||
|
description='A dead simple FastPM implementation in JAX',
|
||||||
|
packages=find_packages(),
|
||||||
|
install_requires=['jax', 'jax_cosmo'],
|
||||||
|
)
|
Loading…
Add table
Reference in a new issue