Adding an example of jaxdecomp implementation

This commit is contained in:
EiffL 2022-11-26 17:27:14 +01:00
parent 6644b35d71
commit 6ca4c9191e
5 changed files with 166 additions and 192 deletions

View file

@ -1,15 +1,15 @@
from dataclasses import fields
from mpi4py import MPI
import os
import jax
import jax.numpy as jnp
import numpy as onp
import mpi4jax
from jaxpm.ops import fft3d, ifft3d, normal, meshgrid3d, zeros
import jaxdecomp
from jaxpm.ops import fft3d, ifft3d, normal, meshgrid3d, zeros, ShardingInfo
from jaxpm.pm import linear_field, lpt, make_ode_fn
from jaxpm.painting import cic_paint
from jax.experimental.ode import odeint
import jax_cosmo as jc
import time
### Setting up a whole bunch of things #######
# Create communicators
@ -17,10 +17,12 @@ world = MPI.COMM_WORLD
rank = world.Get_rank()
size = world.Get_size()
cart_comm = MPI.COMM_WORLD.Create_cart(dims=[2, 2],
periods=[True, True])
comms = [cart_comm.Sub([True, False]),
cart_comm.Sub([False, True])]
# Here we assume clients are on the same node, so we restrict which device
# they can use based on their rank
os.environ["CUDA_VISIBLE_DEVICES"] = "%d" % (rank + 1)
jaxdecomp.init()
# Setup random keys
master_key = jax.random.PRNGKey(42)
@ -29,50 +31,77 @@ key = jax.random.split(master_key, size)[rank]
# Size and parameters of the simulation volume
N = 256
mesh_shape = [N, N, N]
box_size = [205, 205, 205] # Mpc/h
mesh_shape = (N, N, N)
box_size = [500, 500, 500] # Mpc/h
halo_size = 32
sharding_info = ShardingInfo(global_shape=mesh_shape,
pdims=(1,2),
halo_extents=(halo_size, halo_size, 0),
rank=rank)
cosmo = jc.Planck15()
halo_size = 16
a = 0.1
@jax.jit
def run_sim(cosmo, key):
initial_conditions = linear_field(cosmo, mesh_shape, box_size, key,
comms=comms)
init_field = ifft3d(initial_conditions, comms=comms).real
sharding_info=sharding_info)
init_field = ifft3d(initial_conditions, sharding_info=sharding_info).real
# Initialize particles
pos = meshgrid3d(mesh_shape, comms=comms)
pos = meshgrid3d(mesh_shape, sharding_info=sharding_info)
# Initial displacement by LPT
cosmo = jc.Planck15()
dx, p, f = lpt(cosmo, pos, initial_conditions, a, comms=comms)
dx, p, f = lpt(cosmo, pos, initial_conditions, a, halo_size=halo_size, sharding_info=sharding_info)
# And now, we run an actual nbody
res = odeint(make_ode_fn(mesh_shape, halo_size, comms),
res = odeint(make_ode_fn(mesh_shape, halo_size, sharding_info),
[pos+dx, p], jnp.linspace(0.1, 1.0, 2), cosmo,
rtol=1e-5, atol=1e-5)
rtol=1e-3, atol=1e-3)
# Painting on a new mesh
field = cic_paint(zeros(mesh_shape, comms=comms),
res[0][-1], halo_size, comms=comms)
field = cic_paint(zeros(mesh_shape, sharding_info=sharding_info),
res[0][-1], halo_size, sharding_info=sharding_info)
# field = cic_paint(zeros(mesh_shape, sharding_info=sharding_info),
# pos+dx, halo_size, sharding_info=sharding_info)
return init_field, field
# initial_conditions = linear_field(cosmo, mesh_shape, box_size, key,
# sharding_info=sharding_info)
# Recover the real space initial conditions
# init_field = ifft3d(initial_conditions, sharding_info=sharding_info).real
# print("hello", init_field.shape)
# cosmo = jc.Planck15()
# pos = meshgrid3d(mesh_shape, sharding_info=sharding_info)
# dx, p, f = lpt(cosmo, pos, initial_conditions, a, sharding_info=sharding_info)
# #dx = 3*jax.random.normal(key=key, shape=[1048576, 3])
# # Initialize particles
# # pos = meshgrid3d(mesh_shape, sharding_info=sharding_info)
# field = cic_paint(zeros(mesh_shape, sharding_info=sharding_info),
# pos+dx, halo_size, sharding_info=sharding_info)
# # Recover the real space initial conditions
init_field, field = run_sim(cosmo, key)
# Testing that the result is actually looking like what we expect
total_array, token = mpi4jax.allgather(field, comm=comms[0])
total_array = total_array.reshape([N, N//2, N])
total_array, token = mpi4jax.allgather(
total_array.transpose([1, 0, 2]), comm=comms[1], token=token)
total_array = total_array.reshape([N, N, N])
total_array = total_array.transpose([1, 0, 2])
# import jaxdecomp
# field = jaxdecomp.halo_exchange(field,
# halo_extents=sharding_info.halo_extents,
# halo_periods=(True,True,True),
# pdims=sharding_info.pdims,
# global_shape=sharding_info.global_shape)
if rank == 0:
onp.save('simulation.npy', total_array)
# time1 = time.time()
# init_field, field = run_sim(cosmo, key)
# init_field.block_until_ready()
# time2 = time.time()
print('Done !')
# if rank == 0:
onp.save('simulation_%d.npy'%rank, field)
# print('Done in', time2-time1)