mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-29 08:31:11 +00:00
temp commit
This commit is contained in:
parent
6ca4c9191e
commit
055ceedb7e
5 changed files with 220 additions and 110 deletions
|
@ -10,19 +10,20 @@ from jaxpm.painting import cic_paint
|
|||
from jax.experimental.ode import odeint
|
||||
import jax_cosmo as jc
|
||||
import time
|
||||
|
||||
from jax.experimental import mesh_utils, multihost_utils
|
||||
from jax.sharding import Mesh, PartitionSpec as P,NamedSharding
|
||||
from functools import partial
|
||||
### Setting up a whole bunch of things #######
|
||||
# Create communicators
|
||||
world = MPI.COMM_WORLD
|
||||
rank = world.Get_rank()
|
||||
size = world.Get_size()
|
||||
|
||||
jax.config.update("jax_enable_x64", True)
|
||||
|
||||
# Here we assume clients are on the same node, so we restrict which device
|
||||
# they can use based on their rank
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "%d" % (rank + 1)
|
||||
|
||||
|
||||
jaxdecomp.init()
|
||||
jax.distributed.initialize()
|
||||
|
||||
# Setup random keys
|
||||
master_key = jax.random.PRNGKey(42)
|
||||
|
@ -35,37 +36,57 @@ mesh_shape = (N, N, N)
|
|||
box_size = [500, 500, 500] # Mpc/h
|
||||
halo_size = 32
|
||||
sharding_info = ShardingInfo(global_shape=mesh_shape,
|
||||
pdims=(1,2),
|
||||
pdims=(2,2),
|
||||
halo_extents=(halo_size, halo_size, 0),
|
||||
rank=rank)
|
||||
cosmo = jc.Planck15()
|
||||
a = 0.1
|
||||
|
||||
|
||||
devices = mesh_utils.create_device_mesh(sharding_info.pdims[::-1])
|
||||
mesh = Mesh(devices, axis_names=('z', 'y'))
|
||||
|
||||
initial_conditions = linear_field(cosmo, mesh, mesh_shape, box_size, key,
|
||||
sharding_info=sharding_info)
|
||||
|
||||
@jax.jit
|
||||
def run_sim(cosmo, key):
|
||||
initial_conditions = linear_field(cosmo, mesh_shape, box_size, key,
|
||||
sharding_info=sharding_info)
|
||||
def ifft3d_c2r(initial_conditions):
|
||||
return ifft3d(initial_conditions, sharding_info=sharding_info).real
|
||||
|
||||
init_field = ifft3d(initial_conditions, sharding_info=sharding_info).real
|
||||
@jax.jit
|
||||
def compute_displacement(p , dx):
|
||||
return p + dx
|
||||
|
||||
# Initialize particles
|
||||
pos = meshgrid3d(mesh_shape, sharding_info=sharding_info)
|
||||
|
||||
|
||||
def run_sim(mesh , initial_conditions, cosmo, key):
|
||||
|
||||
with mesh:
|
||||
init_field = ifft3d_c2r(initial_conditions)
|
||||
|
||||
# Initialize particles
|
||||
pos = meshgrid3d(mesh_shape, sharding_info=sharding_info)
|
||||
|
||||
# Initial displacement by LPT
|
||||
cosmo = jc.Planck15()
|
||||
dx, p, f = lpt(cosmo, pos, initial_conditions, a, halo_size=halo_size, sharding_info=sharding_info)
|
||||
|
||||
dx, p, f = lpt(mesh , cosmo, pos, initial_conditions, a, halo_size=halo_size, sharding_info=sharding_info)
|
||||
|
||||
# And now, we run an actual nbody
|
||||
res = odeint(make_ode_fn(mesh_shape, halo_size, sharding_info),
|
||||
[pos+dx, p], jnp.linspace(0.1, 1.0, 2), cosmo,
|
||||
rtol=1e-3, atol=1e-3)
|
||||
# Painting on a new mesh
|
||||
field = cic_paint(zeros(mesh_shape, sharding_info=sharding_info),
|
||||
res[0][-1], halo_size, sharding_info=sharding_info)
|
||||
|
||||
# field = cic_paint(zeros(mesh_shape, sharding_info=sharding_info),
|
||||
# pos+dx, halo_size, sharding_info=sharding_info)
|
||||
#res = odeint(make_ode_fn(mesh_shape, halo_size, sharding_info),
|
||||
# [pos+dx, p], jnp.linspace(0.1, 1.0, 2), cosmo,
|
||||
# rtol=1e-3, atol=1e-3)
|
||||
## Painting on a new mesh
|
||||
print(f"shape of p {p.shape}")
|
||||
print(f"shape of dx {dx.shape}")
|
||||
with mesh:
|
||||
displacement = compute_displacement(p , dx)
|
||||
|
||||
empty_field = zeros(mesh_shape, sharding_info=sharding_info)
|
||||
|
||||
field = cic_paint(mesh , empty_field,
|
||||
displacement, halo_size, sharding_info=sharding_info)
|
||||
|
||||
return init_field, field
|
||||
|
||||
# initial_conditions = linear_field(cosmo, mesh_shape, box_size, key,
|
||||
|
@ -87,7 +108,8 @@ def run_sim(cosmo, key):
|
|||
# pos+dx, halo_size, sharding_info=sharding_info)
|
||||
|
||||
# # Recover the real space initial conditions
|
||||
init_field, field = run_sim(cosmo, key)
|
||||
run_sim(mesh , initial_conditions,cosmo, key)
|
||||
#init_field, field = run_sim(mesh , initial_conditions,cosmo, key)
|
||||
|
||||
# import jaxdecomp
|
||||
# field = jaxdecomp.halo_exchange(field,
|
||||
|
@ -102,6 +124,10 @@ init_field, field = run_sim(cosmo, key)
|
|||
# time2 = time.time()
|
||||
|
||||
# if rank == 0:
|
||||
onp.save('simulation_%d.npy'%rank, field)
|
||||
#onp.save('simulation_%d.npy'%rank, field)
|
||||
|
||||
# print('Done in', time2-time1)
|
||||
|
||||
print("Done")
|
||||
|
||||
jaxdecomp.finalize()
|
Loading…
Add table
Add a link
Reference in a new issue