mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 20:30:54 +00:00
ability to choose snapshots number with MultiHost script
This commit is contained in:
parent
ad4566620f
commit
12c74e2601
2 changed files with 318 additions and 292 deletions
File diff suppressed because one or more lines are too long
|
@ -57,6 +57,12 @@ def parse_arguments():
|
|||
help=
|
||||
"Box size of the simulation as three values (e.g., 500.0 500.0 1000.0)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-st",
|
||||
"--snapshots",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Number of snapshots to save during the simulation.")
|
||||
parser.add_argument("-H",
|
||||
"--halo_size",
|
||||
type=int,
|
||||
|
@ -71,7 +77,7 @@ def parse_arguments():
|
|||
return parser.parse_args()
|
||||
|
||||
|
||||
def create_mesh_and_sharding(mesh_shape, pdims):
|
||||
def create_mesh_and_sharding(pdims):
|
||||
devices = create_device_mesh(pdims)
|
||||
mesh = Mesh(devices, axis_names=('x', 'y'))
|
||||
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||
|
@ -80,7 +86,7 @@ def create_mesh_and_sharding(mesh_shape, pdims):
|
|||
|
||||
@partial(jax.jit, static_argnums=(2, 3, 4, 5, 6))
|
||||
def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
|
||||
solver_choice, sharding):
|
||||
solver_choice,nb_snapshots, sharding):
|
||||
k = jnp.logspace(-4, 1, 128)
|
||||
pk = jc.power.linear_matter_power(
|
||||
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
|
||||
|
@ -115,7 +121,7 @@ def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
|
|||
dt0=0.01,
|
||||
y0=jnp.stack([dx, p], axis=0),
|
||||
args=cosmo,
|
||||
saveat=SaveAt(ts=jnp.array([0.5, 1.0])),
|
||||
saveat=SaveAt(ts=jnp.linspace(0.2, 1., nb_snapshots)),
|
||||
stepsize_controller=stepsize_controller)
|
||||
|
||||
ode_fields = [
|
||||
|
@ -132,18 +138,25 @@ def main():
|
|||
box_size = args.box_size
|
||||
halo_size = args.halo_size
|
||||
solver_choice = args.solver
|
||||
nb_snapshots = args.snapshots
|
||||
|
||||
mesh, sharding = create_mesh_and_sharding(mesh_shape, args.pdims)
|
||||
sharding = create_mesh_and_sharding(args.pdims)
|
||||
|
||||
initial_conditions, lpt_displacements, ode_solutions, solver_stats = run_simulation(
|
||||
0.25, 0.8, tuple(mesh_shape), tuple(box_size), halo_size,
|
||||
solver_choice, sharding)
|
||||
solver_choice, nb_snapshots, sharding)
|
||||
|
||||
if rank == 0:
|
||||
os.makedirs("fields", exist_ok=True)
|
||||
print(f"[{rank}] Simulation done")
|
||||
print(f"Solver stats: {solver_stats}")
|
||||
|
||||
|
||||
# Save initial conditions
|
||||
initial_conditions_g = all_gather(initial_conditions)
|
||||
if rank == 0:
|
||||
print(f"[{rank}] Saving initial_conditions")
|
||||
np.save("initial_conditions.npy", initial_conditions_g)
|
||||
np.save("fields/initial_conditions.npy", initial_conditions_g)
|
||||
print(f"[{rank}] initial_conditions saved")
|
||||
del initial_conditions_g, initial_conditions
|
||||
|
||||
|
@ -151,7 +164,7 @@ def main():
|
|||
lpt_displacements_g = all_gather(lpt_displacements)
|
||||
if rank == 0:
|
||||
print(f"[{rank}] Saving lpt_displacements")
|
||||
np.save("lpt_displacements.npy", lpt_displacements_g)
|
||||
np.save("fields/lpt_displacements.npy", lpt_displacements_g)
|
||||
print(f"[{rank}] lpt_displacements saved")
|
||||
del lpt_displacements_g, lpt_displacements
|
||||
|
||||
|
@ -160,7 +173,7 @@ def main():
|
|||
sol_g = all_gather(sol)
|
||||
if rank == 0:
|
||||
print(f"[{rank}] Saving ode_solution_{i}")
|
||||
np.save(f"ode_solution_{i}.npy", sol_g)
|
||||
np.save(f"fields/ode_solution_{i}.npy", sol_g)
|
||||
print(f"[{rank}] ode_solution_{i} saved")
|
||||
del sol_g
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue