diff --git a/scripts/bench_pm.py b/scripts/bench_pm.py index e73a895..73f2708 100644 --- a/scripts/bench_pm.py +++ b/scripts/bench_pm.py @@ -10,7 +10,7 @@ size = jax.process_count() import argparse import time - +from hpc_plotter.timer import Timer import jax.numpy as jnp import jax_cosmo as jc import numpy as np @@ -27,13 +27,6 @@ from jaxpm.painting import cic_paint_dx from jaxpm.pm import linear_field, lpt, make_ode_fn -def chrono_fun(fun, *args): - start = time.perf_counter() - out = fun(*args) - out[0].block_until_ready() - end = time.perf_counter() - return out, end - start - def run_simulation(mesh_shape, box_size, @@ -59,7 +52,6 @@ def run_simulation(mesh_shape, # Create particles cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8) dx, p, _ = lpt(cosmo, initial_conditions, 0.1, halo_size=halo_size) - if solver_choice == "Dopri5": solver = Dopri5() elif solver_choice == "LeapfrogMidpoint": @@ -68,7 +60,8 @@ def run_simulation(mesh_shape, solver = Tsit5() elif solver_choice == "lpt": lpt_field = cic_paint_dx(dx, halo_size=halo_size) - return lpt_field, {"num_steps": 0, "Solver": "LPT"} + print(f"TYPE of lpt_field: {type(lpt_field)}") + return lpt_field, {"num_steps": 0} else: raise ValueError( "Invalid solver choice. Use 'Dopri5' or 'LeapfrogMidpoint'.") @@ -76,8 +69,11 @@ def run_simulation(mesh_shape, ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size) term = ODETerm( lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0)) - - stepsize_controller = PIDController(rtol=1e-4, atol=1e-4) + + if solver_choice == "Dopri5" or solver_choice == "Tsit5": + stepsize_controller = PIDController(rtol=1e-4, atol=1e-4) + elif solver_choice == "LeapfrogMidpoint" or solver_choice == "Euler": + stepsize_controller = ConstantStepSize() res = diffeqsolve(term, solver, t0=0.1, @@ -96,28 +92,27 @@ def run_simulation(mesh_shape, def run(): # Warm start - times = [] + chrono_fun = Timer() RangePush("warmup") - (final_field, stats), warmup_time = chrono_fun(simulate, 0.32, 0.8) + final_field, stats = chrono_fun.chrono_jit(simulate, 0.32, 0.8 , ndarray_arg = 0) RangePop() sync_global_devices("warmup") for i in range(iterations): RangePush(f"sim iter {i}") - (final_field, stats), sim_time = chrono_fun(simulate, 0.32, 0.8) + final_field, stats = chrono_fun.chrono_fun(simulate, 0.32, 0.8 , ndarray_arg = 0) RangePop() - times.append(sim_time) - return stats, warmup_time, times, final_field + return final_field, stats, chrono_fun if jax.device_count() > 1: devices = mesh_utils.create_device_mesh(pdims) mesh = Mesh(devices.T, axis_names=('x', 'y')) with mesh: # Warm start - stats, warmup_time, times, final_field = run() + final_field, stats, chrono_fun = run() else: - stats, warmup_time, times, final_field = run() + final_field, stats, chrono_fun = run() - return stats, warmup_time, times, final_field + return final_field, stats, chrono_fun if __name__ == "__main__": @@ -139,6 +134,11 @@ if __name__ == "__main__": type=str, help='Processor dimensions', default=None) + parser.add_argument('-pr', + '--precision', + type=str, + help='Precision', + choices=["float32", "float64"],) parser.add_argument('-hs', '--halo_size', type=int, @@ -168,9 +168,13 @@ if __name__ == "__main__": '--save_fields', action='store_true', help='Save fields') - + parser.add_argument('-n', + '--nodes', + type=int, + help='Number of nodes', + default=1) + args = parser.parse_args() - mesh_size = args.mesh_size box_size = [args.box_size] * 3 halo_size = args.halo_size @@ -178,13 +182,14 @@ if __name__ == "__main__": iterations = args.iterations output_path = args.output_path os.makedirs(output_path, exist_ok=True) - + + print(f"solver choice: {solver_choice}") match solver_choice: - case "Dopri5", "dopri5", "d5": + case "Dopri5" | "dopri5"| "d5": solver_choice = "Dopri5" - case "Tsit5", "tsit5", "t5": + case "Tsit5"| "tsit5"| "t5": solver_choice = "Tsit5" - case "LeapfrogMidpoint", "leapfrogmidpoint", "lfm": + case "LeapfrogMidpoint"| "leapfrogmidpoint"| "lfm": solver_choice = "LeapfrogMidpoint" case "lpt": solver_choice = "lpt" @@ -192,6 +197,10 @@ if __name__ == "__main__": raise ValueError( "Invalid solver choice. Use 'Dopri5', 'Tsit5', 'LeapfrogMidpoint' or 'lpt" ) + if args.precision == "float32": + jax.config.update("jax_enable_x64", False) + elif args.precision == "float64": + jax.config.update("jax_enable_x64", True) if args.pdims: pdims = tuple(map(int, args.pdims.split("x"))) @@ -200,26 +209,24 @@ if __name__ == "__main__": mesh_shape = [mesh_size] * 3 - stats, warmup_time, times, final_field = run_simulation(mesh_shape, - box_size, - halo_size, - solver_choice, - iterations, - pdims=pdims) + final_field , stats, chrono_fun = run_simulation(mesh_shape, box_size, halo_size, solver_choice, iterations, pdims) + + print(f"shape of final_field {final_field.shape} and sharding spec {final_field.sharding} and local shape {final_field.addressable_data(0).shape}") - # Write benchmark results to CSV - # RANK SIZE MESHSIZE BOX HALO SOLVER NUM_STEPS JITTIME MIN MAX MEAN STD - times = np.array(times) - jit_in_ms = (warmup_time * 1000) - min_time = np.min(times) * 1000 - max_time = np.max(times) * 1000 - mean_time = np.mean(times) * 1000 - std_time = np.std(times) * 1000 - - with open(f"{output_path}/jax_pm_benchmark.csv", 'a') as f: - f.write( - f"{rank},{size},{mesh_size},{box_size[0]},{halo_size},{solver_choice},{stats['num_steps']},{jit_in_ms},{min_time},{max_time},{mean_time},{std_time}\n" - ) + metadata = { + 'rank': rank, + 'function_name': f'JAXPM-{solver_choice}', + 'precision': args.precision, + 'x': str(mesh_size), + 'y': str(mesh_size), + 'z': str(stats["num_steps"]), + 'px': str(pdims[0]), + 'py': str(pdims[1]), + 'backend': 'NCCL', + 'nodes': str(args.nodes) + } + # Print the results to a CSV file + chrono_fun.print_to_csv(f'{output_path}/jaxpm_benchmark.csv', **metadata) # Save the final field nb_gpus = jax.device_count() @@ -228,18 +235,15 @@ if __name__ == "__main__": os.makedirs(field_folder, exist_ok=True) with open(f'{field_folder}/jaxpm.log', 'w') as f: f.write(f"Args: {args}\n") - f.write(f"JIT time: {jit_in_ms:.4f} ms\n") - f.write(f"Min time: {min_time:.4f} ms\n") - f.write(f"Max time: {max_time:.4f} ms\n") - f.write(f"Mean time: {mean_time:.4f} ms\n") - f.write(f"Std time: {std_time:.4f} ms\n") + f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n") + for i , time in enumerate(chrono_fun.times): + f.write(f"Time {i}: {time:.4f} ms\n") f.write(f"Stats: {stats}\n") if args.save_fields: np.save(f'{field_folder}/final_field_0_{rank}.npy', final_field.addressable_data(0)) - print(f"Finished! Warmup time: {warmup_time:.4f} seconds") - print(f"mean times: {np.mean(times):.4f}") + print(f"Finished! ") print(f"Stats {stats}") print(f"Saving to {output_path}/jax_pm_benchmark.csv") print(f"Saving field and logs in {field_folder}")