mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 20:30:54 +00:00
add example
This commit is contained in:
parent
abde5439f6
commit
82be56836a
2 changed files with 366 additions and 0 deletions
95
scripts/distributed_pm.py
Normal file
95
scripts/distributed_pm.py
Normal file
|
@ -0,0 +1,95 @@
|
|||
import os
|
||||
os.environ["EQX_ON_ERROR"] = "nan" # avoid an allgather caused by diffrax
|
||||
import jax
|
||||
jax.distributed.initialize()
|
||||
|
||||
rank = jax.process_index()
|
||||
size = jax.process_count()
|
||||
|
||||
import jax.numpy as jnp
|
||||
import jax_cosmo as jc
|
||||
|
||||
from jax.experimental.ode import odeint
|
||||
|
||||
from jaxpm.painting import cic_paint, cic_read , cic_paint_dx , cic_read_dx
|
||||
from jaxpm.pm import linear_field, lpt, make_ode_fn
|
||||
from diffrax import diffeqsolve, ODETerm, Dopri5, PIDController, SaveAt
|
||||
import numpy as np
|
||||
from jax.experimental import mesh_utils
|
||||
from jax.sharding import Mesh, PartitionSpec as P , NamedSharding
|
||||
from jaxpm.distributed import normal_field
|
||||
from jaxpm.kernels import interpolate_power_spectrum
|
||||
|
||||
|
||||
size = 256
|
||||
mesh_shape= [size]*3
|
||||
box_size = [float(size)]*3
|
||||
snapshots = jnp.linspace(0.1,1.,4)
|
||||
halo_size = 64
|
||||
if jax.device_count() > 1 :
|
||||
|
||||
pdims = (4 , 2)
|
||||
devices = mesh_utils.create_device_mesh(pdims)
|
||||
mesh = Mesh(devices.T, axis_names=('x', 'y'))
|
||||
sharding = NamedSharding(mesh , P('x' , 'y'))
|
||||
|
||||
|
||||
@jax.jit
|
||||
def run_simulation(omega_c, sigma8):
|
||||
# Create a small function to generate the matter power spectrum
|
||||
k = jnp.logspace(-4, 1, 128)
|
||||
pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
|
||||
|
||||
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk)
|
||||
# Create initial conditions
|
||||
initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=jax.random.PRNGKey(0))
|
||||
|
||||
|
||||
# Create particles
|
||||
particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),axis=-1).reshape([-1,3])
|
||||
|
||||
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
|
||||
|
||||
# Initial displacement
|
||||
dx, p, _ = lpt(cosmo, initial_conditions, 0.1 , halo_size=halo_size)
|
||||
|
||||
# Evolve the simulation forward
|
||||
ode_fn = make_ode_fn(mesh_shape , halo_size=halo_size)
|
||||
term = ODETerm(lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
|
||||
solver = Dopri5()
|
||||
|
||||
stepsize_controller = PIDController(rtol=1e-4, atol=1e-4)
|
||||
res = diffeqsolve(term, solver, t0=0.1, t1=1., dt0=0.01, y0=jnp.stack([dx, p],axis=0),
|
||||
args=cosmo,
|
||||
saveat=SaveAt(ts=snapshots),
|
||||
stepsize_controller=stepsize_controller)
|
||||
|
||||
# Return the simulation volume at requested
|
||||
states = res.ys
|
||||
field = cic_paint_dx(dx , halo_size = halo_size)
|
||||
final_fields = [cic_paint_dx(state[0] , halo_size = halo_size) for state in states]
|
||||
|
||||
return initial_conditions , field ,final_fields , res.stats
|
||||
|
||||
|
||||
# Run the simulation
|
||||
if jax.device_count() > 1 :
|
||||
with mesh:
|
||||
init, field , final_fields , stats = run_simulation(0.32, 0.8)
|
||||
|
||||
else:
|
||||
init, field, final_fields , stats = run_simulation(0.32, 0.8)
|
||||
|
||||
# # Print the statistics
|
||||
print(stats)
|
||||
|
||||
# # save the final state
|
||||
np.save(f'initial_conditions_{rank}.npy', init.addressable_data(0))
|
||||
np.save(f'field_{rank}.npy', field.addressable_data(0))
|
||||
|
||||
if final_fields is not None:
|
||||
for i, final_field in enumerate(final_fields):
|
||||
np.save(f'final_field_{i}_{rank}.npy',
|
||||
final_field.addressable_data(0))
|
||||
|
||||
print(f"Finished!!")
|
271
scripts/eval_decomp_ode.ipynb
Normal file
271
scripts/eval_decomp_ode.ipynb
Normal file
File diff suppressed because one or more lines are too long
Loading…
Add table
Reference in a new issue