diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py index ddd909a..1a9e38f 100644 --- a/jaxpm/kernels.py +++ b/jaxpm/kernels.py @@ -3,19 +3,23 @@ from jax.experimental.maps import xmap import numpy as np import jax.numpy as jnp from functools import partial +import jaxdecomp - -def fftk(shape, symmetric=False, dtype=np.float32, comms=None): +def fftk(shape, symmetric=False, dtype=np.float32, sharding_info=None): """ Return k_vector given a shape (nc, nc, nc) """ k = [] - if comms is not None: - nx = comms[0].Get_size() - ix = comms[0].Get_rank() - ny = comms[1].Get_size() - iy = comms[1].Get_rank() - shape = [shape[0]*nx, shape[1]*ny] + list(shape[2:]) + if sharding_info is not None: + nx = sharding_info.pdims[1] + ny = sharding_info.pdims[0] + # nx = sharding_info[0].Get_size() + # ix = sharding_info[0].Get_rank() + # ny = sharding_info[1].Get_size() + # iy = sharding_info[1].Get_rank() + ix = sharding_info.rank + iy = 0 + shape = sharding_info.global_shape for d in range(len(shape)): kd = np.fft.fftfreq(shape[d]) @@ -24,10 +28,10 @@ def fftk(shape, symmetric=False, dtype=np.float32, comms=None): if symmetric and d == len(shape) - 1: kd = kd[:shape[d] // 2 + 1] - if (comms is not None) and d == 0: + if (sharding_info is not None) and d == 0: kd = kd.reshape([nx, -1])[ix] - if (comms is not None) and d == 1: + if (sharding_info is not None) and d == 1: kd = kd.reshape([ny, -1])[iy] k.append(kd.astype(dtype)) @@ -42,10 +46,9 @@ def apply_gradient_laplace(kfield, kvec): kx, ky, kz = kvec kk = (kx**2 + ky**2 + kz**2) kernel = jnp.where(kk == 0, 1., 1./kk) - return jnp.stack([kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky)), - kfield * kernel * 1j * 1 / 6.0 * - (8 * jnp.sin(kz) - jnp.sin(2 * kz)), - kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx))], axis=-1) + return jnp.stack([kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz)), + kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx)), + kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky))], axis=-1) def cic_compensation(kvec): diff --git a/jaxpm/ops.py b/jaxpm/ops.py index 789ff69..ca78b82 100644 --- a/jaxpm/ops.py +++ b/jaxpm/ops.py @@ -2,155 +2,91 @@ import jax import jax.numpy as jnp import mpi4jax +import jaxdecomp +from dataclasses import dataclass +from typing import Tuple + +@dataclass +class ShardingInfo: + """Class for keeping track of the distribution strategy""" + global_shape: Tuple[int, int, int] + pdims: Tuple[int, int] + halo_extents: Tuple[int, int, int] + rank: int = 0 -def fft3d(arr, comms=None): +def fft3d(arr, sharding_info=None): """ Computes forward FFT, note that the output is transposed """ - if comms is not None: - shape = list(arr.shape) - nx = comms[0].Get_size() - ny = comms[1].Get_size() - - # First FFT along z - arr = jnp.fft.fft(arr) # [x, y, z] - # Perform single gpu or distributed transpose - if comms == None: - arr = arr.transpose([1, 2, 0]) + if sharding_info is None: + arr = jnp.fft.fftn(arr).transpose([1, 2, 0]) else: - arr = arr.reshape(shape[:-1]+[nx, shape[-1] // nx]) - #arr = arr.transpose([2, 1, 3, 0]) # [y, z, x] - arr = jnp.einsum('ij,xyjz->iyzx', jnp.eye(nx), arr) # TODO: remove this hack when we understand why transpose before alltoall doenst work - arr, token = mpi4jax.alltoall(arr, comm=comms[0]) - arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [y, z, x] - - # Second FFT along x - arr = jnp.fft.fft(arr) - # Perform single gpu or distributed transpose - if comms == None: - arr = arr.transpose([1, 2, 0]) - else: - arr = arr.reshape(shape[:-1]+[ny, shape[-1] // ny]) - #arr = arr.transpose([2, 1, 3, 0]) # [z, x, y] - arr = jnp.einsum('ij,yzjx->izxy', jnp.eye(ny), arr) # TODO: remove this hack when we understand why transpose before alltoall doenst work - arr, token = mpi4jax.alltoall(arr, comm=comms[1], token=token) - arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [z, x, y] - - # Third FFT along y - return jnp.fft.fft(arr) - - -def ifft3d(arr, comms=None): - """ Let's assume that the data is distributed accross x - """ - if comms is not None: - shape = list(arr.shape) - nx = comms[0].Get_size() - ny = comms[1].Get_size() - - # First FFT along y - arr = jnp.fft.ifft(arr) # Now [z, x, y] - # Perform single gpu or distributed transpose - if comms == None: - arr = arr.transpose([0, 2, 1]) - else: - arr = arr.reshape(shape[:-1]+[ny, shape[-1] // ny]) - # arr = arr.transpose([2, 0, 3, 1]) # Now [z, y, x] - arr = jnp.einsum('ij,zxjy->izyx', jnp.eye(ny), arr) # TODO: remove this hack when we understand why transpose before alltoall doenst work - arr, token = mpi4jax.alltoall(arr, comm=comms[1]) - arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [z,y,x] - - # Second FFT along x - arr = jnp.fft.ifft(arr) - # Perform single gpu or distributed transpose - if comms == None: - arr = arr.transpose([2, 1, 0]) - else: - arr = arr.reshape(shape[:-1]+[nx, shape[-1] // nx]) - # arr = arr.transpose([2, 3, 1, 0]) # now [x, y, z] - arr = jnp.einsum('ij,zyjx->ixyz', jnp.eye(nx), arr) # TODO: remove this hack when we understand why transpose before alltoall doenst work - arr, token = mpi4jax.alltoall(arr, comm=comms[0], token=token) - arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [x,y,z] - - # Third FFT along z - return jnp.fft.ifft(arr) - - -def halo_reduce(arr, halo_size, comms=None): - if halo_size <= 0: - return arr - - # Perform halo exchange along x - rank_x = comms[0].Get_rank() - size_x = comms[0].Get_size() - margin = arr[-2*halo_size:] - left, token = mpi4jax.sendrecv(margin, margin, - (rank_x-1) % size_x, - (rank_x+1) % size_x, - comm=comms[0]) - margin = arr[:2*halo_size] - right, token = mpi4jax.sendrecv(margin, margin, - (rank_x+1) % size_x, - (rank_x-1) % size_x, - comm=comms[0], token=token) - - arr = arr.at[:2*halo_size].add(left) - arr = arr.at[-2*halo_size:].add(right) - - # Perform halo exchange along y - rank_y = comms[1].Get_rank() - size_y = comms[1].Get_size() - margin = arr[:, -2*halo_size:] - left, token = mpi4jax.sendrecv(margin, margin, - (rank_y-1) % size_y, - (rank_y+1) % size_y, - comm=comms[1], token=token) - margin = arr[:, :2*halo_size] - right, token = mpi4jax.sendrecv(margin, margin, - (rank_y+1) % size_y, - (rank_y-1) % size_y, - comm=comms[1], token=token) - arr = arr.at[:, :2*halo_size].add(left) - arr = arr.at[:, -2*halo_size:].add(right) - + arr = jaxdecomp.pfft3d(arr, + pdims=sharding_info.pdims, + global_shape=sharding_info.global_shape) return arr -def meshgrid3d(shape, comms=None): - if comms is not None: - nx = comms[0].Get_size() - ny = comms[1].Get_size() +def ifft3d(arr, sharding_info=None): + if sharding_info is None: + arr = jnp.fft.ifftn(arr.transpose([2, 0, 1])) + else: + arr = jaxdecomp.pifft3d(arr, + pdims=sharding_info.pdims, + global_shape=sharding_info.global_shape) + return arr - coords = [jnp.arange(shape[0]//nx), - jnp.arange(shape[1]//ny)] + [jnp.arange(s) for s in shape[2:]] + +def halo_reduce(arr, sharding_info=None): + if sharding_info is None: + return arr + halo_size = sharding_info.halo_extents[0] + global_shape = sharding_info.global_shape + arr = jaxdecomp.halo_exchange(arr, + halo_extents=(halo_size//2, halo_size//2, 0), + halo_periods=(True,True,True), + pdims=sharding_info.pdims, + global_shape=(global_shape[0]+2*halo_size, + global_shape[1]+halo_size, + global_shape[2])) + + # Apply correction along x + arr = arr.at[halo_size:halo_size + halo_size//2].add(arr[ :halo_size//2]) + arr = arr.at[-halo_size - halo_size//2:-halo_size].add(arr[-halo_size//2:]) + + # Apply correction along y + arr = arr.at[:, halo_size:halo_size + halo_size//2].add(arr[:, :halo_size//2][:, :]) + arr = arr.at[:, -halo_size - halo_size//2:-halo_size].add(arr[:, -halo_size//2:][:, :]) + + return arr + + +def meshgrid3d(shape, sharding_info=None): + if sharding_info is not None: + coords = [jnp.arange(sharding_info.global_shape[0]//sharding_info.pdims[1]), + jnp.arange(sharding_info.global_shape[1]//sharding_info.pdims[0]), jnp.arange(sharding_info.global_shape[2])] else: coords = [jnp.arange(s) for s in shape[2:]] return jnp.stack(jnp.meshgrid(*coords), axis=-1).reshape([-1, 3]) -def zeros(shape, comms=None): +def zeros(shape, sharding_info=None): """ Initialize an array of given global shape partitionned if need be accross dimensions. """ - if comms is None: + if sharding_info is None: return jnp.zeros(shape) - nx = comms[0].Get_size() - ny = comms[1].Get_size() - - return jnp.zeros([shape[0]//nx, shape[1]//ny]+list(shape[2:])) + return jnp.zeros([sharding_info.global_shape[0]//sharding_info.pdims[1], sharding_info.global_shape[1]//sharding_info.pdims[0]]+list(sharding_info.global_shape[2:])) -def normal(key, shape, comms=None): +def normal(key, shape, sharding_info=None): """ Generates a normal variable for the given global shape. """ - if comms is None: + if sharding_info is None: return jax.random.normal(key, shape) - nx = comms[0].Get_size() - ny = comms[1].Get_size() - return jax.random.normal(key, - [shape[0]//nx, shape[1]//ny]+list(shape[2:])) + [sharding_info.global_shape[0]//sharding_info.pdims[1], sharding_info.global_shape[1]//sharding_info.pdims[0], sharding_info.global_shape[2]]) diff --git a/jaxpm/painting.py b/jaxpm/painting.py index 78a3bd9..71c3671 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -6,12 +6,12 @@ from jaxpm.ops import halo_reduce from jaxpm.kernels import fftk, cic_compensation -def cic_paint(mesh, positions, halo_size=0, comms=None): +def cic_paint(mesh, positions, halo_size=0, sharding_info=None): """ Paints positions onto mesh mesh: [nx, ny, nz] positions: [npart, 3] """ - if comms is not None: + if sharding_info is not None: # Add some padding for the halo exchange mesh = jnp.pad(mesh, [[halo_size, halo_size], [halo_size, halo_size], @@ -40,26 +40,32 @@ def cic_paint(mesh, positions, halo_size=0, comms=None): kernel.reshape([-1, 8]), dnums) - if comms == None: + if sharding_info == None: return mesh else: - mesh = halo_reduce(mesh, halo_size, comms) + mesh = halo_reduce(mesh, sharding_info) return mesh[halo_size:-halo_size, halo_size:-halo_size] -def cic_read(mesh, positions, halo_size=0, comms=None): +def cic_read(mesh, positions, halo_size=0, sharding_info=None): """ Paints positions onto mesh mesh: [nx, ny, nz] positions: [npart, 3] """ - if comms is not None: + if sharding_info is not None: # Add some padding and perfom hao exchange to retrieve # neighboring regions mesh = jnp.pad(mesh, [[halo_size, halo_size], [halo_size, halo_size], [0, 0]]) - mesh = halo_reduce(mesh, halo_size, comms) + # mesh = halo_reduce(mesh, sharding_info) + import jaxdecomp + mesh = jaxdecomp.halo_exchange(mesh, + halo_extents=sharding_info.halo_extents, + halo_periods=(True,True,True), + pdims=sharding_info.pdims, + global_shape=sharding_info.global_shape) positions += jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]) positions = jnp.expand_dims(positions, 1) diff --git a/jaxpm/pm.py b/jaxpm/pm.py index 8ec910a..672f7b6 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -10,32 +10,32 @@ from jaxpm.painting import cic_paint, cic_read from jaxpm.growth import growth_factor, growth_rate, dGfa -def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=0, token=None, comms=None): +def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=0, sharding_info=None): """ Computes gravitational forces on particles using a PM scheme """ if delta_k is None: - delta = cic_paint(zeros(mesh_shape, comms=comms), + delta = cic_paint(zeros(mesh_shape, sharding_info=sharding_info), positions, - halo_size=halo_size, comms=comms) - delta_k = fft3d(delta, comms=comms) + halo_size=halo_size, sharding_info=sharding_info) + delta_k = fft3d(delta, sharding_info=sharding_info) # Computes gravitational forces - kvec = fftk(delta_k.shape, symmetric=False, comms=comms) + kvec = fftk(delta_k.shape, symmetric=False, sharding_info=sharding_info) forces_k = apply_gradient_laplace(delta_k, kvec) # Interpolate forces at the position of particles - return jnp.stack([cic_read(ifft3d(forces_k[..., i], comms=comms).real, - positions, halo_size=halo_size, comms=comms) + return jnp.stack([cic_read(ifft3d(forces_k[..., i], sharding_info=sharding_info).real, + positions, halo_size=halo_size, sharding_info=sharding_info) for i in range(3)], axis=-1) -def lpt(cosmo, positions, initial_conditions, a, halo_size=0, comms=None): +def lpt(cosmo, positions, initial_conditions, a, halo_size=0, sharding_info=None): """ Computes first order LPT displacement """ initial_force = pm_forces( - positions, delta_k=initial_conditions, halo_size=halo_size, comms=comms) + positions, delta_k=initial_conditions, halo_size=halo_size, sharding_info=sharding_info) a = jnp.atleast_1d(a) dx = growth_factor(cosmo, a) * initial_force p = a**2 * growth_rate(cosmo, a) * \ @@ -45,21 +45,21 @@ def lpt(cosmo, positions, initial_conditions, a, halo_size=0, comms=None): return dx, p, f -def linear_field(cosmo, mesh_shape, box_size, key, comms=None): +def linear_field(cosmo, mesh_shape, box_size, key, sharding_info=None): """ Generate initial conditions in Fourier space. """ # Sample normal field - field = normal(key, mesh_shape, comms=comms) + field = normal(key, mesh_shape, sharding_info=sharding_info) # Transform to Fourier space - kfield = fft3d(field, comms=comms) + kfield = fft3d(field, sharding_info=sharding_info) # Rescaling k to physical units kvec = [k / box_size[i] * mesh_shape[i] for i, k in enumerate(fftk(kfield.shape, symmetric=False, - comms=comms))] + sharding_info=sharding_info))] # Evaluating linear matter powerspectrum k = jnp.logspace(-4, 2, 256) @@ -77,7 +77,7 @@ def linear_field(cosmo, mesh_shape, box_size, key, comms=None): return kfield -def make_ode_fn(mesh_shape, halo_size=0, comms=None): +def make_ode_fn(mesh_shape, halo_size=0, sharding_info=None): def nbody_ode(state, a, cosmo): """ @@ -86,7 +86,7 @@ def make_ode_fn(mesh_shape, halo_size=0, comms=None): pos, vel = state forces = pm_forces(pos, mesh_shape=mesh_shape, - halo_size=halo_size, comms=comms) * 1.5 * cosmo.Omega_m + halo_size=halo_size, sharding_info=sharding_info) * 1.5 * cosmo.Omega_m # Computes the update of position (drift) dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel diff --git a/scripts/test_nbody.py b/scripts/test_nbody.py index b9860f2..8bf8092 100644 --- a/scripts/test_nbody.py +++ b/scripts/test_nbody.py @@ -1,15 +1,15 @@ -from dataclasses import fields from mpi4py import MPI +import os import jax import jax.numpy as jnp import numpy as onp -import mpi4jax -from jaxpm.ops import fft3d, ifft3d, normal, meshgrid3d, zeros +import jaxdecomp +from jaxpm.ops import fft3d, ifft3d, normal, meshgrid3d, zeros, ShardingInfo from jaxpm.pm import linear_field, lpt, make_ode_fn from jaxpm.painting import cic_paint from jax.experimental.ode import odeint import jax_cosmo as jc - +import time ### Setting up a whole bunch of things ####### # Create communicators @@ -17,10 +17,12 @@ world = MPI.COMM_WORLD rank = world.Get_rank() size = world.Get_size() -cart_comm = MPI.COMM_WORLD.Create_cart(dims=[2, 2], - periods=[True, True]) -comms = [cart_comm.Sub([True, False]), - cart_comm.Sub([False, True])] +# Here we assume clients are on the same node, so we restrict which device +# they can use based on their rank +os.environ["CUDA_VISIBLE_DEVICES"] = "%d" % (rank + 1) + + +jaxdecomp.init() # Setup random keys master_key = jax.random.PRNGKey(42) @@ -29,50 +31,77 @@ key = jax.random.split(master_key, size)[rank] # Size and parameters of the simulation volume N = 256 -mesh_shape = [N, N, N] -box_size = [205, 205, 205] # Mpc/h +mesh_shape = (N, N, N) +box_size = [500, 500, 500] # Mpc/h +halo_size = 32 +sharding_info = ShardingInfo(global_shape=mesh_shape, + pdims=(1,2), + halo_extents=(halo_size, halo_size, 0), + rank=rank) cosmo = jc.Planck15() -halo_size = 16 a = 0.1 @jax.jit def run_sim(cosmo, key): initial_conditions = linear_field(cosmo, mesh_shape, box_size, key, - comms=comms) - init_field = ifft3d(initial_conditions, comms=comms).real + sharding_info=sharding_info) + + init_field = ifft3d(initial_conditions, sharding_info=sharding_info).real # Initialize particles - pos = meshgrid3d(mesh_shape, comms=comms) + pos = meshgrid3d(mesh_shape, sharding_info=sharding_info) # Initial displacement by LPT cosmo = jc.Planck15() - dx, p, f = lpt(cosmo, pos, initial_conditions, a, comms=comms) + dx, p, f = lpt(cosmo, pos, initial_conditions, a, halo_size=halo_size, sharding_info=sharding_info) # And now, we run an actual nbody - res = odeint(make_ode_fn(mesh_shape, halo_size, comms), + res = odeint(make_ode_fn(mesh_shape, halo_size, sharding_info), [pos+dx, p], jnp.linspace(0.1, 1.0, 2), cosmo, - rtol=1e-5, atol=1e-5) - + rtol=1e-3, atol=1e-3) # Painting on a new mesh - field = cic_paint(zeros(mesh_shape, comms=comms), - res[0][-1], halo_size, comms=comms) - + field = cic_paint(zeros(mesh_shape, sharding_info=sharding_info), + res[0][-1], halo_size, sharding_info=sharding_info) + + # field = cic_paint(zeros(mesh_shape, sharding_info=sharding_info), + # pos+dx, halo_size, sharding_info=sharding_info) return init_field, field +# initial_conditions = linear_field(cosmo, mesh_shape, box_size, key, +# sharding_info=sharding_info) -# Recover the real space initial conditions +# init_field = ifft3d(initial_conditions, sharding_info=sharding_info).real + +# print("hello", init_field.shape) + +# cosmo = jc.Planck15() +# pos = meshgrid3d(mesh_shape, sharding_info=sharding_info) +# dx, p, f = lpt(cosmo, pos, initial_conditions, a, sharding_info=sharding_info) + +# #dx = 3*jax.random.normal(key=key, shape=[1048576, 3]) +# # Initialize particles +# # pos = meshgrid3d(mesh_shape, sharding_info=sharding_info) + +# field = cic_paint(zeros(mesh_shape, sharding_info=sharding_info), +# pos+dx, halo_size, sharding_info=sharding_info) + +# # Recover the real space initial conditions init_field, field = run_sim(cosmo, key) -# Testing that the result is actually looking like what we expect -total_array, token = mpi4jax.allgather(field, comm=comms[0]) -total_array = total_array.reshape([N, N//2, N]) -total_array, token = mpi4jax.allgather( - total_array.transpose([1, 0, 2]), comm=comms[1], token=token) -total_array = total_array.reshape([N, N, N]) -total_array = total_array.transpose([1, 0, 2]) +# import jaxdecomp +# field = jaxdecomp.halo_exchange(field, +# halo_extents=sharding_info.halo_extents, +# halo_periods=(True,True,True), +# pdims=sharding_info.pdims, +# global_shape=sharding_info.global_shape) -if rank == 0: - onp.save('simulation.npy', total_array) +# time1 = time.time() +# init_field, field = run_sim(cosmo, key) +# init_field.block_until_ready() +# time2 = time.time() -print('Done !') +# if rank == 0: +onp.save('simulation_%d.npy'%rank, field) + +# print('Done in', time2-time1)