Adding an example of jaxdecomp implementation

This commit is contained in:
EiffL 2022-11-26 17:27:14 +01:00
parent 6644b35d71
commit 6ca4c9191e
5 changed files with 166 additions and 192 deletions

View file

@ -3,19 +3,23 @@ from jax.experimental.maps import xmap
import numpy as np import numpy as np
import jax.numpy as jnp import jax.numpy as jnp
from functools import partial from functools import partial
import jaxdecomp
def fftk(shape, symmetric=False, dtype=np.float32, sharding_info=None):
def fftk(shape, symmetric=False, dtype=np.float32, comms=None):
""" Return k_vector given a shape (nc, nc, nc) """ Return k_vector given a shape (nc, nc, nc)
""" """
k = [] k = []
if comms is not None: if sharding_info is not None:
nx = comms[0].Get_size() nx = sharding_info.pdims[1]
ix = comms[0].Get_rank() ny = sharding_info.pdims[0]
ny = comms[1].Get_size() # nx = sharding_info[0].Get_size()
iy = comms[1].Get_rank() # ix = sharding_info[0].Get_rank()
shape = [shape[0]*nx, shape[1]*ny] + list(shape[2:]) # ny = sharding_info[1].Get_size()
# iy = sharding_info[1].Get_rank()
ix = sharding_info.rank
iy = 0
shape = sharding_info.global_shape
for d in range(len(shape)): for d in range(len(shape)):
kd = np.fft.fftfreq(shape[d]) kd = np.fft.fftfreq(shape[d])
@ -24,10 +28,10 @@ def fftk(shape, symmetric=False, dtype=np.float32, comms=None):
if symmetric and d == len(shape) - 1: if symmetric and d == len(shape) - 1:
kd = kd[:shape[d] // 2 + 1] kd = kd[:shape[d] // 2 + 1]
if (comms is not None) and d == 0: if (sharding_info is not None) and d == 0:
kd = kd.reshape([nx, -1])[ix] kd = kd.reshape([nx, -1])[ix]
if (comms is not None) and d == 1: if (sharding_info is not None) and d == 1:
kd = kd.reshape([ny, -1])[iy] kd = kd.reshape([ny, -1])[iy]
k.append(kd.astype(dtype)) k.append(kd.astype(dtype))
@ -42,10 +46,9 @@ def apply_gradient_laplace(kfield, kvec):
kx, ky, kz = kvec kx, ky, kz = kvec
kk = (kx**2 + ky**2 + kz**2) kk = (kx**2 + ky**2 + kz**2)
kernel = jnp.where(kk == 0, 1., 1./kk) kernel = jnp.where(kk == 0, 1., 1./kk)
return jnp.stack([kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky)), return jnp.stack([kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz)),
kfield * kernel * 1j * 1 / 6.0 * kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx)),
(8 * jnp.sin(kz) - jnp.sin(2 * kz)), kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky))], axis=-1)
kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx))], axis=-1)
def cic_compensation(kvec): def cic_compensation(kvec):

View file

