temp commit

This commit is contained in:
Wassim KABALAN 2024-04-19 01:11:25 +02:00
parent 6ca4c9191e
commit 055ceedb7e
5 changed files with 220 additions and 110 deletions

View file

@ -10,19 +10,20 @@ from jaxpm.painting import cic_paint
from jax.experimental.ode import odeint
import jax_cosmo as jc
import time
from jax.experimental import mesh_utils, multihost_utils
from jax.sharding import Mesh, PartitionSpec as P,NamedSharding
from functools import partial
### Setting up a whole bunch of things #######
# Create communicators
world = MPI.COMM_WORLD
rank = world.Get_rank()
size = world.Get_size()
jax.config.update("jax_enable_x64", True)
# Here we assume clients are on the same node, so we restrict which device
# they can use based on their rank
os.environ["CUDA_VISIBLE_DEVICES"] = "%d" % (rank + 1)
jaxdecomp.init()
jax.distributed.initialize()
# Setup random keys
master_key = jax.random.PRNGKey(42)
@ -35,37 +36,57 @@ mesh_shape = (N, N, N)
box_size = [500, 500, 500] # Mpc/h
halo_size = 32
sharding_info = ShardingInfo(global_shape=mesh_shape,
pdims=(1,2),
pdims=(2,2),
halo_extents=(halo_size, halo_size, 0),
rank=rank)
cosmo = jc.Planck15()
a = 0.1
devices = mesh_utils.create_device_mesh(sharding_info.pdims[::-1])
mesh = Mesh(devices, axis_names=('z', 'y'))
initial_conditions = linear_field(cosmo, mesh, mesh_shape, box_size, key,
sharding_info=sharding_info)
@jax.jit
def run_sim(cosmo, key):
initial_conditions = linear_field(cosmo, mesh_shape, box_size, key,
sharding_info=sharding_info)
def ifft3d_c2r(initial_conditions):
return ifft3d(initial_conditions, sharding_info=sharding_info).real
init_field = ifft3d(initial_conditions, sharding_info=sharding_info).real
@jax.jit
def compute_displacement(p , dx):
return p + dx
# Initialize particles
pos = meshgrid3d(mesh_shape, sharding_info=sharding_info)
def run_sim(mesh , initial_conditions, cosmo, key):
with mesh:
init_field = ifft3d_c2r(initial_conditions)
# Initialize particles
pos = meshgrid3d(mesh_shape, sharding_info=sharding_info)
# Initial displacement by LPT
cosmo = jc.Planck15()
dx, p, f = lpt(cosmo, pos, initial_conditions, a, halo_size=halo_size, sharding_info=sharding_info)
dx, p, f = lpt(mesh , cosmo, pos, initial_conditions, a, halo_size=halo_size, sharding_info=sharding_info)
# And now, we run an actual nbody
res = odeint(make_ode_fn(mesh_shape, halo_size, sharding_info),
[pos+dx, p], jnp.linspace(0.1, 1.0, 2), cosmo,
rtol=1e-3, atol=1e-3)
# Painting on a new mesh
field = cic_paint(zeros(mesh_shape, sharding_info=sharding_info),
res[0][-1], halo_size, sharding_info=sharding_info)
# field = cic_paint(zeros(mesh_shape, sharding_info=sharding_info),
# pos+dx, halo_size, sharding_info=sharding_info)
#res = odeint(make_ode_fn(mesh_shape, halo_size, sharding_info),
# [pos+dx, p], jnp.linspace(0.1, 1.0, 2), cosmo,
# rtol=1e-3, atol=1e-3)
## Painting on a new mesh
print(f"shape of p {p.shape}")
print(f"shape of dx {dx.shape}")
with mesh:
displacement = compute_displacement(p , dx)
empty_field = zeros(mesh_shape, sharding_info=sharding_info)
field = cic_paint(mesh , empty_field,
displacement, halo_size, sharding_info=sharding_info)
return init_field, field
# initial_conditions = linear_field(cosmo, mesh_shape, box_size, key,
@ -87,7 +108,8 @@ def run_sim(cosmo, key):
# pos+dx, halo_size, sharding_info=sharding_info)
# # Recover the real space initial conditions
init_field, field = run_sim(cosmo, key)
run_sim(mesh , initial_conditions,cosmo, key)
#init_field, field = run_sim(mesh , initial_conditions,cosmo, key)
# import jaxdecomp
# field = jaxdecomp.halo_exchange(field,
@ -102,6 +124,10 @@ init_field, field = run_sim(cosmo, key)
# time2 = time.time()
# if rank == 0:
onp.save('simulation_%d.npy'%rank, field)
#onp.save('simulation_%d.npy'%rank, field)
# print('Done in', time2-time1)
print("Done")
jaxdecomp.finalize()