Adding an example of jaxdecomp implementation

This commit is contained in:
EiffL 2022-11-26 17:27:14 +01:00
parent 6644b35d71
commit 6ca4c9191e
5 changed files with 166 additions and 192 deletions

View file

@ -3,19 +3,23 @@ from jax.experimental.maps import xmap
import numpy as np
import jax.numpy as jnp
from functools import partial
import jaxdecomp
def fftk(shape, symmetric=False, dtype=np.float32, comms=None):
def fftk(shape, symmetric=False, dtype=np.float32, sharding_info=None):
""" Return k_vector given a shape (nc, nc, nc)
"""
k = []
if comms is not None:
nx = comms[0].Get_size()
ix = comms[0].Get_rank()
ny = comms[1].Get_size()
iy = comms[1].Get_rank()
shape = [shape[0]*nx, shape[1]*ny] + list(shape[2:])
if sharding_info is not None:
nx = sharding_info.pdims[1]
ny = sharding_info.pdims[0]
# nx = sharding_info[0].Get_size()
# ix = sharding_info[0].Get_rank()
# ny = sharding_info[1].Get_size()
# iy = sharding_info[1].Get_rank()
ix = sharding_info.rank
iy = 0
shape = sharding_info.global_shape
for d in range(len(shape)):
kd = np.fft.fftfreq(shape[d])
@ -24,10 +28,10 @@ def fftk(shape, symmetric=False, dtype=np.float32, comms=None):
if symmetric and d == len(shape) - 1:
kd = kd[:shape[d] // 2 + 1]
if (comms is not None) and d == 0:
if (sharding_info is not None) and d == 0:
kd = kd.reshape([nx, -1])[ix]
if (comms is not None) and d == 1:
if (sharding_info is not None) and d == 1:
kd = kd.reshape([ny, -1])[iy]
k.append(kd.astype(dtype))
@ -42,10 +46,9 @@ def apply_gradient_laplace(kfield, kvec):
kx, ky, kz = kvec
kk = (kx**2 + ky**2 + kz**2)
kernel = jnp.where(kk == 0, 1., 1./kk)
return jnp.stack([kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky)),
kfield * kernel * 1j * 1 / 6.0 *
(8 * jnp.sin(kz) - jnp.sin(2 * kz)),
kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx))], axis=-1)
return jnp.stack([kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz)),
kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx)),
kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky))], axis=-1)
def cic_compensation(kvec):

View file