@ -2,155 +2,91 @@
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import mpi4jax import mpi4jax
import jaxdecomp
from dataclasses import dataclass
from typing import Tuple
@dataclass
class ShardingInfo:
"""Class for keeping track of the distribution strategy"""
global_shape: Tuple[int, int, int]
pdims: Tuple[int, int]
halo_extents: Tuple[int, int, int]
rank: int = 0
def fft3d(arr, comms=None): def fft3d(arr, sharding_info=None):
""" Computes forward FFT, note that the output is transposed """ Computes forward FFT, note that the output is transposed
""" """
if comms is not None: if sharding_info is None:
shape = list(arr.shape) arr = jnp.fft.fftn(arr).transpose([1, 2, 0])
nx = comms[0].Get_size()
ny = comms[1].Get_size()
# First FFT along z
arr = jnp.fft.fft(arr) # [x, y, z]
# Perform single gpu or distributed transpose
if comms == None:
arr = arr.transpose([1, 2, 0])
else: else:
arr = arr.reshape(shape[:-1]+[nx, shape[-1] // nx]) arr = jaxdecomp.pfft3d(arr,
#arr = arr.transpose([2, 1, 3, 0]) # [y, z, x] pdims=sharding_info.pdims,
arr = jnp.einsum('ij,xyjz->iyzx', jnp.eye(nx), arr) # TODO: remove this hack when we understand why transpose before alltoall doenst work global_shape=sharding_info.global_shape)
arr, token = mpi4jax.alltoall(arr, comm=comms[0])
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [y, z, x]
# Second FFT along x
arr = jnp.fft.fft(arr)
# Perform single gpu or distributed transpose
if comms == None:
arr = arr.transpose([1, 2, 0])
else:
arr = arr.reshape(shape[:-1]+[ny, shape[-1] // ny])
#arr = arr.transpose([2, 1, 3, 0]) # [z, x, y]
arr = jnp.einsum('ij,yzjx->izxy', jnp.eye(ny), arr) # TODO: remove this hack when we understand why transpose before alltoall doenst work
arr, token = mpi4jax.alltoall(arr, comm=comms[1], token=token)
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [z, x, y]
# Third FFT along y
return jnp.fft.fft(arr)
def ifft3d(arr, comms=None):
""" Let's assume that the data is distributed accross x
"""
if comms is not None:
shape = list(arr.shape)
nx = comms[0].Get_size()
ny = comms[1].Get_size()
# First FFT along y
arr = jnp.fft.ifft(arr) # Now [z, x, y]
# Perform single gpu or distributed transpose
if comms == None:
arr = arr.transpose([0, 2, 1])
else:
arr = arr.reshape(shape[:-1]+[ny, shape[-1] // ny])
# arr = arr.transpose([2, 0, 3, 1]) # Now [z, y, x]
arr = jnp.einsum('ij,zxjy->izyx', jnp.eye(ny), arr) # TODO: remove this hack when we understand why transpose before alltoall doenst work
arr, token = mpi4jax.alltoall(arr, comm=comms[1])
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [z,y,x]
# Second FFT along x
arr = jnp.fft.ifft(arr)
# Perform single gpu or distributed transpose
if comms == None:
arr = arr.transpose([2, 1, 0])
else:
arr = arr.reshape(shape[:-1]+[nx, shape[-1] // nx])
# arr = arr.transpose([2, 3, 1, 0]) # now [x, y, z]
arr = jnp.einsum('ij,zyjx->ixyz', jnp.eye(nx), arr) # TODO: remove this hack when we understand why transpose before alltoall doenst work
arr, token = mpi4jax.alltoall(arr, comm=comms[0], token=token)
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [x,y,z]
# Third FFT along z
return jnp.fft.ifft(arr)
def halo_reduce(arr, halo_size, comms=None):
if halo_size <= 0:
return arr return arr
# Perform halo exchange along x
rank_x = comms[0].Get_rank()
size_x = comms[0].Get_size()
margin = arr[-2*halo_size:]
left, token = mpi4jax.sendrecv(margin, margin,
(rank_x-1) % size_x,
(rank_x+1) % size_x,
comm=comms[0])
margin = arr[:2*halo_size]
right, token = mpi4jax.sendrecv(margin, margin,
(rank_x+1) % size_x,
(rank_x-1) % size_x,
comm=comms[0], token=token)
arr = arr.at[:2*halo_size].add(left) def ifft3d(arr, sharding_info=None):
arr = arr.at[-2*halo_size:].add(right) if sharding_info is None:
arr = jnp.fft.ifftn(arr.transpose([2, 0, 1]))
else:
arr = jaxdecomp.pifft3d(arr,
pdims=sharding_info.pdims,
global_shape=sharding_info.global_shape)
return arr
# Perform halo exchange along y
rank_y = comms[1].Get_rank() def halo_reduce(arr, sharding_info=None):
size_y = comms[1].Get_size() if sharding_info is None:
margin = arr[:, -2*halo_size:] return arr
left, token = mpi4jax.sendrecv(margin, margin, halo_size = sharding_info.halo_extents[0]
(rank_y-1) % size_y, global_shape = sharding_info.global_shape
(rank_y+1) % size_y, arr = jaxdecomp.halo_exchange(arr,
comm=comms[1], token=token) halo_extents=(halo_size//2, halo_size//2, 0),
margin = arr[:, :2*halo_size] halo_periods=(True,True,True),
right, token = mpi4jax.sendrecv(margin, margin, pdims=sharding_info.pdims,
(rank_y+1) % size_y, global_shape=(global_shape[0]+2*halo_size,
(rank_y-1) % size_y, global_shape[1]+halo_size,
comm=comms[1], token=token) global_shape[2]))
arr = arr.at[:, :2*halo_size].add(left)
arr = arr.at[:, -2*halo_size:].add(right) # Apply correction along x
arr = arr.at[halo_size:halo_size + halo_size//2].add(arr[ :halo_size//2])
arr = arr.at[-halo_size - halo_size//2:-halo_size].add(arr[-halo_size//2:])
# Apply correction along y
arr = arr.at[:, halo_size:halo_size + halo_size//2].add(arr[:, :halo_size//2][:, :])
arr = arr.at[:, -halo_size - halo_size//2:-halo_size].add(arr[:, -halo_size//2:][:, :])
return arr return arr
def meshgrid3d(shape, comms=None): def meshgrid3d(shape, sharding_info=None):
if comms is not None: if sharding_info is not None:
nx = comms[0].Get_size() coords = [jnp.arange(sharding_info.global_shape[0]//sharding_info.pdims[1]),
ny = comms[1].Get_size() jnp.arange(sharding_info.global_shape[1]//sharding_info.pdims[0]), jnp.arange(sharding_info.global_shape[2])]
coords = [jnp.arange(shape[0]//nx),
jnp.arange(shape[1]//ny)] + [jnp.arange(s) for s in shape[2:]]
else: else:
coords = [jnp.arange(s) for s in shape[2:]] coords = [jnp.arange(s) for s in shape[2:]]
return jnp.stack(jnp.meshgrid(*coords), axis=-1).reshape([-1, 3]) return jnp.stack(jnp.meshgrid(*coords), axis=-1).reshape([-1, 3])
def zeros(shape, comms=None): def zeros(shape, sharding_info=None):
""" Initialize an array of given global shape """ Initialize an array of given global shape
partitionned if need be accross dimensions. partitionned if need be accross dimensions.
""" """
if comms is None: if sharding_info is None:
return jnp.zeros(shape) return jnp.zeros(shape)
nx = comms[0].Get_size() return jnp.zeros([sharding_info.global_shape[0]//sharding_info.pdims[1], sharding_info.global_shape[1]//sharding_info.pdims[0]]+list(sharding_info.global_shape[2:]))
ny = comms[1].Get_size()
return jnp.zeros([shape[0]//nx, shape[1]//ny]+list(shape[2:]))
def normal(key, shape, comms=None): def normal(key, shape, sharding_info=None):
""" Generates a normal variable for the given """ Generates a normal variable for the given
global shape. global shape.
""" """
if comms is None: if sharding_info is None:
return jax.random.normal(key, shape) return jax.random.normal(key, shape)
nx = comms[0].Get_size()
ny = comms[1].Get_size()
return jax.random.normal(key, return jax.random.normal(key,
[shape[0]//nx, shape[1]//ny]+list(shape[2:])) [sharding_info.global_shape[0]//sharding_info.pdims[1], sharding_info.global_shape[1]//sharding_info.pdims[0], sharding_info.global_shape[2]])

View file

@ -6,12 +6,12 @@ from jaxpm.ops import halo_reduce
from jaxpm.kernels import fftk, cic_compensation from jaxpm.kernels import fftk, cic_compensation
def cic_paint(mesh, positions, halo_size=0, comms=None): def cic_paint(mesh, positions, halo_size=0, sharding_info=None):
""" Paints positions onto mesh """ Paints positions onto mesh
mesh: [nx, ny, nz] mesh: [nx, ny, nz]
positions: [npart, 3] positions: [npart, 3]
""" """
if comms is not None: if sharding_info is not None:
# Add some padding for the halo exchange # Add some padding for the halo exchange
mesh = jnp.pad(mesh, [[halo_size, halo_size], mesh = jnp.pad(mesh, [[halo_size, halo_size],
[halo_size, halo_size], [halo_size, halo_size],
@ -40,26 +40,32 @@ def cic_paint(mesh, positions, halo_size=0, comms=None):
kernel.reshape([-1, 8]), kernel.reshape([-1, 8]),
dnums) dnums)
if comms == None: if sharding_info == None:
return mesh return mesh
else: else:
mesh = halo_reduce(mesh, halo_size, comms) mesh = halo_reduce(mesh, sharding_info)
return mesh[halo_size:-halo_size, halo_size:-halo_size] return mesh[halo_size:-halo_size, halo_size:-halo_size]
def cic_read(mesh, positions, halo_size=0, comms=None): def cic_read(mesh, positions, halo_size=0, sharding_info=None):
""" Paints positions onto mesh """ Paints positions onto mesh
mesh: [nx, ny, nz] mesh: [nx, ny, nz]
positions: [npart, 3] positions: [npart, 3]
""" """
if comms is not None: if sharding_info is not None:
# Add some padding and perfom hao exchange to retrieve # Add some padding and perfom hao exchange to retrieve
# neighboring regions # neighboring regions
mesh = jnp.pad(mesh, [[halo_size, halo_size], mesh = jnp.pad(mesh, [[halo_size, halo_size],
[halo_size, halo_size], [halo_size, halo_size],
[0, 0]]) [0, 0]])
mesh = halo_reduce(mesh, halo_size, comms) # mesh = halo_reduce(mesh, sharding_info)
import jaxdecomp
mesh = jaxdecomp.halo_exchange(mesh,
halo_extents=sharding_info.halo_extents,
halo_periods=(True,True,True),
pdims=sharding_info.pdims,
global_shape=sharding_info.global_shape)
positions += jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]) positions += jnp.array([halo_size, halo_size, 0]).reshape([-1, 3])
positions = jnp.expand_dims(positions, 1) positions = jnp.expand_dims(positions, 1)

View file

@ -10,32 +10,32 @@ from jaxpm.painting import cic_paint, cic_read
from jaxpm.growth import growth_factor, growth_rate, dGfa from jaxpm.growth import growth_factor, growth_rate, dGfa
def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=0, token=None, comms=None): def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=0, sharding_info=None):
""" """
Computes gravitational forces on particles using a PM scheme Computes gravitational forces on particles using a PM scheme
""" """
if delta_k is None: if delta_k is None:
delta = cic_paint(zeros(mesh_shape, comms=comms), delta = cic_paint(zeros(mesh_shape, sharding_info=sharding_info),
positions, positions,
halo_size=halo_size, comms=comms) halo_size=halo_size, sharding_info=sharding_info)
delta_k = fft3d(delta, comms=comms) delta_k = fft3d(delta, sharding_info=sharding_info)
# Computes gravitational forces # Computes gravitational forces
kvec = fftk(delta_k.shape, symmetric=False, comms=comms) kvec = fftk(delta_k.shape, symmetric=False, sharding_info=sharding_info)
forces_k = apply_gradient_laplace(delta_k, kvec) forces_k = apply_gradient_laplace(delta_k, kvec)
# Interpolate forces at the position of particles # Interpolate forces at the position of particles
return jnp.stack([cic_read(ifft3d(forces_k[..., i], comms=comms).real, return jnp.stack([cic_read(ifft3d(forces_k[..., i], sharding_info=sharding_info).real,
positions, halo_size=halo_size, comms=comms) positions, halo_size=halo_size, sharding_info=sharding_info)
for i in range(3)], axis=-1) for i in range(3)], axis=-1)
def lpt(cosmo, positions, initial_conditions, a, halo_size=0, comms=None): def lpt(cosmo, positions, initial_conditions, a, halo_size=0, sharding_info=None):
""" """
Computes first order LPT displacement Computes first order LPT displacement
""" """
initial_force = pm_forces( initial_force = pm_forces(
positions, delta_k=initial_conditions, halo_size=halo_size, comms=comms) positions, delta_k=initial_conditions, halo_size=halo_size, sharding_info=sharding_info)
a = jnp.atleast_1d(a) a = jnp.atleast_1d(a)
dx = growth_factor(cosmo, a) * initial_force dx = growth_factor(cosmo, a) * initial_force
p = a**2 * growth_rate(cosmo, a) * \ p = a**2 * growth_rate(cosmo, a) * \
@ -45,21 +45,21 @@ def lpt(cosmo, positions, initial_conditions, a, halo_size=0, comms=None):
return dx, p, f return dx, p, f
def linear_field(cosmo, mesh_shape, box_size, key, comms=None): def linear_field(cosmo, mesh_shape, box_size, key, sharding_info=None):
""" """
Generate initial conditions in Fourier space. Generate initial conditions in Fourier space.
""" """
# Sample normal field # Sample normal field
field = normal(key, mesh_shape, comms=comms) field = normal(key, mesh_shape, sharding_info=sharding_info)
# Transform to Fourier space # Transform to Fourier space
kfield = fft3d(field, comms=comms) kfield = fft3d(field, sharding_info=sharding_info)
# Rescaling k to physical units # Rescaling k to physical units
kvec = [k / box_size[i] * mesh_shape[i] kvec = [k / box_size[i] * mesh_shape[i]
for i, k in enumerate(fftk(kfield.shape, for i, k in enumerate(fftk(kfield.shape,
symmetric=False, symmetric=False,
comms=comms))] sharding_info=sharding_info))]
# Evaluating linear matter powerspectrum # Evaluating linear matter powerspectrum
k = jnp.logspace(-4, 2, 256) k = jnp.logspace(-4, 2, 256)
@ -77,7 +77,7 @@ def linear_field(cosmo, mesh_shape, box_size, key, comms=None):
return kfield return kfield
def make_ode_fn(mesh_shape, halo_size=0, comms=None): def make_ode_fn(mesh_shape, halo_size=0, sharding_info=None):
def nbody_ode(state, a, cosmo): def nbody_ode(state, a, cosmo):
""" """
@ -86,7 +86,7 @@ def make_ode_fn(mesh_shape, halo_size=0, comms=None):
pos, vel = state pos, vel = state
forces = pm_forces(pos, mesh_shape=mesh_shape, forces = pm_forces(pos, mesh_shape=mesh_shape,
halo_size=halo_size, comms=comms) * 1.5 * cosmo.Omega_m halo_size=halo_size, sharding_info=sharding_info) * 1.5 * cosmo.Omega_m
# Computes the update of position (drift) # Computes the update of position (drift)
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel

View file

@ -1,15 +1,15 @@
from dataclasses import fields
from mpi4py import MPI from mpi4py import MPI
import os
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as onp import numpy as onp
import mpi4jax import jaxdecomp
from jaxpm.ops import fft3d, ifft3d, normal, meshgrid3d, zeros from jaxpm.ops import fft3d, ifft3d, normal, meshgrid3d, zeros, ShardingInfo
from jaxpm.pm import linear_field, lpt, make_ode_fn from jaxpm.pm import linear_field, lpt, make_ode_fn
from jaxpm.painting import cic_paint from jaxpm.painting import cic_paint
from jax.experimental.ode import odeint from jax.experimental.ode import odeint
import jax_cosmo as jc import jax_cosmo as jc
import time
### Setting up a whole bunch of things ####### ### Setting up a whole bunch of things #######
# Create communicators # Create communicators
@ -17,10 +17,12 @@ world = MPI.COMM_WORLD
rank = world.Get_rank() rank = world.Get_rank()
size = world.Get_size() size = world.Get_size()
cart_comm = MPI.COMM_WORLD.Create_cart(dims=[2, 2], # Here we assume clients are on the same node, so we restrict which device
periods=[True, True]) # they can use based on their rank
comms = [cart_comm.Sub([True, False]), os.environ["CUDA_VISIBLE_DEVICES"] = "%d" % (rank + 1)
cart_comm.Sub([False, True])]
jaxdecomp.init()
# Setup random keys # Setup random keys
master_key = jax.random.PRNGKey(42) master_key = jax.random.PRNGKey(42)
@ -29,50 +31,77 @@ key = jax.random.split(master_key, size)[rank]
# Size and parameters of the simulation volume # Size and parameters of the simulation volume
N = 256 N = 256
mesh_shape = [N, N, N] mesh_shape = (N, N, N)
box_size = [205, 205, 205] # Mpc/h box_size = [500, 500, 500] # Mpc/h
halo_size = 32
sharding_info = ShardingInfo(global_shape=mesh_shape,
pdims=(1,2),
halo_extents=(halo_size, halo_size, 0),
rank=rank)
cosmo = jc.Planck15() cosmo = jc.Planck15()
halo_size = 16
a = 0.1 a = 0.1
@jax.jit @jax.jit
def run_sim(cosmo, key): def run_sim(cosmo, key):
initial_conditions = linear_field(cosmo, mesh_shape, box_size, key, initial_conditions = linear_field(cosmo, mesh_shape, box_size, key,
comms=comms) sharding_info=sharding_info)
init_field = ifft3d(initial_conditions, comms=comms).real
init_field = ifft3d(initial_conditions, sharding_info=sharding_info).real
# Initialize particles # Initialize particles
pos = meshgrid3d(mesh_shape, comms=comms) pos = meshgrid3d(mesh_shape, sharding_info=sharding_info)
# Initial displacement by LPT # Initial displacement by LPT
cosmo = jc.Planck15() cosmo = jc.Planck15()
dx, p, f = lpt(cosmo, pos, initial_conditions, a, comms=comms) dx, p, f = lpt(cosmo, pos, initial_conditions, a, halo_size=halo_size, sharding_info=sharding_info)
# And now, we run an actual nbody # And now, we run an actual nbody
res = odeint(make_ode_fn(mesh_shape, halo_size, comms), res = odeint(make_ode_fn(mesh_shape, halo_size, sharding_info),
[pos+dx, p], jnp.linspace(0.1, 1.0, 2), cosmo, [pos+dx, p], jnp.linspace(0.1, 1.0, 2), cosmo,
rtol=1e-5, atol=1e-5) rtol=1e-3, atol=1e-3)
# Painting on a new mesh # Painting on a new mesh
field = cic_paint(zeros(mesh_shape, comms=comms), field = cic_paint(zeros(mesh_shape, sharding_info=sharding_info),
res[0][-1], halo_size, comms=comms) res[0][-1], halo_size, sharding_info=sharding_info)
# field = cic_paint(zeros(mesh_shape, sharding_info=sharding_info),
# pos+dx, halo_size, sharding_info=sharding_info)
return init_field, field return init_field, field
# initial_conditions = linear_field(cosmo, mesh_shape, box_size, key,
# sharding_info=sharding_info)
# Recover the real space initial conditions # init_field = ifft3d(initial_conditions, sharding_info=sharding_info).real
# print("hello", init_field.shape)
# cosmo = jc.Planck15()
# pos = meshgrid3d(mesh_shape, sharding_info=sharding_info)
# dx, p, f = lpt(cosmo, pos, initial_conditions, a, sharding_info=sharding_info)
# #dx = 3*jax.random.normal(key=key, shape=[1048576, 3])
# # Initialize particles
# # pos = meshgrid3d(mesh_shape, sharding_info=sharding_info)
# field = cic_paint(zeros(mesh_shape, sharding_info=sharding_info),
# pos+dx, halo_size, sharding_info=sharding_info)
# # Recover the real space initial conditions
init_field, field = run_sim(cosmo, key) init_field, field = run_sim(cosmo, key)
# Testing that the result is actually looking like what we expect # import jaxdecomp
total_array, token = mpi4jax.allgather(field, comm=comms[0]) # field = jaxdecomp.halo_exchange(field,
total_array = total_array.reshape([N, N//2, N]) # halo_extents=sharding_info.halo_extents,
total_array, token = mpi4jax.allgather( # halo_periods=(True,True,True),
total_array.transpose([1, 0, 2]), comm=comms[1], token=token) # pdims=sharding_info.pdims,
total_array = total_array.reshape([N, N, N]) # global_shape=sharding_info.global_shape)
total_array = total_array.transpose([1, 0, 2])
if rank == 0: # time1 = time.time()
onp.save('simulation.npy', total_array) # init_field, field = run_sim(cosmo, key)
# init_field.block_until_ready()
# time2 = time.time()
print('Done !') # if rank == 0:
onp.save('simulation_%d.npy'%rank, field)
# print('Done in', time2-time1)