mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-24 11:50:53 +00:00
adds mpi tools
This commit is contained in:
parent
2db7c43ab8
commit
3c1abbafcd
2 changed files with 222 additions and 76 deletions
111
jaxpm/ops.py
Normal file
111
jaxpm/ops.py
Normal file
|
@ -0,0 +1,111 @@
|
||||||
|
# Module for custom ops, typically mpi4jax
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import mpi4jax
|
||||||
|
|
||||||
|
|
||||||
|
def fft3d(arr, token=None, comms=None):
|
||||||
|
""" Computes forward FFT, note that the output is transposed
|
||||||
|
"""
|
||||||
|
if comms is not None:
|
||||||
|
shape = list(arr.shape)
|
||||||
|
nx = comms[0].Get_size()
|
||||||
|
ny = comms[1].Get_size()
|
||||||
|
|
||||||
|
# First FFT along z
|
||||||
|
arr = jnp.fft.fft(arr) # [x, y, z]
|
||||||
|
# Perform single gpu or distributed transpose
|
||||||
|
if comms == None:
|
||||||
|
arr = arr.transpose([1, 2, 0])
|
||||||
|
else:
|
||||||
|
arr = arr.reshape(shape[:-1]+[nx, shape[-1] // nx])
|
||||||
|
arr = arr.transpose([2, 1, 3, 0]) # [y, z, x]
|
||||||
|
arr, token = mpi4jax.alltoall(arr, comm=comms[0], token=token)
|
||||||
|
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [y, z, x]
|
||||||
|
|
||||||
|
# Second FFT along x
|
||||||
|
arr = jnp.fft.fft(arr)
|
||||||
|
# Perform single gpu or distributed transpose
|
||||||
|
if comms == None:
|
||||||
|
arr = arr.transpose([1, 2, 0])
|
||||||
|
else:
|
||||||
|
arr = arr.reshape(shape[:-1]+[ny, shape[-1] // ny])
|
||||||
|
arr = arr.transpose([2, 1, 3, 0]) # [z, x, y]
|
||||||
|
arr, token = mpi4jax.alltoall(arr, comm=comms[1], token=token)
|
||||||
|
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [z, x, y]
|
||||||
|
|
||||||
|
# Third FFT along y
|
||||||
|
arr = jnp.fft.fft(arr)
|
||||||
|
|
||||||
|
if comms == None:
|
||||||
|
return arr
|
||||||
|
else:
|
||||||
|
return arr, token
|
||||||
|
|
||||||
|
|
||||||
|
def ifft3d(arr, token=None, comms=None):
|
||||||
|
""" Let's assume that the data is distributed accross x
|
||||||
|
"""
|
||||||
|
if comms is not None:
|
||||||
|
shape = list(arr.shape)
|
||||||
|
nx = comms[0].Get_size()
|
||||||
|
ny = comms[1].Get_size()
|
||||||
|
|
||||||
|
# First FFT along y
|
||||||
|
arr = jnp.fft.ifft(arr) # Now [z, x, y]
|
||||||
|
# Perform single gpu or distributed transpose
|
||||||
|
if comms == None:
|
||||||
|
arr = arr.transpose([0, 2, 1])
|
||||||
|
else:
|
||||||
|
arr = arr.reshape(shape[:-1]+[ny, shape[-1] // ny])
|
||||||
|
arr = arr.transpose([2, 0, 3, 1]) # Now [z, y, x]
|
||||||
|
arr, token = mpi4jax.alltoall(arr, comm=comms[1], token=token)
|
||||||
|
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [z,y,x]
|
||||||
|
|
||||||
|
# Second FFT along x
|
||||||
|
arr = jnp.fft.ifft(arr)
|
||||||
|
# Perform single gpu or distributed transpose
|
||||||
|
if comms == None:
|
||||||
|
arr = arr.transpose([2, 1, 0])
|
||||||
|
else:
|
||||||
|
arr = arr.reshape(shape[:-1]+[nx, shape[-1] // nx])
|
||||||
|
arr = arr.transpose([2, 3, 1, 0]) # now [x, y, z]
|
||||||
|
arr, token = mpi4jax.alltoall(arr, comm=comms[0], token=token)
|
||||||
|
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [x,y,z]
|
||||||
|
|
||||||
|
# Third FFT along y
|
||||||
|
arr = jnp.fft.fft(arr)
|
||||||
|
|
||||||
|
if comms == None:
|
||||||
|
return arr
|
||||||
|
else:
|
||||||
|
return arr, token
|
||||||
|
|
||||||
|
|
||||||
|
def halo_reduce(arr, halo_size, token=None, comms=None):
|
||||||
|
|
||||||
|
# Perform halo exchange along x
|
||||||
|
rank_x = comms[0].Get_rank()
|
||||||
|
margin = arr[-2*halo_size:]
|
||||||
|
margin, token = mpi4jax.sendrecv(margin, margin, rank_x-1, rank_x+1,
|
||||||
|
comm=comms[0], token=token)
|
||||||
|
arr = arr.at[:2*halo_size].add(margin)
|
||||||
|
|
||||||
|
margin = arr[:2*halo_size]
|
||||||
|
margin, token = mpi4jax.sendrecv(margin, margin, rank_x+1, rank_x-1,
|
||||||
|
comm=comms[0], token=token)
|
||||||
|
arr = arr.at[-2*halo_size:].add(margin)
|
||||||
|
|
||||||
|
# Perform halo exchange along y
|
||||||
|
rank_y = comms[1].Get_rank()
|
||||||
|
margin = arr[:, -2*halo_size:]
|
||||||
|
margin, token = mpi4jax.sendrecv(margin, margin, rank_y-1, rank_y+1,
|
||||||
|
comm=comms[0], token=token)
|
||||||
|
arr = arr.at[:, :2*halo_size].add(margin)
|
||||||
|
|
||||||
|
margin = arr[:, :2*halo_size]
|
||||||
|
margin, token = mpi4jax.sendrecv(margin, margin, rank_y+1, rank_y-1,
|
||||||
|
comm=comms[0], token=token)
|
||||||
|
arr = arr.at[:, -2*halo_size:].add(margin)
|
||||||
|
|
||||||
|
return arr, token
|
|
@ -2,95 +2,130 @@ import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jax.lax as lax
|
import jax.lax as lax
|
||||||
|
|
||||||
|
from jaxpm.ops import halo_reduce
|
||||||
from jaxpm.kernels import fftk, cic_compensation
|
from jaxpm.kernels import fftk, cic_compensation
|
||||||
|
|
||||||
def cic_paint(mesh, positions):
|
|
||||||
""" Paints positions onto mesh
|
|
||||||
mesh: [nx, ny, nz]
|
|
||||||
positions: [npart, 3]
|
|
||||||
"""
|
|
||||||
positions = jnp.expand_dims(positions, 1)
|
|
||||||
floor = jnp.floor(positions)
|
|
||||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
|
|
||||||
[0., 0, 1], [1., 1, 0], [1., 0, 1],
|
|
||||||
[0., 1, 1], [1., 1, 1]]])
|
|
||||||
|
|
||||||
neighboor_coords = floor + connection
|
def cic_paint(mesh, positions, halo_size=0, token=None, comms=None):
|
||||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
""" Paints positions onto mesh
|
||||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
mesh: [nx, ny, nz]
|
||||||
|
positions: [npart, 3]
|
||||||
|
"""
|
||||||
|
if comms is not None:
|
||||||
|
# Add some padding for the halo exchange
|
||||||
|
mesh = jnp.pad(mesh, [[halo_size, halo_size],
|
||||||
|
[halo_size, halo_size],
|
||||||
|
[0, 0]])
|
||||||
|
positions += jnp.array([halo_size, halo_size, 0]).reshape([-1, 3])
|
||||||
|
|
||||||
neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,8,3]).astype('int32'), jnp.array(mesh.shape))
|
positions = jnp.expand_dims(positions, 1)
|
||||||
|
floor = jnp.floor(positions)
|
||||||
|
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
|
||||||
|
[0., 0, 1], [1., 1, 0], [1., 0, 1],
|
||||||
|
[0., 1, 1], [1., 1, 1]]])
|
||||||
|
|
||||||
dnums = jax.lax.ScatterDimensionNumbers(
|
neighboor_coords = floor + connection
|
||||||
update_window_dims=(),
|
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||||
inserted_window_dims=(0, 1, 2),
|
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||||
scatter_dims_to_operand_dims=(0, 1, 2))
|
|
||||||
mesh = lax.scatter_add(mesh,
|
|
||||||
neighboor_coords,
|
|
||||||
kernel.reshape([-1,8]),
|
|
||||||
dnums)
|
|
||||||
return mesh
|
|
||||||
|
|
||||||
def cic_read(mesh, positions):
|
neighboor_coords = jnp.mod(neighboor_coords.reshape(
|
||||||
""" Paints positions onto mesh
|
[-1, 8, 3]).astype('int32'), jnp.array(mesh.shape))
|
||||||
mesh: [nx, ny, nz]
|
|
||||||
positions: [npart, 3]
|
|
||||||
"""
|
|
||||||
positions = jnp.expand_dims(positions, 1)
|
|
||||||
floor = jnp.floor(positions)
|
|
||||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
|
|
||||||
[0., 0, 1], [1., 1, 0], [1., 0, 1],
|
|
||||||
[0., 1, 1], [1., 1, 1]]])
|
|
||||||
|
|
||||||
neighboor_coords = floor + connection
|
dnums = jax.lax.ScatterDimensionNumbers(
|
||||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
update_window_dims=(),
|
||||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
inserted_window_dims=(0, 1, 2),
|
||||||
|
scatter_dims_to_operand_dims=(0, 1, 2))
|
||||||
|
mesh = lax.scatter_add(mesh,
|
||||||
|
neighboor_coords,
|
||||||
|
kernel.reshape([-1, 8]),
|
||||||
|
dnums)
|
||||||
|
|
||||||
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'), jnp.array(mesh.shape))
|
if comms == None:
|
||||||
|
return mesh
|
||||||
|
else:
|
||||||
|
mesh, token = halo_reduce(mesh, halo_size, token, comms)
|
||||||
|
return mesh[halo_size:-halo_size, halo_size:-halo_size]
|
||||||
|
|
||||||
|
|
||||||
|
def cic_read(mesh, positions, halo_size=0, token=None, comms=None):
|
||||||
|
""" Paints positions onto mesh
|
||||||
|
mesh: [nx, ny, nz]
|
||||||
|
positions: [npart, 3]
|
||||||
|
"""
|
||||||
|
|
||||||
|
if comms is not None:
|
||||||
|
# Add some padding and perfom hao exchange to retrieve
|
||||||
|
# neighboring regions
|
||||||
|
mesh = jnp.pad(mesh, [[halo_size, halo_size],
|
||||||
|
[halo_size, halo_size],
|
||||||
|
[0, 0]])
|
||||||
|
mesh, token = halo_reduce(mesh, halo_size, token, comms)
|
||||||
|
positions += jnp.array([halo_size, halo_size, 0]).reshape([-1, 3])
|
||||||
|
|
||||||
|
positions = jnp.expand_dims(positions, 1)
|
||||||
|
floor = jnp.floor(positions)
|
||||||
|
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
|
||||||
|
[0., 0, 1], [1., 1, 0], [1., 0, 1],
|
||||||
|
[0., 1, 1], [1., 1, 1]]])
|
||||||
|
|
||||||
|
neighboor_coords = floor + connection
|
||||||
|
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||||
|
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||||
|
|
||||||
|
neighboor_coords = jnp.mod(
|
||||||
|
neighboor_coords.astype('int32'), jnp.array(mesh.shape))
|
||||||
|
|
||||||
|
res = (mesh[neighboor_coords[..., 0],
|
||||||
|
neighboor_coords[..., 1],
|
||||||
|
neighboor_coords[..., 3]]*kernel).sum(axis=-1)
|
||||||
|
|
||||||
|
if comms is not None:
|
||||||
|
return res
|
||||||
|
else:
|
||||||
|
return res, token
|
||||||
|
|
||||||
return (mesh[neighboor_coords[...,0],
|
|
||||||
neighboor_coords[...,1],
|
|
||||||
neighboor_coords[...,3]]*kernel).sum(axis=-1)
|
|
||||||
|
|
||||||
def cic_paint_2d(mesh, positions, weight):
|
def cic_paint_2d(mesh, positions, weight):
|
||||||
""" Paints positions onto a 2d mesh
|
""" Paints positions onto a 2d mesh
|
||||||
mesh: [nx, ny]
|
mesh: [nx, ny]
|
||||||
positions: [npart, 2]
|
positions: [npart, 2]
|
||||||
weight: [npart]
|
weight: [npart]
|
||||||
"""
|
"""
|
||||||
positions = jnp.expand_dims(positions, 1)
|
positions = jnp.expand_dims(positions, 1)
|
||||||
floor = jnp.floor(positions)
|
floor = jnp.floor(positions)
|
||||||
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
|
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
|
||||||
|
|
||||||
neighboor_coords = floor + connection
|
neighboor_coords = floor + connection
|
||||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||||
kernel = kernel[..., 0] * kernel[..., 1]
|
kernel = kernel[..., 0] * kernel[..., 1]
|
||||||
if weight is not None:
|
if weight is not None:
|
||||||
kernel = kernel * weight[...,jnp.newaxis]
|
kernel = kernel * weight[..., jnp.newaxis]
|
||||||
|
|
||||||
neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,4,2]).astype('int32'), jnp.array(mesh.shape))
|
neighboor_coords = jnp.mod(neighboor_coords.reshape(
|
||||||
|
[-1, 4, 2]).astype('int32'), jnp.array(mesh.shape))
|
||||||
|
|
||||||
|
dnums = jax.lax.ScatterDimensionNumbers(
|
||||||
|
update_window_dims=(),
|
||||||
|
inserted_window_dims=(0, 1),
|
||||||
|
scatter_dims_to_operand_dims=(0, 1))
|
||||||
|
mesh = lax.scatter_add(mesh,
|
||||||
|
neighboor_coords,
|
||||||
|
kernel.reshape([-1, 4]),
|
||||||
|
dnums)
|
||||||
|
return mesh
|
||||||
|
|
||||||
dnums = jax.lax.ScatterDimensionNumbers(
|
|
||||||
update_window_dims=(),
|
|
||||||
inserted_window_dims=(0, 1),
|
|
||||||
scatter_dims_to_operand_dims=(0, 1))
|
|
||||||
mesh = lax.scatter_add(mesh,
|
|
||||||
neighboor_coords,
|
|
||||||
kernel.reshape([-1,4]),
|
|
||||||
dnums)
|
|
||||||
return mesh
|
|
||||||
|
|
||||||
def compensate_cic(field):
|
def compensate_cic(field):
|
||||||
"""
|
"""
|
||||||
Compensate for CiC painting
|
Compensate for CiC painting
|
||||||
Args:
|
Args:
|
||||||
field: input 3D cic-painted field
|
field: input 3D cic-painted field
|
||||||
Returns:
|
Returns:
|
||||||
compensated_field
|
compensated_field
|
||||||
"""
|
"""
|
||||||
nc = field.shape
|
nc = field.shape
|
||||||
kvec = fftk(nc)
|
kvec = fftk(nc)
|
||||||
|
|
||||||
delta_k = jnp.fft.rfftn(field)
|
delta_k = jnp.fft.rfftn(field)
|
||||||
delta_k = cic_compensation(kvec) * delta_k
|
delta_k = cic_compensation(kvec) * delta_k
|
||||||
return jnp.fft.irfftn(delta_k)
|
return jnp.fft.irfftn(delta_k)
|
||||||
|
|
Loading…
Add table
Reference in a new issue