From 350733966331b4df988e873c43132871e995201e Mon Sep 17 00:00:00 2001 From: EiffL Date: Sun, 13 Feb 2022 21:36:03 +0100 Subject: [PATCH] Adds a trivial jaxpm implementation --- jaxpm/__init__.py | 0 jaxpm/kernels.py | 85 +++++++++++++++++++++++++++++++++++++++++++++++ jaxpm/painting.py | 51 ++++++++++++++++++++++++++++ setup.py | 11 ++++++ 4 files changed, 147 insertions(+) create mode 100644 jaxpm/__init__.py create mode 100644 jaxpm/kernels.py create mode 100644 jaxpm/painting.py create mode 100644 setup.py diff --git a/jaxpm/__init__.py b/jaxpm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py new file mode 100644 index 0000000..73a8c93 --- /dev/null +++ b/jaxpm/kernels.py @@ -0,0 +1,85 @@ +import numpy as np +import jax.numpy as jnp + +def fftk(shape, symmetric=True, finite=False, dtype=np.float32): + """ Return k_vector given a shape (nc, nc, nc) and box_size + """ + k = [] + for d in range(len(shape)): + kd = np.fft.fftfreq(shape[d]) + kd *= 2 * np.pi + kdshape = np.ones(len(shape), dtype='int') + if symmetric and d == len(shape) - 1: + kd = kd[:shape[d] // 2 + 1] + kdshape[d] = len(kd) + kd = kd.reshape(kdshape) + + k.append(kd.astype(dtype)) + del kd, kdshape + return k + +def gradient_kernel(kvec, direction, order=1): + """ + Computes the gradient kernel in the requested direction + Parameters: + ----------- + kvec: array + Array of k values in Fourier space + direction: int + Index of the direction in which to take the gradient + Returns: + -------- + wts: array + Complex kernel + """ + if order == 0: + wts = 1j * kvec[direction] + wts = jnp.squeeze(wts) + wts[len(wts) // 2] = 0 + wts = wts.reshape(kvec[direction].shape) + return wts + else: + w = kvec[direction] + a = 1 / 6.0 * (8 * jnp.sin(w) - jnp.sin(2 * w)) + wts = a * 1j + return wts + +def laplace_kernel(kvec): + """ + Compute the Laplace kernel from a given K vector + Parameters: + ----------- + kvec: array + Array of k values in Fourier space + Returns: + -------- + wts: array + Complex kernel + """ + kk = sum(ki**2 for ki in kvec) + mask = (kk == 0).nonzero() + kk[mask] = 1 + wts = 1. / kk + imask = (~(kk == 0)).astype(int) + wts *= imask + return wts + +def longrange_kernel(kvec, r_split): + """ + Computes a long range kernel + Parameters: + ----------- + kvec: array + Array of k values in Fourier space + r_split: float + TODO: @modichirag add documentation + Returns: + -------- + wts: array + kernel + """ + if r_split != 0: + kk = sum(ki**2 for ki in kvec) + return np.exp(-kk * r_split**2) + else: + return 1. diff --git a/jaxpm/painting.py b/jaxpm/painting.py new file mode 100644 index 0000000..6eb1925 --- /dev/null +++ b/jaxpm/painting.py @@ -0,0 +1,51 @@ +import jax +import jax.numpy as jnp +import jax.lax as lax + +def cic_paint(mesh, positions): + """ Paints positions onto mesh + mesh: [nx, ny, nz] + positions: [npart, 3] + """ + positions = jnp.expand_dims(positions, 1) + floor = jnp.floor(positions) + connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], + [0., 0, 1], [1., 1, 0], [1., 0, 1], + [0., 1, 1], [1., 1, 1]]]) + + neighboor_coords = floor + connection + kernel = 1. - jnp.abs(positions - neighboor_coords) + kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] + + neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,8,3]).astype('int32'), jnp.array(mesh.shape)) + + dnums = jax.lax.ScatterDimensionNumbers( + update_window_dims=(), + inserted_window_dims=(0, 1, 2), + scatter_dims_to_operand_dims=(0, 1, 2)) + mesh = lax.scatter_add(mesh, + neighboor_coords, + kernel.reshape([-1,8]), + dnums) + return mesh + +def cic_read(mesh, positions): + """ Paints positions onto mesh + mesh: [nx, ny, nz] + positions: [npart, 3] + """ + positions = jnp.expand_dims(positions, 1) + floor = jnp.floor(positions) + connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], + [0., 0, 1], [1., 1, 0], [1., 0, 1], + [0., 1, 1], [1., 1, 1]]]) + + neighboor_coords = floor + connection + kernel = 1. - jnp.abs(positions - neighboor_coords) + kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] + + neighboor_coords = jnp.mod(neighboor_coords.astype('int32'), jnp.array(mesh.shape)) + + return (mesh[neighboor_coords[...,0], + neighboor_coords[...,1], + neighboor_coords[...,3]]*kernel).sum(axis=-1) diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..44be5a1 --- /dev/null +++ b/setup.py @@ -0,0 +1,11 @@ +from setuptools import setup, find_packages + +setup( + name='JaxPM', + version='0.0.1', + url='https://github.com/DifferentiableUniverseInitiative/JaxPM', + author='JaxPM developers', + description='A dead simple FastPM implementation in JAX', + packages=find_packages(), + install_requires=['jax', 'jax_cosmo'], +) \ No newline at end of file