mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-29 16:41:11 +00:00
pm ok
This commit is contained in:
parent
055ceedb7e
commit
179030377b
4 changed files with 63 additions and 38 deletions
|
@ -1,6 +1,7 @@
|
|||
from mpi4py import MPI
|
||||
import os
|
||||
import jax
|
||||
from jax import jit
|
||||
import jax.numpy as jnp
|
||||
import numpy as onp
|
||||
import jaxdecomp
|
||||
|
@ -53,12 +54,6 @@ initial_conditions = linear_field(cosmo, mesh, mesh_shape, box_size, key,
|
|||
def ifft3d_c2r(initial_conditions):
|
||||
return ifft3d(initial_conditions, sharding_info=sharding_info).real
|
||||
|
||||
@jax.jit
|
||||
def compute_displacement(p , dx):
|
||||
return p + dx
|
||||
|
||||
|
||||
|
||||
def run_sim(mesh , initial_conditions, cosmo, key):
|
||||
|
||||
with mesh:
|
||||
|
@ -77,12 +72,10 @@ def run_sim(mesh , initial_conditions, cosmo, key):
|
|||
# [pos+dx, p], jnp.linspace(0.1, 1.0, 2), cosmo,
|
||||
# rtol=1e-3, atol=1e-3)
|
||||
## Painting on a new mesh
|
||||
print(f"shape of p {p.shape}")
|
||||
print(f"shape of dx {dx.shape}")
|
||||
with mesh:
|
||||
displacement = compute_displacement(p , dx)
|
||||
displacement = jit(jnp.add)(p , dx)
|
||||
|
||||
empty_field = zeros(mesh_shape, sharding_info=sharding_info)
|
||||
empty_field = zeros(mesh , mesh_shape, sharding_info=sharding_info)
|
||||
|
||||
field = cic_paint(mesh , empty_field,
|
||||
displacement, halo_size, sharding_info=sharding_info)
|
||||
|
@ -108,7 +101,7 @@ def run_sim(mesh , initial_conditions, cosmo, key):
|
|||
# pos+dx, halo_size, sharding_info=sharding_info)
|
||||
|
||||
# # Recover the real space initial conditions
|
||||
run_sim(mesh , initial_conditions,cosmo, key)
|
||||
init_field, field = run_sim(mesh , initial_conditions,cosmo, key)
|
||||
#init_field, field = run_sim(mesh , initial_conditions,cosmo, key)
|
||||
|
||||
# import jaxdecomp
|
||||
|
@ -123,8 +116,11 @@ run_sim(mesh , initial_conditions,cosmo, key)
|
|||
# init_field.block_until_ready()
|
||||
# time2 = time.time()
|
||||
|
||||
|
||||
|
||||
# if rank == 0:
|
||||
#onp.save('simulation_%d.npy'%rank, field)
|
||||
onp.save('simulation_init_field_float16_%d.npy'%rank, init_field.addressable_data(0).astype(onp.float16))
|
||||
onp.save('simulation_field_float16_%d.npy'%rank, field.addressable_data(0).astype(onp.float16))
|
||||
|
||||
# print('Done in', time2-time1)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue