Adding an example of jaxdecomp implementation

This commit is contained in:
EiffL 2022-11-26 17:27:14 +01:00
parent 6644b35d71
commit 6ca4c9191e
5 changed files with 166 additions and 192 deletions

View file

@ -3,19 +3,23 @@ from jax.experimental.maps import xmap
import numpy as np
import jax.numpy as jnp
from functools import partial
import jaxdecomp
def fftk(shape, symmetric=False, dtype=np.float32, comms=None):
def fftk(shape, symmetric=False, dtype=np.float32, sharding_info=None):
""" Return k_vector given a shape (nc, nc, nc)
"""
k = []
if comms is not None:
nx = comms[0].Get_size()
ix = comms[0].Get_rank()
ny = comms[1].Get_size()
iy = comms[1].Get_rank()
shape = [shape[0]*nx, shape[1]*ny] + list(shape[2:])
if sharding_info is not None:
nx = sharding_info.pdims[1]
ny = sharding_info.pdims[0]
# nx = sharding_info[0].Get_size()
# ix = sharding_info[0].Get_rank()
# ny = sharding_info[1].Get_size()
# iy = sharding_info[1].Get_rank()
ix = sharding_info.rank
iy = 0
shape = sharding_info.global_shape
for d in range(len(shape)):
kd = np.fft.fftfreq(shape[d])
@ -24,10 +28,10 @@ def fftk(shape, symmetric=False, dtype=np.float32, comms=None):
if symmetric and d == len(shape) - 1:
kd = kd[:shape[d] // 2 + 1]
if (comms is not None) and d == 0:
if (sharding_info is not None) and d == 0:
kd = kd.reshape([nx, -1])[ix]
if (comms is not None) and d == 1:
if (sharding_info is not None) and d == 1:
kd = kd.reshape([ny, -1])[iy]
k.append(kd.astype(dtype))
@ -42,10 +46,9 @@ def apply_gradient_laplace(kfield, kvec):
kx, ky, kz = kvec
kk = (kx**2 + ky**2 + kz**2)
kernel = jnp.where(kk == 0, 1., 1./kk)
return jnp.stack([kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky)),
kfield * kernel * 1j * 1 / 6.0 *
(8 * jnp.sin(kz) - jnp.sin(2 * kz)),
kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx))], axis=-1)
return jnp.stack([kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz)),
kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx)),
kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky))], axis=-1)
def cic_compensation(kvec):

View file

