This commit is contained in:
Wassim KABALAN 2024-04-19 10:32:38 +02:00
parent 055ceedb7e
commit 179030377b
4 changed files with 63 additions and 38 deletions

View file

@ -1,6 +1,7 @@
from mpi4py import MPI
import os
import jax
from jax import jit
import jax.numpy as jnp
import numpy as onp
import jaxdecomp
@ -53,12 +54,6 @@ initial_conditions = linear_field(cosmo, mesh, mesh_shape, box_size, key,
def ifft3d_c2r(initial_conditions):
return ifft3d(initial_conditions, sharding_info=sharding_info).real
@jax.jit
def compute_displacement(p , dx):
return p + dx
def run_sim(mesh , initial_conditions, cosmo, key):
with mesh:
@ -77,12 +72,10 @@ def run_sim(mesh , initial_conditions, cosmo, key):
# [pos+dx, p], jnp.linspace(0.1, 1.0, 2), cosmo,
# rtol=1e-3, atol=1e-3)
## Painting on a new mesh
print(f"shape of p {p.shape}")
print(f"shape of dx {dx.shape}")
with mesh:
displacement = compute_displacement(p , dx)
displacement = jit(jnp.add)(p , dx)
empty_field = zeros(mesh_shape, sharding_info=sharding_info)
empty_field = zeros(mesh , mesh_shape, sharding_info=sharding_info)
field = cic_paint(mesh , empty_field,
displacement, halo_size, sharding_info=sharding_info)
@ -108,7 +101,7 @@ def run_sim(mesh , initial_conditions, cosmo, key):
# pos+dx, halo_size, sharding_info=sharding_info)
# # Recover the real space initial conditions
run_sim(mesh , initial_conditions,cosmo, key)
init_field, field = run_sim(mesh , initial_conditions,cosmo, key)
#init_field, field = run_sim(mesh , initial_conditions,cosmo, key)
# import jaxdecomp
@ -123,8 +116,11 @@ run_sim(mesh , initial_conditions,cosmo, key)
# init_field.block_until_ready()
# time2 = time.time()
# if rank == 0:
#onp.save('simulation_%d.npy'%rank, field)
onp.save('simulation_init_field_float16_%d.npy'%rank, init_field.addressable_data(0).astype(onp.float16))
onp.save('simulation_field_float16_%d.npy'%rank, field.addressable_data(0).astype(onp.float16))
# print('Done in', time2-time1)