mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-24 19:50:55 +00:00
78 lines
2.3 KiB
Python
78 lines
2.3 KiB
Python
from dataclasses import fields
|
|
from mpi4py import MPI
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import numpy as onp
|
|
import mpi4jax
|
|
from jaxpm.ops import fft3d, ifft3d, normal, meshgrid3d, zeros
|
|
from jaxpm.pm import linear_field, lpt, make_ode_fn
|
|
from jaxpm.painting import cic_paint
|
|
from jax.experimental.ode import odeint
|
|
import jax_cosmo as jc
|
|
|
|
|
|
### Setting up a whole bunch of things #######
|
|
# Create communicators
|
|
world = MPI.COMM_WORLD
|
|
rank = world.Get_rank()
|
|
size = world.Get_size()
|
|
|
|
cart_comm = MPI.COMM_WORLD.Create_cart(dims=[2, 2],
|
|
periods=[True, True])
|
|
comms = [cart_comm.Sub([True, False]),
|
|
cart_comm.Sub([False, True])]
|
|
|
|
# Setup random keys
|
|
master_key = jax.random.PRNGKey(42)
|
|
key = jax.random.split(master_key, size)[rank]
|
|
################################################
|
|
|
|
# Size and parameters of the simulation volume
|
|
N = 256
|
|
mesh_shape = [N, N, N]
|
|
box_size = [205, 205, 205] # Mpc/h
|
|
cosmo = jc.Planck15()
|
|
halo_size = 16
|
|
a = 0.1
|
|
|
|
|
|
@jax.jit
|
|
def run_sim(cosmo, key):
|
|
initial_conditions = linear_field(cosmo, mesh_shape, box_size, key,
|
|
comms=comms)
|
|
init_field = ifft3d(initial_conditions, comms=comms).real
|
|
|
|
# Initialize particles
|
|
pos = meshgrid3d(mesh_shape, comms=comms)
|
|
|
|
# Initial displacement by LPT
|
|
cosmo = jc.Planck15()
|
|
dx, p, f = lpt(cosmo, pos, initial_conditions, a, comms=comms)
|
|
|
|
# And now, we run an actual nbody
|
|
res = odeint(make_ode_fn(mesh_shape, halo_size, comms),
|
|
[pos+dx, p], jnp.linspace(0.1, 1.0, 2), cosmo,
|
|
rtol=1e-5, atol=1e-5)
|
|
|
|
# Painting on a new mesh
|
|
field = cic_paint(zeros(mesh_shape, comms=comms),
|
|
res[0][-1], halo_size, comms=comms)
|
|
|
|
return init_field, field
|
|
|
|
|
|
# Recover the real space initial conditions
|
|
init_field, field = run_sim(cosmo, key)
|
|
|
|
# Testing that the result is actually looking like what we expect
|
|
total_array, token = mpi4jax.allgather(field, comm=comms[0])
|
|
total_array = total_array.reshape([N, N//2, N])
|
|
total_array, token = mpi4jax.allgather(
|
|
total_array.transpose([1, 0, 2]), comm=comms[1], token=token)
|
|
total_array = total_array.reshape([N, N, N])
|
|
total_array = total_array.transpose([1, 0, 2])
|
|
|
|
if rank == 0:
|
|
onp.save('simulation.npy', total_array)
|
|
|
|
print('Done !')
|