mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-18 17:10:54 +00:00
adjust test for hpc-plotter
This commit is contained in:
parent
ccbfee3615
commit
aebc3e72c0
1 changed files with 56 additions and 52 deletions
|
@ -10,7 +10,7 @@ size = jax.process_count()
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
|
from hpc_plotter.timer import Timer
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jax_cosmo as jc
|
import jax_cosmo as jc
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -27,13 +27,6 @@ from jaxpm.painting import cic_paint_dx
|
||||||
from jaxpm.pm import linear_field, lpt, make_ode_fn
|
from jaxpm.pm import linear_field, lpt, make_ode_fn
|
||||||
|
|
||||||
|
|
||||||
def chrono_fun(fun, *args):
|
|
||||||
start = time.perf_counter()
|
|
||||||
out = fun(*args)
|
|
||||||
out[0].block_until_ready()
|
|
||||||
end = time.perf_counter()
|
|
||||||
return out, end - start
|
|
||||||
|
|
||||||
|
|
||||||
def run_simulation(mesh_shape,
|
def run_simulation(mesh_shape,
|
||||||
box_size,
|
box_size,
|
||||||
|
@ -59,7 +52,6 @@ def run_simulation(mesh_shape,
|
||||||
# Create particles
|
# Create particles
|
||||||
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
|
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
|
||||||
dx, p, _ = lpt(cosmo, initial_conditions, 0.1, halo_size=halo_size)
|
dx, p, _ = lpt(cosmo, initial_conditions, 0.1, halo_size=halo_size)
|
||||||
|
|
||||||
if solver_choice == "Dopri5":
|
if solver_choice == "Dopri5":
|
||||||
solver = Dopri5()
|
solver = Dopri5()
|
||||||
elif solver_choice == "LeapfrogMidpoint":
|
elif solver_choice == "LeapfrogMidpoint":
|
||||||
|
@ -68,7 +60,8 @@ def run_simulation(mesh_shape,
|
||||||
solver = Tsit5()
|
solver = Tsit5()
|
||||||
elif solver_choice == "lpt":
|
elif solver_choice == "lpt":
|
||||||
lpt_field = cic_paint_dx(dx, halo_size=halo_size)
|
lpt_field = cic_paint_dx(dx, halo_size=halo_size)
|
||||||
return lpt_field, {"num_steps": 0, "Solver": "LPT"}
|
print(f"TYPE of lpt_field: {type(lpt_field)}")
|
||||||
|
return lpt_field, {"num_steps": 0}
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Invalid solver choice. Use 'Dopri5' or 'LeapfrogMidpoint'.")
|
"Invalid solver choice. Use 'Dopri5' or 'LeapfrogMidpoint'.")
|
||||||
|
@ -77,7 +70,10 @@ def run_simulation(mesh_shape,
|
||||||
term = ODETerm(
|
term = ODETerm(
|
||||||
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
|
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
|
||||||
|
|
||||||
|
if solver_choice == "Dopri5" or solver_choice == "Tsit5":
|
||||||
stepsize_controller = PIDController(rtol=1e-4, atol=1e-4)
|
stepsize_controller = PIDController(rtol=1e-4, atol=1e-4)
|
||||||
|
elif solver_choice == "LeapfrogMidpoint" or solver_choice == "Euler":
|
||||||
|
stepsize_controller = ConstantStepSize()
|
||||||
res = diffeqsolve(term,
|
res = diffeqsolve(term,
|
||||||
solver,
|
solver,
|
||||||
t0=0.1,
|
t0=0.1,
|
||||||
|
@ -96,28 +92,27 @@ def run_simulation(mesh_shape,
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
# Warm start
|
# Warm start
|
||||||
times = []
|
chrono_fun = Timer()
|
||||||
RangePush("warmup")
|
RangePush("warmup")
|
||||||
(final_field, stats), warmup_time = chrono_fun(simulate, 0.32, 0.8)
|
final_field, stats = chrono_fun.chrono_jit(simulate, 0.32, 0.8 , ndarray_arg = 0)
|
||||||
RangePop()
|
RangePop()
|
||||||
sync_global_devices("warmup")
|
sync_global_devices("warmup")
|
||||||
for i in range(iterations):
|
for i in range(iterations):
|
||||||
RangePush(f"sim iter {i}")
|
RangePush(f"sim iter {i}")
|
||||||
(final_field, stats), sim_time = chrono_fun(simulate, 0.32, 0.8)
|
final_field, stats = chrono_fun.chrono_fun(simulate, 0.32, 0.8 , ndarray_arg = 0)
|
||||||
RangePop()
|
RangePop()
|
||||||
times.append(sim_time)
|
return final_field, stats, chrono_fun
|
||||||
return stats, warmup_time, times, final_field
|
|
||||||
|
|
||||||
if jax.device_count() > 1:
|
if jax.device_count() > 1:
|
||||||
devices = mesh_utils.create_device_mesh(pdims)
|
devices = mesh_utils.create_device_mesh(pdims)
|
||||||
mesh = Mesh(devices.T, axis_names=('x', 'y'))
|
mesh = Mesh(devices.T, axis_names=('x', 'y'))
|
||||||
with mesh:
|
with mesh:
|
||||||
# Warm start
|
# Warm start
|
||||||
stats, warmup_time, times, final_field = run()
|
final_field, stats, chrono_fun = run()
|
||||||
else:
|
else:
|
||||||
stats, warmup_time, times, final_field = run()
|
final_field, stats, chrono_fun = run()
|
||||||
|
|
||||||
return stats, warmup_time, times, final_field
|
return final_field, stats, chrono_fun
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -139,6 +134,11 @@ if __name__ == "__main__":
|
||||||
type=str,
|
type=str,
|
||||||
help='Processor dimensions',
|
help='Processor dimensions',
|
||||||
default=None)
|
default=None)
|
||||||
|
parser.add_argument('-pr',
|
||||||
|
'--precision',
|
||||||
|
type=str,
|
||||||
|
help='Precision',
|
||||||
|
choices=["float32", "float64"],)
|
||||||
parser.add_argument('-hs',
|
parser.add_argument('-hs',
|
||||||
'--halo_size',
|
'--halo_size',
|
||||||
type=int,
|
type=int,
|
||||||
|
@ -168,9 +168,13 @@ if __name__ == "__main__":
|
||||||
'--save_fields',
|
'--save_fields',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='Save fields')
|
help='Save fields')
|
||||||
|
parser.add_argument('-n',
|
||||||
|
'--nodes',
|
||||||
|
type=int,
|
||||||
|
help='Number of nodes',
|
||||||
|
default=1)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
mesh_size = args.mesh_size
|
mesh_size = args.mesh_size
|
||||||
box_size = [args.box_size] * 3
|
box_size = [args.box_size] * 3
|
||||||
halo_size = args.halo_size
|
halo_size = args.halo_size
|
||||||
|
@ -179,12 +183,13 @@ if __name__ == "__main__":
|
||||||
output_path = args.output_path
|
output_path = args.output_path
|
||||||
os.makedirs(output_path, exist_ok=True)
|
os.makedirs(output_path, exist_ok=True)
|
||||||
|
|
||||||
|
print(f"solver choice: {solver_choice}")
|
||||||
match solver_choice:
|
match solver_choice:
|
||||||
case "Dopri5", "dopri5", "d5":
|
case "Dopri5" | "dopri5"| "d5":
|
||||||
solver_choice = "Dopri5"
|
solver_choice = "Dopri5"
|
||||||
case "Tsit5", "tsit5", "t5":
|
case "Tsit5"| "tsit5"| "t5":
|
||||||
solver_choice = "Tsit5"
|
solver_choice = "Tsit5"
|
||||||
case "LeapfrogMidpoint", "leapfrogmidpoint", "lfm":
|
case "LeapfrogMidpoint"| "leapfrogmidpoint"| "lfm":
|
||||||
solver_choice = "LeapfrogMidpoint"
|
solver_choice = "LeapfrogMidpoint"
|
||||||
case "lpt":
|
case "lpt":
|
||||||
solver_choice = "lpt"
|
solver_choice = "lpt"
|
||||||
|
@ -192,6 +197,10 @@ if __name__ == "__main__":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Invalid solver choice. Use 'Dopri5', 'Tsit5', 'LeapfrogMidpoint' or 'lpt"
|
"Invalid solver choice. Use 'Dopri5', 'Tsit5', 'LeapfrogMidpoint' or 'lpt"
|
||||||
)
|
)
|
||||||
|
if args.precision == "float32":
|
||||||
|
jax.config.update("jax_enable_x64", False)
|
||||||
|
elif args.precision == "float64":
|
||||||
|
jax.config.update("jax_enable_x64", True)
|
||||||
|
|
||||||
if args.pdims:
|
if args.pdims:
|
||||||
pdims = tuple(map(int, args.pdims.split("x")))
|
pdims = tuple(map(int, args.pdims.split("x")))
|
||||||
|
@ -200,26 +209,24 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
mesh_shape = [mesh_size] * 3
|
mesh_shape = [mesh_size] * 3
|
||||||
|
|
||||||
stats, warmup_time, times, final_field = run_simulation(mesh_shape,
|
final_field , stats, chrono_fun = run_simulation(mesh_shape, box_size, halo_size, solver_choice, iterations, pdims)
|
||||||
box_size,
|
|
||||||
halo_size,
|
|
||||||
solver_choice,
|
|
||||||
iterations,
|
|
||||||
pdims=pdims)
|
|
||||||
|
|
||||||
# Write benchmark results to CSV
|
print(f"shape of final_field {final_field.shape} and sharding spec {final_field.sharding} and local shape {final_field.addressable_data(0).shape}")
|
||||||
# RANK SIZE MESHSIZE BOX HALO SOLVER NUM_STEPS JITTIME MIN MAX MEAN STD
|
|
||||||
times = np.array(times)
|
|
||||||
jit_in_ms = (warmup_time * 1000)
|
|
||||||
min_time = np.min(times) * 1000
|
|
||||||
max_time = np.max(times) * 1000
|
|
||||||
mean_time = np.mean(times) * 1000
|
|
||||||
std_time = np.std(times) * 1000
|
|
||||||
|
|
||||||
with open(f"{output_path}/jax_pm_benchmark.csv", 'a') as f:
|
metadata = {
|
||||||
f.write(
|
'rank': rank,
|
||||||
f"{rank},{size},{mesh_size},{box_size[0]},{halo_size},{solver_choice},{stats['num_steps']},{jit_in_ms},{min_time},{max_time},{mean_time},{std_time}\n"
|
'function_name': f'JAXPM-{solver_choice}',
|
||||||
)
|
'precision': args.precision,
|
||||||
|
'x': str(mesh_size),
|
||||||
|
'y': str(mesh_size),
|
||||||
|
'z': str(stats["num_steps"]),
|
||||||
|
'px': str(pdims[0]),
|
||||||
|
'py': str(pdims[1]),
|
||||||
|
'backend': 'NCCL',
|
||||||
|
'nodes': str(args.nodes)
|
||||||
|
}
|
||||||
|
# Print the results to a CSV file
|
||||||
|
chrono_fun.print_to_csv(f'{output_path}/jaxpm_benchmark.csv', **metadata)
|
||||||
|
|
||||||
# Save the final field
|
# Save the final field
|
||||||
nb_gpus = jax.device_count()
|
nb_gpus = jax.device_count()
|
||||||
|
@ -228,18 +235,15 @@ if __name__ == "__main__":
|
||||||
os.makedirs(field_folder, exist_ok=True)
|
os.makedirs(field_folder, exist_ok=True)
|
||||||
with open(f'{field_folder}/jaxpm.log', 'w') as f:
|
with open(f'{field_folder}/jaxpm.log', 'w') as f:
|
||||||
f.write(f"Args: {args}\n")
|
f.write(f"Args: {args}\n")
|
||||||
f.write(f"JIT time: {jit_in_ms:.4f} ms\n")
|
f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n")
|
||||||
f.write(f"Min time: {min_time:.4f} ms\n")
|
for i , time in enumerate(chrono_fun.times):
|
||||||
f.write(f"Max time: {max_time:.4f} ms\n")
|
f.write(f"Time {i}: {time:.4f} ms\n")
|
||||||
f.write(f"Mean time: {mean_time:.4f} ms\n")
|
|
||||||
f.write(f"Std time: {std_time:.4f} ms\n")
|
|
||||||
f.write(f"Stats: {stats}\n")
|
f.write(f"Stats: {stats}\n")
|
||||||
if args.save_fields:
|
if args.save_fields:
|
||||||
np.save(f'{field_folder}/final_field_0_{rank}.npy',
|
np.save(f'{field_folder}/final_field_0_{rank}.npy',
|
||||||
final_field.addressable_data(0))
|
final_field.addressable_data(0))
|
||||||
|
|
||||||
print(f"Finished! Warmup time: {warmup_time:.4f} seconds")
|
print(f"Finished! ")
|
||||||
print(f"mean times: {np.mean(times):.4f}")
|
|
||||||
print(f"Stats {stats}")
|
print(f"Stats {stats}")
|
||||||
print(f"Saving to {output_path}/jax_pm_benchmark.csv")
|
print(f"Saving to {output_path}/jax_pm_benchmark.csv")
|
||||||
print(f"Saving field and logs in {field_folder}")
|
print(f"Saving field and logs in {field_folder}")
|
||||||
|
|
Loading…
Add table
Reference in a new issue