mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-29 08:31:11 +00:00
Adding an example of jaxdecomp implementation
This commit is contained in:
parent
6644b35d71
commit
6ca4c9191e
5 changed files with 166 additions and 192 deletions
|
@ -1,15 +1,15 @@
|
|||
from dataclasses import fields
|
||||
from mpi4py import MPI
|
||||
import os
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as onp
|
||||
import mpi4jax
|
||||
from jaxpm.ops import fft3d, ifft3d, normal, meshgrid3d, zeros
|
||||
import jaxdecomp
|
||||
from jaxpm.ops import fft3d, ifft3d, normal, meshgrid3d, zeros, ShardingInfo
|
||||
from jaxpm.pm import linear_field, lpt, make_ode_fn
|
||||
from jaxpm.painting import cic_paint
|
||||
from jax.experimental.ode import odeint
|
||||
import jax_cosmo as jc
|
||||
|
||||
import time
|
||||
|
||||
### Setting up a whole bunch of things #######
|
||||
# Create communicators
|
||||
|
@ -17,10 +17,12 @@ world = MPI.COMM_WORLD
|
|||
rank = world.Get_rank()
|
||||
size = world.Get_size()
|
||||
|
||||
cart_comm = MPI.COMM_WORLD.Create_cart(dims=[2, 2],
|
||||
periods=[True, True])
|
||||
comms = [cart_comm.Sub([True, False]),
|
||||
cart_comm.Sub([False, True])]
|
||||
# Here we assume clients are on the same node, so we restrict which device
|
||||
# they can use based on their rank
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "%d" % (rank + 1)
|
||||
|
||||
|
||||
jaxdecomp.init()
|
||||
|
||||
# Setup random keys
|
||||
master_key = jax.random.PRNGKey(42)
|
||||
|
@ -29,50 +31,77 @@ key = jax.random.split(master_key, size)[rank]
|
|||
|
||||
# Size and parameters of the simulation volume
|
||||
N = 256
|
||||
mesh_shape = [N, N, N]
|
||||
box_size = [205, 205, 205] # Mpc/h
|
||||
mesh_shape = (N, N, N)
|
||||
box_size = [500, 500, 500] # Mpc/h
|
||||
halo_size = 32
|
||||
sharding_info = ShardingInfo(global_shape=mesh_shape,
|
||||
pdims=(1,2),
|
||||
halo_extents=(halo_size, halo_size, 0),
|
||||
rank=rank)
|
||||
cosmo = jc.Planck15()
|
||||
halo_size = 16
|
||||
a = 0.1
|
||||
|
||||
|
||||
@jax.jit
|
||||
def run_sim(cosmo, key):
|
||||
initial_conditions = linear_field(cosmo, mesh_shape, box_size, key,
|
||||
comms=comms)
|
||||
init_field = ifft3d(initial_conditions, comms=comms).real
|
||||
sharding_info=sharding_info)
|
||||
|
||||
init_field = ifft3d(initial_conditions, sharding_info=sharding_info).real
|
||||
|
||||
# Initialize particles
|
||||
pos = meshgrid3d(mesh_shape, comms=comms)
|
||||
pos = meshgrid3d(mesh_shape, sharding_info=sharding_info)
|
||||
|
||||
# Initial displacement by LPT
|
||||
cosmo = jc.Planck15()
|
||||
dx, p, f = lpt(cosmo, pos, initial_conditions, a, comms=comms)
|
||||
dx, p, f = lpt(cosmo, pos, initial_conditions, a, halo_size=halo_size, sharding_info=sharding_info)
|
||||
|
||||
# And now, we run an actual nbody
|
||||
res = odeint(make_ode_fn(mesh_shape, halo_size, comms),
|
||||
res = odeint(make_ode_fn(mesh_shape, halo_size, sharding_info),
|
||||
[pos+dx, p], jnp.linspace(0.1, 1.0, 2), cosmo,
|
||||
rtol=1e-5, atol=1e-5)
|
||||
|
||||
rtol=1e-3, atol=1e-3)
|
||||
# Painting on a new mesh
|
||||
field = cic_paint(zeros(mesh_shape, comms=comms),
|
||||
res[0][-1], halo_size, comms=comms)
|
||||
|
||||
field = cic_paint(zeros(mesh_shape, sharding_info=sharding_info),
|
||||
res[0][-1], halo_size, sharding_info=sharding_info)
|
||||
|
||||
# field = cic_paint(zeros(mesh_shape, sharding_info=sharding_info),
|
||||
# pos+dx, halo_size, sharding_info=sharding_info)
|
||||
return init_field, field
|
||||
|
||||
# initial_conditions = linear_field(cosmo, mesh_shape, box_size, key,
|
||||
# sharding_info=sharding_info)
|
||||
|
||||
# Recover the real space initial conditions
|
||||
# init_field = ifft3d(initial_conditions, sharding_info=sharding_info).real
|
||||
|
||||
# print("hello", init_field.shape)
|
||||
|
||||
# cosmo = jc.Planck15()
|
||||
# pos = meshgrid3d(mesh_shape, sharding_info=sharding_info)
|
||||
# dx, p, f = lpt(cosmo, pos, initial_conditions, a, sharding_info=sharding_info)
|
||||
|
||||
# #dx = 3*jax.random.normal(key=key, shape=[1048576, 3])
|
||||
# # Initialize particles
|
||||
# # pos = meshgrid3d(mesh_shape, sharding_info=sharding_info)
|
||||
|
||||
# field = cic_paint(zeros(mesh_shape, sharding_info=sharding_info),
|
||||
# pos+dx, halo_size, sharding_info=sharding_info)
|
||||
|
||||
# # Recover the real space initial conditions
|
||||
init_field, field = run_sim(cosmo, key)
|
||||
|
||||
# Testing that the result is actually looking like what we expect
|
||||
total_array, token = mpi4jax.allgather(field, comm=comms[0])
|
||||
total_array = total_array.reshape([N, N//2, N])
|
||||
total_array, token = mpi4jax.allgather(
|
||||
total_array.transpose([1, 0, 2]), comm=comms[1], token=token)
|
||||
total_array = total_array.reshape([N, N, N])
|
||||
total_array = total_array.transpose([1, 0, 2])
|
||||
# import jaxdecomp
|
||||
# field = jaxdecomp.halo_exchange(field,
|
||||
# halo_extents=sharding_info.halo_extents,
|
||||
# halo_periods=(True,True,True),
|
||||
# pdims=sharding_info.pdims,
|
||||
# global_shape=sharding_info.global_shape)
|
||||
|
||||
if rank == 0:
|
||||
onp.save('simulation.npy', total_array)
|
||||
# time1 = time.time()
|
||||
# init_field, field = run_sim(cosmo, key)
|
||||
# init_field.block_until_ready()
|
||||
# time2 = time.time()
|
||||
|
||||
print('Done !')
|
||||
# if rank == 0:
|
||||
onp.save('simulation_%d.npy'%rank, field)
|
||||
|
||||
# print('Done in', time2-time1)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue