diff --git a/scripts/distributed_pm.py b/scripts/distributed_pm.py index 29d0b19..ef891e2 100644 --- a/scripts/distributed_pm.py +++ b/scripts/distributed_pm.py @@ -14,18 +14,19 @@ import numpy as np from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve) from jax.experimental import mesh_utils +from jax.experimental.multihost_utils import process_allgather from jax.sharding import Mesh, NamedSharding from jax.sharding import PartitionSpec as P -from jax.experimental.multihost_utils import process_allgather + from jaxpm.kernels import interpolate_power_spectrum from jaxpm.painting import cic_paint_dx from jaxpm.pm import linear_field, lpt, make_ode_fn -size = 256 +size = 64 mesh_shape = [size] * 3 box_size = [float(size)] * 3 snapshots = jnp.linspace(0.1, 1., 4) -halo_size = 32 +halo_size = 4 pdims = (1, 1) mesh = None sharding = None @@ -59,9 +60,9 @@ def run_simulation(omega_c, sigma8): 0.1, halo_size=halo_size, sharding=sharding) - return initial_conditions, cic_paint_dx(dx, - halo_size=halo_size, - sharding=sharding), None, None + # return initial_conditions, cic_paint_dx(dx, + # halo_size=halo_size, + # sharding=sharding), None, None # Evolve the simulation forward ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding) term = ODETerm( @@ -91,6 +92,8 @@ def run_simulation(omega_c, sigma8): # Run the simulation +distributed_str = "distributed" if mesh is not None else "single device" +print(f"running {distributed_str} simulation") init, field, final_fields, stats = run_simulation(0.32, 0.8) # # Print the statistics