mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-24 19:50:55 +00:00
75 lines
1.9 KiB
Python
75 lines
1.9 KiB
Python
import jax
|
|
jax.distributed.initialize()
|
|
rank = jax.process_index()
|
|
size = jax.process_count()
|
|
|
|
print(f"Started process {rank} of {size}")
|
|
|
|
import diffrax
|
|
import jaxpm as jpm
|
|
from jaxpm import solvers
|
|
import jax_cosmo as jc
|
|
import numpy as np
|
|
from jax.experimental import mesh_utils, multihost_utils
|
|
from jax.sharding import Mesh
|
|
from jax.sharding import PartitionSpec as P
|
|
|
|
import jaxpm as jaxpm
|
|
|
|
|
|
cosmo = jc.Planck15(Omega_c=0.25, sigma8=0.8)
|
|
# Create initial field
|
|
size = 256
|
|
mesh_shape = (size, size, size)
|
|
box_size = [float(size), float(size), float(size)]
|
|
|
|
|
|
def gen_input():
|
|
|
|
key = jax.random.PRNGKey(0)
|
|
initial_field = jpm.ops.normal(mesh_shape,key)
|
|
kvec = jpm.ops.fftk(mesh_shape , symmetric=False)
|
|
|
|
return initial_field , kvec
|
|
|
|
@jax.jit
|
|
def fn(cosmo , initial_field , kvec):
|
|
solver = solvers.FastPM()
|
|
particles = jpm.ops.generate_initial_positions(mesh_shape)
|
|
|
|
|
|
state = solver.init_state(cosmo , particles , kvec , initial_field , box_size)
|
|
|
|
state = solver.lpt(state , a=0.1)
|
|
|
|
diffsolver = diffrax.Dopri5()
|
|
step_size = diffrax.PIDController(rtol=1e-3,atol=1e-3)
|
|
|
|
state = solver.nbody(state , solver=diffsolver , stepsize_controller=step_size,
|
|
t0=0.1,
|
|
t1=1,
|
|
dt0=0.01)
|
|
|
|
final_field = jpm.painting.cic_paint_dx(state.displacements)
|
|
return final_field
|
|
|
|
|
|
|
|
# One GPU
|
|
initial_field , kvec = gen_input()
|
|
final_field = fn(cosmo , initial_field , kvec)
|
|
np.save('file.npy',final_field)
|
|
|
|
# Multiple GPUs
|
|
# pdims = (2 , 2)
|
|
# devices = mesh_utils.create_device_mesh(pdims)
|
|
# mesh = Mesh(devices, axis_names=('y', 'z'))
|
|
# sharding = jax.sharding.NamedSharding(mesh, P('z', 'y'))
|
|
|
|
|
|
# with jaxpm.SPMDConfig(sharding):
|
|
# initial_field , kvec = gen_input()
|
|
# final_field = fn(cosmo , initial_field , kvec)
|
|
|
|
# final_field = multihost_utils.process_allgather(final_field , tiled=True)
|
|
# np.save('file_spmd.npy',final_field)
|