@ -2,155 +2,91 @@
import jax
import jax.numpy as jnp
import mpi4jax
import jaxdecomp
from dataclasses import dataclass
from typing import Tuple
@dataclass
class ShardingInfo:
"""Class for keeping track of the distribution strategy"""
global_shape: Tuple[int, int, int]
pdims: Tuple[int, int]
halo_extents: Tuple[int, int, int]
rank: int = 0
def fft3d(arr, comms=None):
def fft3d(arr, sharding_info=None):
""" Computes forward FFT, note that the output is transposed
"""
if comms is not None:
shape = list(arr.shape)
nx = comms[0].Get_size()
ny = comms[1].Get_size()
# First FFT along z
arr = jnp.fft.fft(arr) # [x, y, z]
# Perform single gpu or distributed transpose
if comms == None:
arr = arr.transpose([1, 2, 0])
if sharding_info is None:
arr = jnp.fft.fftn(arr).transpose([1, 2, 0])
else:
arr = arr.reshape(shape[:-1]+[nx, shape[-1] // nx])
#arr = arr.transpose([2, 1, 3, 0]) # [y, z, x]
arr = jnp.einsum('ij,xyjz->iyzx', jnp.eye(nx), arr) # TODO: remove this hack when we understand why transpose before alltoall doenst work
arr, token = mpi4jax.alltoall(arr, comm=comms[0])
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [y, z, x]
# Second FFT along x
arr = jnp.fft.fft(arr)
# Perform single gpu or distributed transpose
if comms == None:
arr = arr.transpose([1, 2, 0])
else:
arr = arr.reshape(shape[:-1]+[ny, shape[-1] // ny])
#arr = arr.transpose([2, 1, 3, 0]) # [z, x, y]
arr = jnp.einsum('ij,yzjx->izxy', jnp.eye(ny), arr) # TODO: remove this hack when we understand why transpose before alltoall doenst work
arr, token = mpi4jax.alltoall(arr, comm=comms[1], token=token)
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [z, x, y]
# Third FFT along y
return jnp.fft.fft(arr)
def ifft3d(arr, comms=None):
""" Let's assume that the data is distributed accross x
"""
if comms is not None:
shape = list(arr.shape)
nx = comms[0].Get_size()
ny = comms[1].Get_size()
# First FFT along y
arr = jnp.fft.ifft(arr) # Now [z, x, y]
# Perform single gpu or distributed transpose
if comms == None:
arr = arr.transpose([0, 2, 1])
else:
arr = arr.reshape(shape[:-1]+[ny, shape[-1] // ny])
# arr = arr.transpose([2, 0, 3, 1]) # Now [z, y, x]
arr = jnp.einsum('ij,zxjy->izyx', jnp.eye(ny), arr) # TODO: remove this hack when we understand why transpose before alltoall doenst work
arr, token = mpi4jax.alltoall(arr, comm=comms[1])
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [z,y,x]
# Second FFT along x
arr = jnp.fft.ifft(arr)
# Perform single gpu or distributed transpose
if comms == None:
arr = arr.transpose([2, 1, 0])
else:
arr = arr.reshape(shape[:-1]+[nx, shape[-1] // nx])
# arr = arr.transpose([2, 3, 1, 0]) # now [x, y, z]
arr = jnp.einsum('ij,zyjx->ixyz', jnp.eye(nx), arr) # TODO: remove this hack when we understand why transpose before alltoall doenst work
arr, token = mpi4jax.alltoall(arr, comm=comms[0], token=token)
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [x,y,z]
# Third FFT along z
return jnp.fft.ifft(arr)
def halo_reduce(arr, halo_size, comms=None):
if halo_size <= 0:
return arr
# Perform halo exchange along x
rank_x = comms[0].Get_rank()
size_x = comms[0].Get_size()
margin = arr[-2*halo_size:]
left, token = mpi4jax.sendrecv(margin, margin,
(rank_x-1) % size_x,
(rank_x+1) % size_x,
comm=comms[0])
margin = arr[:2*halo_size]
right, token = mpi4jax.sendrecv(margin, margin,
(rank_x+1) % size_x,
(rank_x-1) % size_x,
comm=comms[0], token=token)
arr = arr.at[:2*halo_size].add(left)
arr = arr.at[-2*halo_size:].add(right)
# Perform halo exchange along y
rank_y = comms[1].Get_rank()
size_y = comms[1].Get_size()
margin = arr[:, -2*halo_size:]
left, token = mpi4jax.sendrecv(margin, margin,
(rank_y-1) % size_y,
(rank_y+1) % size_y,
comm=comms[1], token=token)
margin = arr[:, :2*halo_size]
right, token = mpi4jax.sendrecv(margin, margin,
(rank_y+1) % size_y,
(rank_y-1) % size_y,
comm=comms[1], token=token)
arr = arr.at[:, :2*halo_size].add(left)
arr = arr.at[:, -2*halo_size:].add(right)
arr = jaxdecomp.pfft3d(arr,
pdims=sharding_info.pdims,
global_shape=sharding_info.global_shape)
return arr
def meshgrid3d(shape, comms=None):
if comms is not None:
nx = comms[0].Get_size()
ny = comms[1].Get_size()
def ifft3d(arr, sharding_info=None):
if sharding_info is None:
arr = jnp.fft.ifftn(arr.transpose([2, 0, 1]))
else:
arr = jaxdecomp.pifft3d(arr,
pdims=sharding_info.pdims,
global_shape=sharding_info.global_shape)
return arr
coords = [jnp.arange(shape[0]//nx),
jnp.arange(shape[1]//ny)] + [jnp.arange(s) for s in shape[2:]]
def halo_reduce(arr, sharding_info=None):
if sharding_info is None:
return arr
halo_size = sharding_info.halo_extents[0]
global_shape = sharding_info.global_shape
arr = jaxdecomp.halo_exchange(arr,
halo_extents=(halo_size//2, halo_size//2, 0),
halo_periods=(True,True,True),
pdims=sharding_info.pdims,
global_shape=(global_shape[0]+2*halo_size,
global_shape[1]+halo_size,
global_shape[2]))
# Apply correction along x
arr = arr.at[halo_size:halo_size + halo_size//2].add(arr[ :halo_size//2])
arr = arr.at[-halo_size - halo_size//2:-halo_size].add(arr[-halo_size//2:])
# Apply correction along y
arr = arr.at[:, halo_size:halo_size + halo_size//2].add(arr[:, :halo_size//2][:, :])
arr = arr.at[:, -halo_size - halo_size//2:-halo_size].add(arr[:, -halo_size//2:][:, :])
return arr
def meshgrid3d(shape, sharding_info=None):
if sharding_info is not None:
coords = [jnp.arange(sharding_info.global_shape[0]//sharding_info.pdims[1]),
jnp.arange(sharding_info.global_shape[1]//sharding_info.pdims[0]), jnp.arange(sharding_info.global_shape[2])]
else:
coords = [jnp.arange(s) for s in shape[2:]]
return jnp.stack(jnp.meshgrid(*coords), axis=-1).reshape([-1, 3])
def zeros(shape, comms=None):
def zeros(shape, sharding_info=None):
""" Initialize an array of given global shape
partitionned if need be accross dimensions.
"""
if comms is None:
if sharding_info is None:
return jnp.zeros(shape)
nx = comms[0].Get_size()
ny = comms[1].Get_size()
return jnp.zeros([shape[0]//nx, shape[1]//ny]+list(shape[2:]))
return jnp.zeros([sharding_info.global_shape[0]//sharding_info.pdims[1], sharding_info.global_shape[1]//sharding_info.pdims[0]]+list(sharding_info.global_shape[2:]))
def normal(key, shape, comms=None):
def normal(key, shape, sharding_info=None):
""" Generates a normal variable for the given
global shape.
"""
if comms is None:
if sharding_info is None:
return jax.random.normal(key, shape)
nx = comms[0].Get_size()
ny = comms[1].Get_size()
return jax.random.normal(key,
[shape[0]//nx, shape[1]//ny]+list(shape[2:]))
[sharding_info.global_shape[0]//sharding_info.pdims[1], sharding_info.global_shape[1]//sharding_info.pdims[0], sharding_info.global_shape[2]])

View file

@ -6,12 +6,12 @@ from jaxpm.ops import halo_reduce
from jaxpm.kernels import fftk, cic_compensation
def cic_paint(mesh, positions, halo_size=0, comms=None):
def cic_paint(mesh, positions, halo_size=0, sharding_info=None):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [npart, 3]
"""
if comms is not None:
if sharding_info is not None:
# Add some padding for the halo exchange
mesh = jnp.pad(mesh, [[halo_size, halo_size],
[halo_size, halo_size],
@ -40,26 +40,32 @@ def cic_paint(mesh, positions, halo_size=0, comms=None):
kernel.reshape([-1, 8]),
dnums)
if comms == None:
if sharding_info == None:
return mesh
else:
mesh = halo_reduce(mesh, halo_size, comms)
mesh = halo_reduce(mesh, sharding_info)
return mesh[halo_size:-halo_size, halo_size:-halo_size]
def cic_read(mesh, positions, halo_size=0, comms=None):
def cic_read(mesh, positions, halo_size=0, sharding_info=None):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [npart, 3]
"""
if comms is not None:
if sharding_info is not None:
# Add some padding and perfom hao exchange to retrieve
# neighboring regions
mesh = jnp.pad(mesh, [[halo_size, halo_size],
[halo_size, halo_size],
[0, 0]])
mesh = halo_reduce(mesh, halo_size, comms)
# mesh = halo_reduce(mesh, sharding_info)
import jaxdecomp
mesh = jaxdecomp.halo_exchange(mesh,
halo_extents=sharding_info.halo_extents,
halo_periods=(True,True,True),
pdims=sharding_info.pdims,
global_shape=sharding_info.global_shape)
positions += jnp.array([halo_size, halo_size, 0]).reshape([-1, 3])
positions = jnp.expand_dims(positions, 1)

View file

@ -10,32 +10,32 @@ from jaxpm.painting import cic_paint, cic_read
from jaxpm.growth import growth_factor, growth_rate, dGfa
def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=0, token=None, comms=None):
def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=0, sharding_info=None):
"""
Computes gravitational forces on particles using a PM scheme
"""
if delta_k is None:
delta = cic_paint(zeros(mesh_shape, comms=comms),
delta = cic_paint(zeros(mesh_shape, sharding_info=sharding_info),
positions,
halo_size=halo_size, comms=comms)
delta_k = fft3d(delta, comms=comms)
halo_size=halo_size, sharding_info=sharding_info)
delta_k = fft3d(delta, sharding_info=sharding_info)
# Computes gravitational forces
kvec = fftk(delta_k.shape, symmetric=False, comms=comms)
kvec = fftk(delta_k.shape, symmetric=False, sharding_info=sharding_info)
forces_k = apply_gradient_laplace(delta_k, kvec)
# Interpolate forces at the position of particles
return jnp.stack([cic_read(ifft3d(forces_k[..., i], comms=comms).real,
positions, halo_size=halo_size, comms=comms)
return jnp.stack([cic_read(ifft3d(forces_k[..., i], sharding_info=sharding_info).real,
positions, halo_size=halo_size, sharding_info=sharding_info)
for i in range(3)], axis=-1)
def lpt(cosmo, positions, initial_conditions, a, halo_size=0, comms=None):
def lpt(cosmo, positions, initial_conditions, a, halo_size=0, sharding_info=None):
"""
Computes first order LPT displacement
"""
initial_force = pm_forces(
positions, delta_k=initial_conditions, halo_size=halo_size, comms=comms)
positions, delta_k=initial_conditions, halo_size=halo_size, sharding_info=sharding_info)
a = jnp.atleast_1d(a)
dx = growth_factor(cosmo, a) * initial_force
p = a**2 * growth_rate(cosmo, a) * \
@ -45,21 +45,21 @@ def lpt(cosmo, positions, initial_conditions, a, halo_size=0, comms=None):
return dx, p, f
def linear_field(cosmo, mesh_shape, box_size, key, comms=None):
def linear_field(cosmo, mesh_shape, box_size, key, sharding_info=None):
"""
Generate initial conditions in Fourier space.
"""
# Sample normal field
field = normal(key, mesh_shape, comms=comms)
field = normal(key, mesh_shape, sharding_info=sharding_info)
# Transform to Fourier space
kfield = fft3d(field, comms=comms)
kfield = fft3d(field, sharding_info=sharding_info)
# Rescaling k to physical units
kvec = [k / box_size[i] * mesh_shape[i]
for i, k in enumerate(fftk(kfield.shape,
symmetric=False,
comms=comms))]
sharding_info=sharding_info))]
# Evaluating linear matter powerspectrum
k = jnp.logspace(-4, 2, 256)
@ -77,7 +77,7 @@ def linear_field(cosmo, mesh_shape, box_size, key, comms=None):
return kfield
def make_ode_fn(mesh_shape, halo_size=0, comms=None):
def make_ode_fn(mesh_shape, halo_size=0, sharding_info=None):
def nbody_ode(state, a, cosmo):
"""
@ -86,7 +86,7 @@ def make_ode_fn(mesh_shape, halo_size=0, comms=None):
pos, vel = state
forces = pm_forces(pos, mesh_shape=mesh_shape,
halo_size=halo_size, comms=comms) * 1.5 * cosmo.Omega_m
halo_size=halo_size, sharding_info=sharding_info) * 1.5 * cosmo.Omega_m
# Computes the update of position (drift)
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel