mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-29 16:41:11 +00:00
Adding an example of jaxdecomp implementation
This commit is contained in:
parent
6644b35d71
commit
6ca4c9191e
5 changed files with 166 additions and 192 deletions
184
jaxpm/ops.py
184
jaxpm/ops.py
|
@ -2,155 +2,91 @@
|
|||
import jax
|
||||
import jax.numpy as jnp
|
||||
import mpi4jax
|
||||
import jaxdecomp
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
@dataclass
|
||||
class ShardingInfo:
|
||||
"""Class for keeping track of the distribution strategy"""
|
||||
global_shape: Tuple[int, int, int]
|
||||
pdims: Tuple[int, int]
|
||||
halo_extents: Tuple[int, int, int]
|
||||
rank: int = 0
|
||||
|
||||
|
||||
def fft3d(arr, comms=None):
|
||||
def fft3d(arr, sharding_info=None):
|
||||
""" Computes forward FFT, note that the output is transposed
|
||||
"""
|
||||
if comms is not None:
|
||||
shape = list(arr.shape)
|
||||
nx = comms[0].Get_size()
|
||||
ny = comms[1].Get_size()
|
||||
|
||||
# First FFT along z
|
||||
arr = jnp.fft.fft(arr) # [x, y, z]
|
||||
# Perform single gpu or distributed transpose
|
||||
if comms == None:
|
||||
arr = arr.transpose([1, 2, 0])
|
||||
if sharding_info is None:
|
||||
arr = jnp.fft.fftn(arr).transpose([1, 2, 0])
|
||||
else:
|
||||
arr = arr.reshape(shape[:-1]+[nx, shape[-1] // nx])
|
||||
#arr = arr.transpose([2, 1, 3, 0]) # [y, z, x]
|
||||
arr = jnp.einsum('ij,xyjz->iyzx', jnp.eye(nx), arr) # TODO: remove this hack when we understand why transpose before alltoall doenst work
|
||||
arr, token = mpi4jax.alltoall(arr, comm=comms[0])
|
||||
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [y, z, x]
|
||||
|
||||
# Second FFT along x
|
||||
arr = jnp.fft.fft(arr)
|
||||
# Perform single gpu or distributed transpose
|
||||
if comms == None:
|
||||
arr = arr.transpose([1, 2, 0])
|
||||
else:
|
||||
arr = arr.reshape(shape[:-1]+[ny, shape[-1] // ny])
|
||||
#arr = arr.transpose([2, 1, 3, 0]) # [z, x, y]
|
||||
arr = jnp.einsum('ij,yzjx->izxy', jnp.eye(ny), arr) # TODO: remove this hack when we understand why transpose before alltoall doenst work
|
||||
arr, token = mpi4jax.alltoall(arr, comm=comms[1], token=token)
|
||||
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [z, x, y]
|
||||
|
||||
# Third FFT along y
|
||||
return jnp.fft.fft(arr)
|
||||
|
||||
|
||||
def ifft3d(arr, comms=None):
|
||||
""" Let's assume that the data is distributed accross x
|
||||
"""
|
||||
if comms is not None:
|
||||
shape = list(arr.shape)
|
||||
nx = comms[0].Get_size()
|
||||
ny = comms[1].Get_size()
|
||||
|
||||
# First FFT along y
|
||||
arr = jnp.fft.ifft(arr) # Now [z, x, y]
|
||||
# Perform single gpu or distributed transpose
|
||||
if comms == None:
|
||||
arr = arr.transpose([0, 2, 1])
|
||||
else:
|
||||
arr = arr.reshape(shape[:-1]+[ny, shape[-1] // ny])
|
||||
# arr = arr.transpose([2, 0, 3, 1]) # Now [z, y, x]
|
||||
arr = jnp.einsum('ij,zxjy->izyx', jnp.eye(ny), arr) # TODO: remove this hack when we understand why transpose before alltoall doenst work
|
||||
arr, token = mpi4jax.alltoall(arr, comm=comms[1])
|
||||
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [z,y,x]
|
||||
|
||||
# Second FFT along x
|
||||
arr = jnp.fft.ifft(arr)
|
||||
# Perform single gpu or distributed transpose
|
||||
if comms == None:
|
||||
arr = arr.transpose([2, 1, 0])
|
||||
else:
|
||||
arr = arr.reshape(shape[:-1]+[nx, shape[-1] // nx])
|
||||
# arr = arr.transpose([2, 3, 1, 0]) # now [x, y, z]
|
||||
arr = jnp.einsum('ij,zyjx->ixyz', jnp.eye(nx), arr) # TODO: remove this hack when we understand why transpose before alltoall doenst work
|
||||
arr, token = mpi4jax.alltoall(arr, comm=comms[0], token=token)
|
||||
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [x,y,z]
|
||||
|
||||
# Third FFT along z
|
||||
return jnp.fft.ifft(arr)
|
||||
|
||||
|
||||
def halo_reduce(arr, halo_size, comms=None):
|
||||
if halo_size <= 0:
|
||||
return arr
|
||||
|
||||
# Perform halo exchange along x
|
||||
rank_x = comms[0].Get_rank()
|
||||
size_x = comms[0].Get_size()
|
||||
margin = arr[-2*halo_size:]
|
||||
left, token = mpi4jax.sendrecv(margin, margin,
|
||||
(rank_x-1) % size_x,
|
||||
(rank_x+1) % size_x,
|
||||
comm=comms[0])
|
||||
margin = arr[:2*halo_size]
|
||||
right, token = mpi4jax.sendrecv(margin, margin,
|
||||
(rank_x+1) % size_x,
|
||||
(rank_x-1) % size_x,
|
||||
comm=comms[0], token=token)
|
||||
|
||||
arr = arr.at[:2*halo_size].add(left)
|
||||
arr = arr.at[-2*halo_size:].add(right)
|
||||
|
||||
# Perform halo exchange along y
|
||||
rank_y = comms[1].Get_rank()
|
||||
size_y = comms[1].Get_size()
|
||||
margin = arr[:, -2*halo_size:]
|
||||
left, token = mpi4jax.sendrecv(margin, margin,
|
||||
(rank_y-1) % size_y,
|
||||
(rank_y+1) % size_y,
|
||||
comm=comms[1], token=token)
|
||||
margin = arr[:, :2*halo_size]
|
||||
right, token = mpi4jax.sendrecv(margin, margin,
|
||||
(rank_y+1) % size_y,
|
||||
(rank_y-1) % size_y,
|
||||
comm=comms[1], token=token)
|
||||
arr = arr.at[:, :2*halo_size].add(left)
|
||||
arr = arr.at[:, -2*halo_size:].add(right)
|
||||
|
||||
arr = jaxdecomp.pfft3d(arr,
|
||||
pdims=sharding_info.pdims,
|
||||
global_shape=sharding_info.global_shape)
|
||||
return arr
|
||||
|
||||
|
||||
def meshgrid3d(shape, comms=None):
|
||||
if comms is not None:
|
||||
nx = comms[0].Get_size()
|
||||
ny = comms[1].Get_size()
|
||||
def ifft3d(arr, sharding_info=None):
|
||||
if sharding_info is None:
|
||||
arr = jnp.fft.ifftn(arr.transpose([2, 0, 1]))
|
||||
else:
|
||||
arr = jaxdecomp.pifft3d(arr,
|
||||
pdims=sharding_info.pdims,
|
||||
global_shape=sharding_info.global_shape)
|
||||
return arr
|
||||
|
||||
coords = [jnp.arange(shape[0]//nx),
|
||||
jnp.arange(shape[1]//ny)] + [jnp.arange(s) for s in shape[2:]]
|
||||
|
||||
def halo_reduce(arr, sharding_info=None):
|
||||
if sharding_info is None:
|
||||
return arr
|
||||
halo_size = sharding_info.halo_extents[0]
|
||||
global_shape = sharding_info.global_shape
|
||||
arr = jaxdecomp.halo_exchange(arr,
|
||||
halo_extents=(halo_size//2, halo_size//2, 0),
|
||||
halo_periods=(True,True,True),
|
||||
pdims=sharding_info.pdims,
|
||||
global_shape=(global_shape[0]+2*halo_size,
|
||||
global_shape[1]+halo_size,
|
||||
global_shape[2]))
|
||||
|
||||
# Apply correction along x
|
||||
arr = arr.at[halo_size:halo_size + halo_size//2].add(arr[ :halo_size//2])
|
||||
arr = arr.at[-halo_size - halo_size//2:-halo_size].add(arr[-halo_size//2:])
|
||||
|
||||
# Apply correction along y
|
||||
arr = arr.at[:, halo_size:halo_size + halo_size//2].add(arr[:, :halo_size//2][:, :])
|
||||
arr = arr.at[:, -halo_size - halo_size//2:-halo_size].add(arr[:, -halo_size//2:][:, :])
|
||||
|
||||
return arr
|
||||
|
||||
|
||||
def meshgrid3d(shape, sharding_info=None):
|
||||
if sharding_info is not None:
|
||||
coords = [jnp.arange(sharding_info.global_shape[0]//sharding_info.pdims[1]),
|
||||
jnp.arange(sharding_info.global_shape[1]//sharding_info.pdims[0]), jnp.arange(sharding_info.global_shape[2])]
|
||||
else:
|
||||
coords = [jnp.arange(s) for s in shape[2:]]
|
||||
|
||||
return jnp.stack(jnp.meshgrid(*coords), axis=-1).reshape([-1, 3])
|
||||
|
||||
|
||||
def zeros(shape, comms=None):
|
||||
def zeros(shape, sharding_info=None):
|
||||
""" Initialize an array of given global shape
|
||||
partitionned if need be accross dimensions.
|
||||
"""
|
||||
if comms is None:
|
||||
if sharding_info is None:
|
||||
return jnp.zeros(shape)
|
||||
|
||||
nx = comms[0].Get_size()
|
||||
ny = comms[1].Get_size()
|
||||
|
||||
return jnp.zeros([shape[0]//nx, shape[1]//ny]+list(shape[2:]))
|
||||
return jnp.zeros([sharding_info.global_shape[0]//sharding_info.pdims[1], sharding_info.global_shape[1]//sharding_info.pdims[0]]+list(sharding_info.global_shape[2:]))
|
||||
|
||||
|
||||
def normal(key, shape, comms=None):
|
||||
def normal(key, shape, sharding_info=None):
|
||||
""" Generates a normal variable for the given
|
||||
global shape.
|
||||
"""
|
||||
if comms is None:
|
||||
if sharding_info is None:
|
||||
return jax.random.normal(key, shape)
|
||||
|
||||
nx = comms[0].Get_size()
|
||||
ny = comms[1].Get_size()
|
||||
|
||||
return jax.random.normal(key,
|
||||
[shape[0]//nx, shape[1]//ny]+list(shape[2:]))
|
||||
[sharding_info.global_shape[0]//sharding_info.pdims[1], sharding_info.global_shape[1]//sharding_info.pdims[0], sharding_info.global_shape[2]])
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue