From 8b9287184ada2f0ee55aed318ce8943b35947b0c Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Mon, 8 Jul 2024 00:23:53 +0200 Subject: [PATCH] Export operators --- jaxpm/ops.py | 34 ++++++++++++ jaxpm/painting.py | 129 +++++++++++++++++----------------------------- 2 files changed, 82 insertions(+), 81 deletions(-) create mode 100644 jaxpm/ops.py diff --git a/jaxpm/ops.py b/jaxpm/ops.py new file mode 100644 index 0000000..3054807 --- /dev/null +++ b/jaxpm/ops.py @@ -0,0 +1,34 @@ +import numpy as np +from _src.spmd_config import pm_operators + + +def fftn(arr): + return pm_operators.fftn(arr) + + +def ifftn(arr): + return pm_operators.ifftn(arr) + + +def halo_exchange(arr): + return pm_operators.halo_exchange(arr) + + +def slice_pad(arr, pad_width): + return pm_operators.slice_pad(arr, pad_width) + + +def slice_unpad(arr, pad_width): + return pm_operators.slice_unpad(arr, pad_width) + + +def normal(shape, key, dtype='float32'): + return pm_operators.normal(shape, key, dtype) + + +def fftk(shape, symmetric=True, finite=False, dtype=np.float32): + return pm_operators.fftk(shape, symmetric, finite, dtype) + + +def generate_initial_positions(shape): + return pm_operators.generate_initial_positions(shape) diff --git a/jaxpm/painting.py b/jaxpm/painting.py index 67d54b0..108838c 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -1,98 +1,65 @@ import jax -import jax.numpy as jnp import jax.lax as lax +import jax.numpy as jnp -from jaxpm.kernels import fftk, cic_compensation +import jaxpm +import jaxpm.ops +from jaxpm.kernels import cic_compensation, fftk -def cic_paint(mesh, positions, weight=None): - """ Paints positions onto mesh - mesh: [nx, ny, nz] - positions: [npart, 3] - """ - positions = jnp.expand_dims(positions, 1) - floor = jnp.floor(positions) - connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], - [0., 0, 1], [1., 1, 0], [1., 0, 1], - [0., 1, 1], [1., 1, 1]]]) - neighboor_coords = floor + connection - kernel = 1. - jnp.abs(positions - neighboor_coords) - kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] - if weight is not None: - kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel) - - neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,8,3]).astype('int32'), jnp.array(mesh.shape)) +def cic_paint(particle_mesh, positions, halo_size=0): + return jaxpm.ops.cic_paint(particle_mesh, positions, halo_size=halo_size) - dnums = jax.lax.ScatterDimensionNumbers( - update_window_dims=(), - inserted_window_dims=(0, 1, 2), - scatter_dims_to_operand_dims=(0, 1, 2)) - mesh = lax.scatter_add(mesh, - neighboor_coords, - kernel.reshape([-1,8]), - dnums) - return mesh -def cic_read(mesh, positions): - """ Paints positions onto mesh - mesh: [nx, ny, nz] - positions: [npart, 3] - """ - positions = jnp.expand_dims(positions, 1) - floor = jnp.floor(positions) - connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], - [0., 0, 1], [1., 1, 0], [1., 0, 1], - [0., 1, 1], [1., 1, 1]]]) +def cic_read(mesh, positions, halo_size=0): + return jaxpm.ops.cic_read(mesh, positions, halo_size=halo_size) - neighboor_coords = floor + connection - kernel = 1. - jnp.abs(positions - neighboor_coords) - kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] - neighboor_coords = jnp.mod(neighboor_coords.astype('int32'), jnp.array(mesh.shape)) +def cic_paint_dx(displacements, halo_size=0): + return jaxpm.ops.cic_paint_dx(displacements, halo_size=halo_size) - return (mesh[neighboor_coords[...,0], - neighboor_coords[...,1], - neighboor_coords[...,3]]*kernel).sum(axis=-1) +# TO REDO def cic_paint_2d(mesh, positions, weight): - """ Paints positions onto a 2d mesh - mesh: [nx, ny] - positions: [npart, 2] - weight: [npart] - """ - positions = jnp.expand_dims(positions, 1) - floor = jnp.floor(positions) - connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]]) + """ Paints positions onto a 2d mesh + mesh: [nx, ny] + positions: [npart, 2] + weight: [npart] + """ + positions = jnp.expand_dims(positions, 1) + floor = jnp.floor(positions) + connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]]) - neighboor_coords = floor + connection - kernel = 1. - jnp.abs(positions - neighboor_coords) - kernel = kernel[..., 0] * kernel[..., 1] - if weight is not None: - kernel = kernel * weight[...,jnp.newaxis] - - neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,4,2]).astype('int32'), jnp.array(mesh.shape)) + neighboor_coords = floor + connection + kernel = 1. - jnp.abs(positions - neighboor_coords) + kernel = kernel[..., 0] * kernel[..., 1] + if weight is not None: + kernel = kernel * weight[..., jnp.newaxis] + + neighboor_coords = jnp.mod( + neighboor_coords.reshape([-1, 4, 2]).astype('int32'), + jnp.array(mesh.shape)) + + dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(), + inserted_window_dims=(0, 1), + scatter_dims_to_operand_dims=(0, + 1)) + mesh = lax.scatter_add(mesh, neighboor_coords, kernel.reshape([-1, 4]), + dnums) + return mesh - dnums = jax.lax.ScatterDimensionNumbers( - update_window_dims=(), - inserted_window_dims=(0, 1), - scatter_dims_to_operand_dims=(0, 1)) - mesh = lax.scatter_add(mesh, - neighboor_coords, - kernel.reshape([-1,4]), - dnums) - return mesh def compensate_cic(field): - """ - Compensate for CiC painting - Args: - field: input 3D cic-painted field - Returns: - compensated_field - """ - nc = field.shape - kvec = fftk(nc) + """ + Compensate for CiC painting + Args: + field: input 3D cic-painted field + Returns: + compensated_field + """ + nc = field.shape + kvec = jaxpm.ops.fftk(nc) - delta_k = jnp.fft.rfftn(field) - delta_k = cic_compensation(kvec) * delta_k - return jnp.fft.irfftn(delta_k) + delta_k = jaxpm.ops.fftn(field) + delta_k = cic_compensation(kvec) * delta_k + return jaxpm.ops.ifftn(delta_k).real