mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-16 16:10:54 +00:00
ability to choose snapshots number with MultiHost script
This commit is contained in:
parent
ad4566620f
commit
12c74e2601
2 changed files with 318 additions and 292 deletions
|
@ -160,34 +160,47 @@
|
||||||
"- **`--box_size`** (`-b`): Sets the physical box size of the simulation as three floating-point values, e.g., `1000. 1000. 1000.` (default is `[500.0, 500.0, 500.0]`).\n",
|
"- **`--box_size`** (`-b`): Sets the physical box size of the simulation as three floating-point values, e.g., `1000. 1000. 1000.` (default is `[500.0, 500.0, 500.0]`).\n",
|
||||||
"- **`--halo_size`** (`-H`): Specifies the halo size for boundary overlap across nodes (default is `64`).\n",
|
"- **`--halo_size`** (`-H`): Specifies the halo size for boundary overlap across nodes (default is `64`).\n",
|
||||||
"- **`--solver`** (`-s`): Chooses the ODE solver (`leapfrog` or `dopri8`). The `leapfrog` solver uses a fixed step size, while `dopri8` is an adaptive Runge-Kutta solver with a PID controller (default is `leapfrog`).\n",
|
"- **`--solver`** (`-s`): Chooses the ODE solver (`leapfrog` or `dopri8`). The `leapfrog` solver uses a fixed step size, while `dopri8` is an adaptive Runge-Kutta solver with a PID controller (default is `leapfrog`).\n",
|
||||||
|
"- **`--snapthots`** (`-st`) : Number of snapshots to save (warning, increases memory usage)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"The script also saves results across nodes.\n"
|
"The script also saves results across nodes.\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 24,
|
"execution_count": null,
|
||||||
"id": "b7eabac5",
|
"id": "c6d13679",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"SIZE is 32\n",
|
|
||||||
"[0] Saving initial_conditions\n",
|
|
||||||
"[0] initial_conditions saved\n",
|
|
||||||
"[0] Saving lpt_displacements\n",
|
|
||||||
"[0] lpt_displacements saved\n",
|
|
||||||
"[0] Saving ode_solution_0\n",
|
|
||||||
"[0] ode_solution_0 saved\n",
|
|
||||||
"[0] Saving ode_solution_1\n",
|
|
||||||
"[0] ode_solution_1 saved\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"!srun --jobid=467745 -n 32 python 05-MultiHost_PM.py --mesh_shape 1024 1024 1024 --box_size 1000. 1000. 1000. --halo_size 128 -s leapfrog --pdims 16 2"
|
"import subprocess\n",
|
||||||
|
"\n",
|
||||||
|
"# Define parameters as variables\n",
|
||||||
|
"jobid = \"467745\"\n",
|
||||||
|
"num_processes = 32\n",
|
||||||
|
"script_name = \"05-MultiHost_PM.py\"\n",
|
||||||
|
"mesh_shape = (1024, 1024, 1024)\n",
|
||||||
|
"box_size = (1000., 1000., 1000.)\n",
|
||||||
|
"halo_size = 128\n",
|
||||||
|
"solver = \"leapfrog\"\n",
|
||||||
|
"pdims = (16, 2)\n",
|
||||||
|
"snapshots = 2\n",
|
||||||
|
"\n",
|
||||||
|
"# Build the command as a list, incorporating variables\n",
|
||||||
|
"command = [\n",
|
||||||
|
" \"srun\",\n",
|
||||||
|
" f\"--jobid={jobid}\",\n",
|
||||||
|
" \"-n\", str(num_processes),\n",
|
||||||
|
" \"python\", script_name,\n",
|
||||||
|
" \"--mesh_shape\", str(mesh_shape[0]), str(mesh_shape[1]), str(mesh_shape[2]),\n",
|
||||||
|
" \"--box_size\", str(box_size[0]), str(box_size[1]), str(box_size[2]),\n",
|
||||||
|
" \"--halo_size\", str(halo_size),\n",
|
||||||
|
" \"-s\", solver,\n",
|
||||||
|
" \"--pdims\", str(pdims[0]), str(pdims[1]),\n",
|
||||||
|
" \"--snapshots\", str(snapshots)\n",
|
||||||
|
"]\n",
|
||||||
|
"\n",
|
||||||
|
"# Execute the command as a subprocess\n",
|
||||||
|
"subprocess.run(command)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -215,10 +228,10 @@
|
||||||
"source": [
|
"source": [
|
||||||
"import numpy as np\n",
|
"import numpy as np\n",
|
||||||
"\n",
|
"\n",
|
||||||
"initial_conditions = np.load('initial_conditions.npy')\n",
|
"initial_conditions = np.load('fields/initial_conditions.npy')\n",
|
||||||
"lpt_displacements = np.load('lpt_displacements.npy')\n",
|
"lpt_displacements = np.load('fields/lpt_displacements.npy')\n",
|
||||||
"ode_solution_0 = np.load('ode_solution_0.npy')\n",
|
"ode_solution_0 = np.load('fields/ode_solution_0.npy')\n",
|
||||||
"ode_solution_1 = np.load('ode_solution_1.npy')"
|
"ode_solution_1 = np.load('fields/ode_solution_1.npy')"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
|
@ -57,6 +57,12 @@ def parse_arguments():
|
||||||
help=
|
help=
|
||||||
"Box size of the simulation as three values (e.g., 500.0 500.0 1000.0)."
|
"Box size of the simulation as three values (e.g., 500.0 500.0 1000.0)."
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-st",
|
||||||
|
"--snapshots",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="Number of snapshots to save during the simulation.")
|
||||||
parser.add_argument("-H",
|
parser.add_argument("-H",
|
||||||
"--halo_size",
|
"--halo_size",
|
||||||
type=int,
|
type=int,
|
||||||
|
@ -71,7 +77,7 @@ def parse_arguments():
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def create_mesh_and_sharding(mesh_shape, pdims):
|
def create_mesh_and_sharding(pdims):
|
||||||
devices = create_device_mesh(pdims)
|
devices = create_device_mesh(pdims)
|
||||||
mesh = Mesh(devices, axis_names=('x', 'y'))
|
mesh = Mesh(devices, axis_names=('x', 'y'))
|
||||||
sharding = NamedSharding(mesh, P('x', 'y'))
|
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||||
|
@ -80,7 +86,7 @@ def create_mesh_and_sharding(mesh_shape, pdims):
|
||||||
|
|
||||||
@partial(jax.jit, static_argnums=(2, 3, 4, 5, 6))
|
@partial(jax.jit, static_argnums=(2, 3, 4, 5, 6))
|
||||||
def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
|
def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
|
||||||
solver_choice, sharding):
|
solver_choice,nb_snapshots, sharding):
|
||||||
k = jnp.logspace(-4, 1, 128)
|
k = jnp.logspace(-4, 1, 128)
|
||||||
pk = jc.power.linear_matter_power(
|
pk = jc.power.linear_matter_power(
|
||||||
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
|
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
|
||||||
|
@ -115,7 +121,7 @@ def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
|
||||||
dt0=0.01,
|
dt0=0.01,
|
||||||
y0=jnp.stack([dx, p], axis=0),
|
y0=jnp.stack([dx, p], axis=0),
|
||||||
args=cosmo,
|
args=cosmo,
|
||||||
saveat=SaveAt(ts=jnp.array([0.5, 1.0])),
|
saveat=SaveAt(ts=jnp.linspace(0.2, 1., nb_snapshots)),
|
||||||
stepsize_controller=stepsize_controller)
|
stepsize_controller=stepsize_controller)
|
||||||
|
|
||||||
ode_fields = [
|
ode_fields = [
|
||||||
|
@ -132,18 +138,25 @@ def main():
|
||||||
box_size = args.box_size
|
box_size = args.box_size
|
||||||
halo_size = args.halo_size
|
halo_size = args.halo_size
|
||||||
solver_choice = args.solver
|
solver_choice = args.solver
|
||||||
|
nb_snapshots = args.snapshots
|
||||||
|
|
||||||
mesh, sharding = create_mesh_and_sharding(mesh_shape, args.pdims)
|
sharding = create_mesh_and_sharding(args.pdims)
|
||||||
|
|
||||||
initial_conditions, lpt_displacements, ode_solutions, solver_stats = run_simulation(
|
initial_conditions, lpt_displacements, ode_solutions, solver_stats = run_simulation(
|
||||||
0.25, 0.8, tuple(mesh_shape), tuple(box_size), halo_size,
|
0.25, 0.8, tuple(mesh_shape), tuple(box_size), halo_size,
|
||||||
solver_choice, sharding)
|
solver_choice, nb_snapshots, sharding)
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
os.makedirs("fields", exist_ok=True)
|
||||||
|
print(f"[{rank}] Simulation done")
|
||||||
|
print(f"Solver stats: {solver_stats}")
|
||||||
|
|
||||||
|
|
||||||
# Save initial conditions
|
# Save initial conditions
|
||||||
initial_conditions_g = all_gather(initial_conditions)
|
initial_conditions_g = all_gather(initial_conditions)
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
print(f"[{rank}] Saving initial_conditions")
|
print(f"[{rank}] Saving initial_conditions")
|
||||||
np.save("initial_conditions.npy", initial_conditions_g)
|
np.save("fields/initial_conditions.npy", initial_conditions_g)
|
||||||
print(f"[{rank}] initial_conditions saved")
|
print(f"[{rank}] initial_conditions saved")
|
||||||
del initial_conditions_g, initial_conditions
|
del initial_conditions_g, initial_conditions
|
||||||
|
|
||||||
|
@ -151,7 +164,7 @@ def main():
|
||||||
lpt_displacements_g = all_gather(lpt_displacements)
|
lpt_displacements_g = all_gather(lpt_displacements)
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
print(f"[{rank}] Saving lpt_displacements")
|
print(f"[{rank}] Saving lpt_displacements")
|
||||||
np.save("lpt_displacements.npy", lpt_displacements_g)
|
np.save("fields/lpt_displacements.npy", lpt_displacements_g)
|
||||||
print(f"[{rank}] lpt_displacements saved")
|
print(f"[{rank}] lpt_displacements saved")
|
||||||
del lpt_displacements_g, lpt_displacements
|
del lpt_displacements_g, lpt_displacements
|
||||||
|
|
||||||
|
@ -160,7 +173,7 @@ def main():
|
||||||
sol_g = all_gather(sol)
|
sol_g = all_gather(sol)
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
print(f"[{rank}] Saving ode_solution_{i}")
|
print(f"[{rank}] Saving ode_solution_{i}")
|
||||||
np.save(f"ode_solution_{i}.npy", sol_g)
|
np.save(f"fields/ode_solution_{i}.npy", sol_g)
|
||||||
print(f"[{rank}] ode_solution_{i} saved")
|
print(f"[{rank}] ode_solution_{i} saved")
|
||||||
del sol_g
|
del sol_g
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue