mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-29 08:31:11 +00:00
fixed a whole lot of issues
This commit is contained in:
parent
429813ad92
commit
72ae0fd88f
5 changed files with 251 additions and 155 deletions
78
scripts/test_nbody.py
Normal file
78
scripts/test_nbody.py
Normal file
|
@ -0,0 +1,78 @@
|
|||
from dataclasses import fields
|
||||
from mpi4py import MPI
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as onp
|
||||
import mpi4jax
|
||||
from jaxpm.ops import fft3d, ifft3d, normal, meshgrid3d, zeros
|
||||
from jaxpm.pm import linear_field, lpt, make_ode_fn
|
||||
from jaxpm.painting import cic_paint
|
||||
from jax.experimental.ode import odeint
|
||||
import jax_cosmo as jc
|
||||
|
||||
|
||||
### Setting up a whole bunch of things #######
|
||||
# Create communicators
|
||||
world = MPI.COMM_WORLD
|
||||
rank = world.Get_rank()
|
||||
size = world.Get_size()
|
||||
|
||||
cart_comm = MPI.COMM_WORLD.Create_cart(dims=[2, 2],
|
||||
periods=[True, True])
|
||||
comms = [cart_comm.Sub([True, False]),
|
||||
cart_comm.Sub([False, True])]
|
||||
|
||||
# Setup random keys
|
||||
master_key = jax.random.PRNGKey(42)
|
||||
key = jax.random.split(master_key, size)[rank]
|
||||
################################################
|
||||
|
||||
# Size and parameters of the simulation volume
|
||||
N = 256
|
||||
mesh_shape = [N, N, N]
|
||||
box_size = [205, 205, 205] # Mpc/h
|
||||
cosmo = jc.Planck15()
|
||||
halo_size = 16
|
||||
a = 0.1
|
||||
|
||||
|
||||
@jax.jit
|
||||
def run_sim(cosmo, key):
|
||||
initial_conditions = linear_field(cosmo, mesh_shape, box_size, key,
|
||||
comms=comms)
|
||||
init_field = ifft3d(initial_conditions, comms=comms).real
|
||||
|
||||
# Initialize particles
|
||||
pos = meshgrid3d(mesh_shape, comms=comms)
|
||||
|
||||
# Initial displacement by LPT
|
||||
cosmo = jc.Planck15()
|
||||
dx, p, f = lpt(cosmo, pos, initial_conditions, a, comms=comms)
|
||||
|
||||
# And now, we run an actual nbody
|
||||
res = odeint(make_ode_fn(mesh_shape, halo_size, comms),
|
||||
[pos+dx, p], jnp.linspace(0.1, 1.0, 2), cosmo,
|
||||
rtol=1e-5, atol=1e-5)
|
||||
|
||||
# Painting on a new mesh
|
||||
field = cic_paint(zeros(mesh_shape, comms=comms),
|
||||
res[0][-1], halo_size, comms=comms)
|
||||
|
||||
return init_field, field
|
||||
|
||||
|
||||
# Recover the real space initial conditions
|
||||
init_field, field = run_sim(cosmo, key)
|
||||
|
||||
# Testing that the result is actually looking like what we expect
|
||||
total_array, token = mpi4jax.allgather(field, comm=comms[0])
|
||||
total_array = total_array.reshape([N, N//2, N])
|
||||
total_array, token = mpi4jax.allgather(
|
||||
total_array.transpose([1, 0, 2]), comm=comms[1], token=token)
|
||||
total_array = total_array.reshape([N, N, N])
|
||||
total_array = total_array.transpose([1, 0, 2])
|
||||
|
||||
if rank == 0:
|
||||
onp.save('simulation.npy', total_array)
|
||||
|
||||
print('Done !')
|
Loading…
Add table
Add a link
Reference in a new issue