@ -2,155 +2,91 @@
import jax
import jax.numpy as jnp
import mpi4jax
import jaxdecomp
from dataclasses import dataclass
from typing import Tuple
@dataclass
class ShardingInfo:
"""Class for keeping track of the distribution strategy"""
global_shape: Tuple[int, int, int]
pdims: Tuple[int, int]
halo_extents: Tuple[int, int, int]
rank: int = 0
def fft3d(arr, comms=None):
def fft3d(arr, sharding_info=None):
""" Computes forward FFT, note that the output is transposed
"""
if comms is not None:
shape = list(arr.shape)
nx = comms[0].Get_size()
ny = comms[1].Get_size()
# First FFT along z
arr = jnp.fft.fft(arr) # [x, y, z]
# Perform single gpu or distributed transpose
if comms == None:
arr = arr.transpose([1, 2, 0])
if sharding_info is None:
arr = jnp.fft.fftn(arr).transpose([1, 2, 0])
else:
arr = arr.reshape(shape[:-1]+[nx, shape[-1] // nx])
#arr = arr.transpose([2, 1, 3, 0]) # [y, z, x]
arr = jnp.einsum('ij,xyjz->iyzx', jnp.eye(nx), arr) # TODO: remove this hack when we understand why transpose before alltoall doenst work
arr, token = mpi4jax.alltoall(arr, comm=comms[0])
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [y, z, x]
arr = jaxdecomp.pfft3d(arr,
pdims=sharding_info.pdims,
global_shape=sharding_info.global_shape)
return arr
# Second FFT along x
arr = jnp.fft.fft(arr)
# Perform single gpu or distributed transpose
if comms == None:
arr = arr.transpose([1, 2, 0])
def ifft3d(arr, sharding_info=None):
if sharding_info is None:
arr = jnp.fft.ifftn(arr.transpose([2, 0, 1]))
else:
arr = arr.reshape(shape[:-1]+[ny, shape[-1] // ny])
#arr = arr.transpose([2, 1, 3, 0]) # [z, x, y]
arr = jnp.einsum('ij,yzjx->izxy', jnp.eye(ny), arr) # TODO: remove this hack when we understand why transpose before alltoall doenst work
arr, token = mpi4jax.alltoall(arr, comm=comms[1], token=token)
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [z, x, y]
# Third FFT along y
return jnp.fft.fft(arr)
arr = jaxdecomp.pifft3d(arr,
pdims=sharding_info.pdims,
global_shape=sharding_info.global_shape)
return arr
def ifft3d(arr, comms=None):
""" Let's assume that the data is distributed accross x
"""
if comms is not None:
shape = list(arr.shape)
nx = comms[0].Get_size()
ny = comms[1].Get_size()
# First FFT along y
arr = jnp.fft.ifft(arr) # Now [z, x, y]
# Perform single gpu or distributed transpose
if comms == None:
arr = arr.transpose([0, 2, 1])
else:
arr = arr.reshape(shape[:-1]+[ny, shape[-1] // ny])
# arr = arr.transpose([2, 0, 3, 1]) # Now [z, y, x]
arr = jnp.einsum('ij,zxjy->izyx', jnp.eye(ny), arr) # TODO: remove this hack when we understand why transpose before alltoall doenst work
arr, token = mpi4jax.alltoall(arr, comm=comms[1])
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [z,y,x]
# Second FFT along x
arr = jnp.fft.ifft(arr)
# Perform single gpu or distributed transpose
if comms == None:
arr = arr.transpose([2, 1, 0])
else:
arr = arr.reshape(shape[:-1]+[nx, shape[-1] // nx])
# arr = arr.transpose([2, 3, 1, 0]) # now [x, y, z]
arr = jnp.einsum('ij,zyjx->ixyz', jnp.eye(nx), arr) # TODO: remove this hack when we understand why transpose before alltoall doenst work
arr, token = mpi4jax.alltoall(arr, comm=comms[0], token=token)
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [x,y,z]
# Third FFT along z
return jnp.fft.ifft(arr)
def halo_reduce(arr, halo_size, comms=None):
if halo_size <= 0:
def halo_reduce(arr, sharding_info=None):
if sharding_info is None:
return arr
halo_size = sharding_info.halo_extents[0]
global_shape = sharding_info.global_shape
arr = jaxdecomp.halo_exchange(arr,
halo_extents=(halo_size//2, halo_size//2, 0),
halo_periods=(True,True,True),
pdims=sharding_info.pdims,
global_shape=(global_shape[0]+2*halo_size,
global_shape[1]+halo_size,
global_shape[2]))
# Perform halo exchange along x
rank_x = comms[0].Get_rank()
size_x = comms[0].Get_size()
margin = arr[-2*halo_size:]
left, token = mpi4jax.sendrecv(margin, margin,
(rank_x-1) % size_x,
(rank_x+1) % size_x,
comm=comms[0])
margin = arr[:2*halo_size]
right, token = mpi4jax.sendrecv(margin, margin,
(rank_x+1) % size_x,
(rank_x-1) % size_x,
comm=comms[0], token=token)
# Apply correction along x
arr = arr.at[halo_size:halo_size + halo_size//2].add(arr[ :halo_size//2])
arr = arr.at[-halo_size - halo_size//2:-halo_size].add(arr[-halo_size//2:])
arr = arr.at[:2*halo_size].add(left)
arr = arr.at[-2*halo_size:].add(right)
# Perform halo exchange along y
rank_y = comms[1].Get_rank()
size_y = comms[1].Get_size()
margin = arr[:, -2*halo_size:]
left, token = mpi4jax.sendrecv(margin, margin,
(rank_y-1) % size_y,
(rank_y+1) % size_y,
comm=comms[1], token=token)
margin = arr[:, :2*halo_size]
right, token = mpi4jax.sendrecv(margin, margin,
(rank_y+1) % size_y,
(rank_y-1) % size_y,
comm=comms[1], token=token)
arr = arr.at[:, :2*halo_size].add(left)
arr = arr.at[:, -2*halo_size:].add(right)
# Apply correction along y
arr = arr.at[:, halo_size:halo_size + halo_size//2].add(arr[:, :halo_size//2][:, :])
arr = arr.at[:, -halo_size - halo_size//2:-halo_size].add(arr[:, -halo_size//2:][:, :])
return arr
def meshgrid3d(shape, comms=None):
if comms is not None:
nx = comms[0].Get_size()
ny = comms[1].Get_size()
coords = [jnp.arange(shape[0]//nx),
jnp.arange(shape[1]//ny)] + [jnp.arange(s) for s in shape[2:]]
def meshgrid3d(shape, sharding_info=None):
if sharding_info is not None:
coords = [jnp.arange(sharding_info.global_shape[0]//sharding_info.pdims[1]),
jnp.arange(sharding_info.global_shape[1]//sharding_info.pdims[0]), jnp.arange(sharding_info.global_shape[2])]
else:
coords = [jnp.arange(s) for s in shape[2:]]
return jnp.stack(jnp.meshgrid(*coords), axis=-1).reshape([-1, 3])
def zeros(shape, comms=None):
def zeros(shape, sharding_info=None):
""" Initialize an array of given global shape
partitionned if need be accross dimensions.
"""
if comms is None:
if sharding_info is None:
return jnp.zeros(shape)
nx = comms[0].Get_size()
ny = comms[1].Get_size()
return jnp.zeros([shape[0]//nx, shape[1]//ny]+list(shape[2:]))
return jnp.zeros([sharding_info.global_shape[0]//sharding_info.pdims[1], sharding_info.global_shape[1]//sharding_info.pdims[0]]+list(sharding_info.global_shape[2:]))
def normal(key, shape, comms=None):
def normal(key, shape, sharding_info=None):
""" Generates a normal variable for the given
global shape.
"""
if comms is None:
if sharding_info is None:
return jax.random.normal(key, shape)
nx = comms[0].Get_size()
ny = comms[1].Get_size()
return jax.random.normal(key,
[shape[0]//nx, shape[1]//ny]+list(shape[2:]))
[sharding_info.global_shape[0]//sharding_info.pdims[1], sharding_info.global_shape[1]//sharding_info.pdims[0], sharding_info.global_shape[2]])

View file

@ -6,12 +6,12 @@ from jaxpm.ops import halo_reduce
from jaxpm.kernels import fftk, cic_compensation
def cic_paint(mesh, positions, halo_size=0, comms=None):
def cic_paint(mesh, positions, halo_size=0, sharding_info=None):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [npart, 3]
"""
if comms is not None:
if sharding_info is not None:
# Add some padding for the halo exchange
mesh = jnp.pad(mesh, [[halo_size, halo_size],
[halo_size, halo_size],
@ -40,26 +40,32 @@ def cic_paint(mesh, positions, halo_size=0, comms=None):
kernel.reshape([-1, 8]),
dnums)
if comms == None:
if sharding_info == None:
return mesh
else:
mesh = halo_reduce(mesh, halo_size, comms)
mesh = halo_reduce(mesh, sharding_info)
return mesh[halo_size:-halo_size, halo_size:-halo_size]
def cic_read(mesh, positions, halo_size=0, comms=None):
def cic_read(mesh, positions, halo_size=0, sharding_info=None):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [npart, 3]
"""
if comms is not None:
if sharding_info is not None:
# Add some padding and perfom hao exchange to retrieve
# neighboring regions
mesh = jnp.pad(mesh, [[halo_size, halo_size],
[halo_size, halo_size],
[0, 0]])
mesh = halo_reduce(mesh, halo_size, comms)
# mesh = halo_reduce(mesh, sharding_info)
import jaxdecomp
mesh = jaxdecomp.halo_exchange(mesh,
halo_extents=sharding_info.halo_extents,
halo_periods=(True,True,True),
pdims=sharding_info.pdims,
global_shape=sharding_info.global_shape)
positions += jnp.array([halo_size, halo_size, 0]).reshape([-1, 3])
positions = jnp.expand_dims(positions, 1)

View file

@ -10,32 +10,32 @@ from jaxpm.painting import cic_paint, cic_read
from jaxpm.growth import growth_factor, growth_rate, dGfa
def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=0, token=None, comms=None):
def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=0, sharding_info=None):
"""
Computes gravitational forces on particles using a PM scheme
"""
if delta_k is None:
delta = cic_paint(zeros(mesh_shape, comms=comms),
delta = cic_paint(zeros(mesh_shape, sharding_info=sharding_info),
positions,
halo_size=halo_size, comms=comms)
delta_k = fft3d(delta, comms=comms)
halo_size=halo_size, sharding_info=sharding_info)
delta_k = fft3d(delta, sharding_info=sharding_info)
# Computes gravitational forces
kvec = fftk(delta_k.shape, symmetric=False, comms=comms)
kvec = fftk(delta_k.shape, symmetric=False, sharding_info=sharding_info)
forces_k = apply_gradient_laplace(delta_k, kvec)
# Interpolate forces at the position of particles
return jnp.stack([cic_read(ifft3d(forces_k[..., i], comms=comms).real,
positions, halo_size=halo_size, comms=comms)
return jnp.stack([cic_read(ifft3d(forces_k[..., i], sharding_info=sharding_info).real,
positions, halo_size=halo_size, sharding_info=sharding_info)
for i in range(3)], axis=-1)
def lpt(cosmo, positions, initial_conditions, a, halo_size=0, comms=None):
def lpt(cosmo, positions, initial_conditions, a, halo_size=0, sharding_info=None):
"""
Computes first order LPT displacement
"""
initial_force = pm_forces(
positions, delta_k=initial_conditions, halo_size=halo_size, comms=comms)
positions, delta_k=initial_conditions, halo_size=halo_size, sharding_info=sharding_info)
a = jnp.atleast_1d(a)
dx = growth_factor(cosmo, a) * initial_force
p = a**2 * growth_rate(cosmo, a) * \
@ -45,21 +45,21 @@ def lpt(cosmo, positions, initial_conditions, a, halo_size=0, comms=None):
return dx, p, f
def linear_field(cosmo, mesh_shape, box_size, key, comms=None):
def linear_field(cosmo, mesh_shape, box_size, key, sharding_info=None):
"""
Generate initial conditions in Fourier space.
"""
# Sample normal field
field = normal(key, mesh_shape, comms=comms)
field = normal(key, mesh_shape, sharding_info=sharding_info)
# Transform to Fourier space
kfield = fft3d(field, comms=comms)
kfield = fft3d(field, sharding_info=sharding_info)
# Rescaling k to physical units
kvec = [k / box_size[i] * mesh_shape[i]
for i, k in enumerate(fftk(kfield.shape,
symmetric=False,
comms=comms))]
sharding_info=sharding_info))]
# Evaluating linear matter powerspectrum
k = jnp.logspace(-4, 2, 256)
@ -77,7 +77,7 @@ def linear_field(cosmo, mesh_shape, box_size, key, comms=None):
return kfield
def make_ode_fn(mesh_shape, halo_size=0, comms=None):
def make_ode_fn(mesh_shape, halo_size=0, sharding_info=None):
def nbody_ode(state, a, cosmo):
"""
@ -86,7 +86,7 @@ def make_ode_fn(mesh_shape, halo_size=0, comms=None):
pos, vel = state
forces = pm_forces(pos, mesh_shape=mesh_shape,
halo_size=halo_size, comms=comms) * 1.5 * cosmo.Omega_m
halo_size=halo_size, sharding_info=sharding_info) * 1.5 * cosmo.Omega_m
# Computes the update of position (drift)
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel

View file

@ -1,15 +1,15 @@
from dataclasses import fields
from mpi4py import MPI
import os
import jax
import jax.numpy as jnp
import numpy as onp
import mpi4jax
from jaxpm.ops import fft3d, ifft3d, normal, meshgrid3d, zeros
import jaxdecomp
from jaxpm.ops import fft3d, ifft3d, normal, meshgrid3d, zeros, ShardingInfo
from jaxpm.pm import linear_field, lpt, make_ode_fn
from jaxpm.painting import cic_paint
from jax.experimental.ode import odeint
import jax_cosmo as jc
import time
### Setting up a whole bunch of things #######
# Create communicators
@ -17,10 +17,12 @@ world = MPI.COMM_WORLD
rank = world.Get_rank()
size = world.Get_size()
cart_comm = MPI.COMM_WORLD.Create_cart(dims=[2, 2],
periods=[True, True])
comms = [cart_comm.Sub([True, False]),
cart_comm.Sub([False, True])]
# Here we assume clients are on the same node, so we restrict which device
# they can use based on their rank
os.environ["CUDA_VISIBLE_DEVICES"] = "%d" % (rank + 1)
jaxdecomp.init()
# Setup random keys
master_key = jax.random.PRNGKey(42)
@ -29,50 +31,77 @@ key = jax.random.split(master_key, size)[rank]
# Size and parameters of the simulation volume
N = 256
mesh_shape = [N, N, N]
box_size = [205, 205, 205] # Mpc/h
mesh_shape = (N, N, N)
box_size = [500, 500, 500] # Mpc/h
halo_size = 32
sharding_info = ShardingInfo(global_shape=mesh_shape,
pdims=(1,2),
halo_extents=(halo_size, halo_size, 0),
rank=rank)
cosmo = jc.Planck15()
halo_size = 16
a = 0.1
@jax.jit
def run_sim(cosmo, key):
initial_conditions = linear_field(cosmo, mesh_shape, box_size, key,
comms=comms)
init_field = ifft3d(initial_conditions, comms=comms).real
sharding_info=sharding_info)
init_field = ifft3d(initial_conditions, sharding_info=sharding_info).real
# Initialize particles
pos = meshgrid3d(mesh_shape, comms=comms)
pos = meshgrid3d(mesh_shape, sharding_info=sharding_info)
# Initial displacement by LPT
cosmo = jc.Planck15()
dx, p, f = lpt(cosmo, pos, initial_conditions, a, comms=comms)
dx, p, f = lpt(cosmo, pos, initial_conditions, a, halo_size=halo_size, sharding_info=sharding_info)
# And now, we run an actual nbody
res = odeint(make_ode_fn(mesh_shape, halo_size, comms),
res = odeint(make_ode_fn(mesh_shape, halo_size, sharding_info),
[pos+dx, p], jnp.linspace(0.1, 1.0, 2), cosmo,
rtol=1e-5, atol=1e-5)
rtol=1e-3, atol=1e-3)
# Painting on a new mesh
field = cic_paint(zeros(mesh_shape, comms=comms),
res[0][-1], halo_size, comms=comms)
field = cic_paint(zeros(mesh_shape, sharding_info=sharding_info),
res[0][-1], halo_size, sharding_info=sharding_info)
# field = cic_paint(zeros(mesh_shape, sharding_info=sharding_info),
# pos+dx, halo_size, sharding_info=sharding_info)
return init_field, field
# initial_conditions = linear_field(cosmo, mesh_shape, box_size, key,
# sharding_info=sharding_info)
# Recover the real space initial conditions
# init_field = ifft3d(initial_conditions, sharding_info=sharding_info).real
# print("hello", init_field.shape)
# cosmo = jc.Planck15()
# pos = meshgrid3d(mesh_shape, sharding_info=sharding_info)
# dx, p, f = lpt(cosmo, pos, initial_conditions, a, sharding_info=sharding_info)
# #dx = 3*jax.random.normal(key=key, shape=[1048576, 3])
# # Initialize particles
# # pos = meshgrid3d(mesh_shape, sharding_info=sharding_info)
# field = cic_paint(zeros(mesh_shape, sharding_info=sharding_info),
# pos+dx, halo_size, sharding_info=sharding_info)
# # Recover the real space initial conditions
init_field, field = run_sim(cosmo, key)
# Testing that the result is actually looking like what we expect
total_array, token = mpi4jax.allgather(field, comm=comms[0])
total_array = total_array.reshape([N, N//2, N])
total_array, token = mpi4jax.allgather(
total_array.transpose([1, 0, 2]), comm=comms[1], token=token)
total_array = total_array.reshape([N, N, N])
total_array = total_array.transpose([1, 0, 2])
# import jaxdecomp
# field = jaxdecomp.halo_exchange(field,
# halo_extents=sharding_info.halo_extents,
# halo_periods=(True,True,True),
# pdims=sharding_info.pdims,
# global_shape=sharding_info.global_shape)
if rank == 0:
onp.save('simulation.npy', total_array)
# time1 = time.time()
# init_field, field = run_sim(cosmo, key)
# init_field.block_until_ready()
# time2 = time.time()
print('Done !')
# if rank == 0:
onp.save('simulation_%d.npy'%rank, field)
# print('Done in', time2-time1)