diff --git a/jaxpm/ops.py b/jaxpm/ops.py new file mode 100644 index 0000000..b4c1801 --- /dev/null +++ b/jaxpm/ops.py @@ -0,0 +1,111 @@ +# Module for custom ops, typically mpi4jax +import jax +import jax.numpy as jnp +import mpi4jax + + +def fft3d(arr, token=None, comms=None): + """ Computes forward FFT, note that the output is transposed + """ + if comms is not None: + shape = list(arr.shape) + nx = comms[0].Get_size() + ny = comms[1].Get_size() + + # First FFT along z + arr = jnp.fft.fft(arr) # [x, y, z] + # Perform single gpu or distributed transpose + if comms == None: + arr = arr.transpose([1, 2, 0]) + else: + arr = arr.reshape(shape[:-1]+[nx, shape[-1] // nx]) + arr = arr.transpose([2, 1, 3, 0]) # [y, z, x] + arr, token = mpi4jax.alltoall(arr, comm=comms[0], token=token) + arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [y, z, x] + + # Second FFT along x + arr = jnp.fft.fft(arr) + # Perform single gpu or distributed transpose + if comms == None: + arr = arr.transpose([1, 2, 0]) + else: + arr = arr.reshape(shape[:-1]+[ny, shape[-1] // ny]) + arr = arr.transpose([2, 1, 3, 0]) # [z, x, y] + arr, token = mpi4jax.alltoall(arr, comm=comms[1], token=token) + arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [z, x, y] + + # Third FFT along y + arr = jnp.fft.fft(arr) + + if comms == None: + return arr + else: + return arr, token + + +def ifft3d(arr, token=None, comms=None): + """ Let's assume that the data is distributed accross x + """ + if comms is not None: + shape = list(arr.shape) + nx = comms[0].Get_size() + ny = comms[1].Get_size() + + # First FFT along y + arr = jnp.fft.ifft(arr) # Now [z, x, y] + # Perform single gpu or distributed transpose + if comms == None: + arr = arr.transpose([0, 2, 1]) + else: + arr = arr.reshape(shape[:-1]+[ny, shape[-1] // ny]) + arr = arr.transpose([2, 0, 3, 1]) # Now [z, y, x] + arr, token = mpi4jax.alltoall(arr, comm=comms[1], token=token) + arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [z,y,x] + + # Second FFT along x + arr = jnp.fft.ifft(arr) + # Perform single gpu or distributed transpose + if comms == None: + arr = arr.transpose([2, 1, 0]) + else: + arr = arr.reshape(shape[:-1]+[nx, shape[-1] // nx]) + arr = arr.transpose([2, 3, 1, 0]) # now [x, y, z] + arr, token = mpi4jax.alltoall(arr, comm=comms[0], token=token) + arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [x,y,z] + + # Third FFT along y + arr = jnp.fft.fft(arr) + + if comms == None: + return arr + else: + return arr, token + + +def halo_reduce(arr, halo_size, token=None, comms=None): + + # Perform halo exchange along x + rank_x = comms[0].Get_rank() + margin = arr[-2*halo_size:] + margin, token = mpi4jax.sendrecv(margin, margin, rank_x-1, rank_x+1, + comm=comms[0], token=token) + arr = arr.at[:2*halo_size].add(margin) + + margin = arr[:2*halo_size] + margin, token = mpi4jax.sendrecv(margin, margin, rank_x+1, rank_x-1, + comm=comms[0], token=token) + arr = arr.at[-2*halo_size:].add(margin) + + # Perform halo exchange along y + rank_y = comms[1].Get_rank() + margin = arr[:, -2*halo_size:] + margin, token = mpi4jax.sendrecv(margin, margin, rank_y-1, rank_y+1, + comm=comms[0], token=token) + arr = arr.at[:, :2*halo_size].add(margin) + + margin = arr[:, :2*halo_size] + margin, token = mpi4jax.sendrecv(margin, margin, rank_y+1, rank_y-1, + comm=comms[0], token=token) + arr = arr.at[:, -2*halo_size:].add(margin) + + return arr, token diff --git a/jaxpm/painting.py b/jaxpm/painting.py index 4237c23..c9b72ca 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -2,95 +2,130 @@ import jax import jax.numpy as jnp import jax.lax as lax +from jaxpm.ops import halo_reduce from jaxpm.kernels import fftk, cic_compensation -def cic_paint(mesh, positions): - """ Paints positions onto mesh - mesh: [nx, ny, nz] - positions: [npart, 3] - """ - positions = jnp.expand_dims(positions, 1) - floor = jnp.floor(positions) - connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], - [0., 0, 1], [1., 1, 0], [1., 0, 1], - [0., 1, 1], [1., 1, 1]]]) - neighboor_coords = floor + connection - kernel = 1. - jnp.abs(positions - neighboor_coords) - kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] +def cic_paint(mesh, positions, halo_size=0, token=None, comms=None): + """ Paints positions onto mesh + mesh: [nx, ny, nz] + positions: [npart, 3] + """ + if comms is not None: + # Add some padding for the halo exchange + mesh = jnp.pad(mesh, [[halo_size, halo_size], + [halo_size, halo_size], + [0, 0]]) + positions += jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]) - neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,8,3]).astype('int32'), jnp.array(mesh.shape)) + positions = jnp.expand_dims(positions, 1) + floor = jnp.floor(positions) + connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], + [0., 0, 1], [1., 1, 0], [1., 0, 1], + [0., 1, 1], [1., 1, 1]]]) - dnums = jax.lax.ScatterDimensionNumbers( - update_window_dims=(), - inserted_window_dims=(0, 1, 2), - scatter_dims_to_operand_dims=(0, 1, 2)) - mesh = lax.scatter_add(mesh, - neighboor_coords, - kernel.reshape([-1,8]), - dnums) - return mesh + neighboor_coords = floor + connection + kernel = 1. - jnp.abs(positions - neighboor_coords) + kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] -def cic_read(mesh, positions): - """ Paints positions onto mesh - mesh: [nx, ny, nz] - positions: [npart, 3] - """ - positions = jnp.expand_dims(positions, 1) - floor = jnp.floor(positions) - connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], - [0., 0, 1], [1., 1, 0], [1., 0, 1], - [0., 1, 1], [1., 1, 1]]]) + neighboor_coords = jnp.mod(neighboor_coords.reshape( + [-1, 8, 3]).astype('int32'), jnp.array(mesh.shape)) - neighboor_coords = floor + connection - kernel = 1. - jnp.abs(positions - neighboor_coords) - kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] + dnums = jax.lax.ScatterDimensionNumbers( + update_window_dims=(), + inserted_window_dims=(0, 1, 2), + scatter_dims_to_operand_dims=(0, 1, 2)) + mesh = lax.scatter_add(mesh, + neighboor_coords, + kernel.reshape([-1, 8]), + dnums) - neighboor_coords = jnp.mod(neighboor_coords.astype('int32'), jnp.array(mesh.shape)) + if comms == None: + return mesh + else: + mesh, token = halo_reduce(mesh, halo_size, token, comms) + return mesh[halo_size:-halo_size, halo_size:-halo_size] + + +def cic_read(mesh, positions, halo_size=0, token=None, comms=None): + """ Paints positions onto mesh + mesh: [nx, ny, nz] + positions: [npart, 3] + """ + + if comms is not None: + # Add some padding and perfom hao exchange to retrieve + # neighboring regions + mesh = jnp.pad(mesh, [[halo_size, halo_size], + [halo_size, halo_size], + [0, 0]]) + mesh, token = halo_reduce(mesh, halo_size, token, comms) + positions += jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]) + + positions = jnp.expand_dims(positions, 1) + floor = jnp.floor(positions) + connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], + [0., 0, 1], [1., 1, 0], [1., 0, 1], + [0., 1, 1], [1., 1, 1]]]) + + neighboor_coords = floor + connection + kernel = 1. - jnp.abs(positions - neighboor_coords) + kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] + + neighboor_coords = jnp.mod( + neighboor_coords.astype('int32'), jnp.array(mesh.shape)) + + res = (mesh[neighboor_coords[..., 0], + neighboor_coords[..., 1], + neighboor_coords[..., 3]]*kernel).sum(axis=-1) + + if comms is not None: + return res + else: + return res, token - return (mesh[neighboor_coords[...,0], - neighboor_coords[...,1], - neighboor_coords[...,3]]*kernel).sum(axis=-1) def cic_paint_2d(mesh, positions, weight): - """ Paints positions onto a 2d mesh - mesh: [nx, ny] - positions: [npart, 2] - weight: [npart] - """ - positions = jnp.expand_dims(positions, 1) - floor = jnp.floor(positions) - connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]]) + """ Paints positions onto a 2d mesh + mesh: [nx, ny] + positions: [npart, 2] + weight: [npart] + """ + positions = jnp.expand_dims(positions, 1) + floor = jnp.floor(positions) + connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]]) - neighboor_coords = floor + connection - kernel = 1. - jnp.abs(positions - neighboor_coords) - kernel = kernel[..., 0] * kernel[..., 1] - if weight is not None: - kernel = kernel * weight[...,jnp.newaxis] - - neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,4,2]).astype('int32'), jnp.array(mesh.shape)) + neighboor_coords = floor + connection + kernel = 1. - jnp.abs(positions - neighboor_coords) + kernel = kernel[..., 0] * kernel[..., 1] + if weight is not None: + kernel = kernel * weight[..., jnp.newaxis] + + neighboor_coords = jnp.mod(neighboor_coords.reshape( + [-1, 4, 2]).astype('int32'), jnp.array(mesh.shape)) + + dnums = jax.lax.ScatterDimensionNumbers( + update_window_dims=(), + inserted_window_dims=(0, 1), + scatter_dims_to_operand_dims=(0, 1)) + mesh = lax.scatter_add(mesh, + neighboor_coords, + kernel.reshape([-1, 4]), + dnums) + return mesh - dnums = jax.lax.ScatterDimensionNumbers( - update_window_dims=(), - inserted_window_dims=(0, 1), - scatter_dims_to_operand_dims=(0, 1)) - mesh = lax.scatter_add(mesh, - neighboor_coords, - kernel.reshape([-1,4]), - dnums) - return mesh def compensate_cic(field): - """ - Compensate for CiC painting - Args: - field: input 3D cic-painted field - Returns: - compensated_field - """ - nc = field.shape - kvec = fftk(nc) + """ + Compensate for CiC painting + Args: + field: input 3D cic-painted field + Returns: + compensated_field + """ + nc = field.shape + kvec = fftk(nc) - delta_k = jnp.fft.rfftn(field) - delta_k = cic_compensation(kvec) * delta_k - return jnp.fft.irfftn(delta_k) \ No newline at end of file + delta_k = jnp.fft.rfftn(field) + delta_k = cic_compensation(kvec) * delta_k + return jnp.fft.irfftn(delta_k)