diff --git a/benchmarks/bench_pm.py b/benchmarks/bench_pm.py index 5e0c3a9..5f25aad 100644 --- a/benchmarks/bench_pm.py +++ b/benchmarks/bench_pm.py @@ -26,7 +26,7 @@ from jax.sharding import PartitionSpec as P from jaxpm.kernels import interpolate_power_spectrum from jaxpm.painting import cic_paint_dx from jaxpm.pm import linear_field, lpt, make_ode_fn -from jax import make_jaxpr + def run_simulation(mesh_shape, box_size, @@ -94,7 +94,6 @@ def run_simulation(mesh_shape, def run(): # Warm start -<<<<<<< HEAD chrono_fun = Timer() RangePush("warmup") final_field, stats = chrono_fun.chrono_jit(simulate, @@ -111,39 +110,6 @@ def run_simulation(mesh_shape, ndarray_arg=0) RangePop() return final_field, stats, chrono_fun -======= - if hlo_print: - jaxpr = make_jaxpr(simulate)(0.32, 0.8) - lowered = jax.jit(simulate).lower(0.32, 0.8) - lower_as_text = lowered.as_text() - compiled = lowered.compile() - compiled_again = jax.jit(simulate).lower(0.32, 0.8).compile() - return jaxpr , compiled , compiled_again - elif trace: - jit_output = f"{output_path}/jit_trace" - first_run_output = f"{output_path}/first_run_trace" - second_run_output = f"{output_path}/second_run_trace" - with jax.profiler.trace(jit_output , create_perfetto_trace=True): - final_field, stats = simulate(0.32, 0.8) - final_field.block_until_ready() - with jax.profiler.trace(first_run_output , create_perfetto_trace=True): - final_field, stats = simulate(0.32, 0.8) - final_field.block_until_ready() - with jax.profiler.trace(second_run_output , create_perfetto_trace=True): - final_field, stats = simulate(0.32, 0.8) - final_field.block_until_ready() - else: - chrono_fun = Timer() - RangePush("warmup") - final_field, stats = chrono_fun.chrono_jit(simulate, 0.32, 0.8 , ndarray_arg = 0) - RangePop() - sync_global_devices("warmup") - for i in range(iterations): - RangePush(f"sim iter {i}") - final_field, stats = chrono_fun.chrono_fun(simulate, 0.32, 0.8 , ndarray_arg = 0) - RangePop() - return final_field, stats, chrono_fun ->>>>>>> glab/ASKabalan/jaxdecomp_proto if jax.device_count() > 1: devices = mesh_utils.create_device_mesh(pdims) @@ -212,25 +178,6 @@ if __name__ == "__main__": type=int, help='Number of nodes', default=1) -<<<<<<< HEAD - -======= - parser.add_argument('-i', - '--iterations', - type=int, - help='Number of iterations', - default=10) - group = parser.add_mutually_exclusive_group() - group.add_argument('-hlo', - '--hlo_print', - action='store_true', - help='Print hlo generated by XLA') - group.add_argument('-t', - '--trace', - action='store_true', - help='Profile using tensorboard') - ->>>>>>> glab/ASKabalan/jaxdecomp_proto args = parser.parse_args() mesh_size = args.mesh_size box_size = [args.box_size] * 3 @@ -239,12 +186,6 @@ if __name__ == "__main__": iterations = args.iterations output_path = args.output_path os.makedirs(output_path, exist_ok=True) -<<<<<<< HEAD -======= - hlo_print = args.hlo_print - trace = args.trace - nb_gpus = jax.device_count() ->>>>>>> glab/ASKabalan/jaxdecomp_proto print(f"solver choice: {solver_choice}") match solver_choice: @@ -272,7 +213,6 @@ if __name__ == "__main__": pdm_str = f"{pdims[0]}x{pdims[1]}" mesh_shape = [mesh_size] * 3 -<<<<<<< HEAD final_field, stats, chrono_fun = run_simulation(mesh_shape, box_size, halo_size, solver_choice, @@ -283,50 +223,6 @@ if __name__ == "__main__": ) metadata = { -======= - - if trace: - trace_folder = f"{output_path}/profiling/jaxpm/{nb_gpus}/{mesh_shape[0]}_{int(box_size[0])}/{pdm_str}/{solver_choice}/halo_{halo_size}" - os.makedirs(trace_folder, exist_ok=True) - run_simulation(mesh_shape, box_size, halo_size, solver_choice, iterations, hlo_print, trace, pdims, trace_folder) - print(f"Profiling done! Check {trace_folder}") - elif hlo_print: - hlo_folder = f"{output_path}/hlo/jaxpm/{nb_gpus}/{mesh_shape[0]}_{int(box_size[0])}/{pdm_str}/{solver_choice}/halo_{halo_size}" - os.makedirs(hlo_folder, exist_ok=True) - jaxpr , compiled , compiled2 = run_simulation(mesh_shape, box_size, halo_size, solver_choice, iterations, hlo_print, trace, pdims, hlo_folder) - print(f"type of memory analysis {type(compiled.memory_analysis())}") - print(f"memory analysis {compiled.memory_analysis()}") - print(f"memory analysis again {compiled2.memory_analysis()}") - jax.tree.map(lambda x: print(x), compiled.memory_analysis()) - with open(f'{hlo_folder}/hlo_jaxpm.md', 'w') as f: - f.write(f"# JAXPM HLO\n") - f.write(f"## Args: {args}\n") - f.write(f"## JAXPR is \n") - f.write(f'---\n') - f.write(f"{jaxpr}\n") - f.write(f'---\n') - f.write(f"Lowered as text is \n") - f.write(f'---\n') - # f.write(f"{lower_as_text}\n") - f.write(f'---\n') - f.write(f"Compiled is \n") - f.write(f'---\n') - f.write(f"{compiled.as_text()}\n") - f.write(f"Cost analysis is \n") - f.write(f'---\n') - f.write(f"{compiled.cost_analysis()[0]['flops']}\n") - f.write(f'---\n') - f.write(f"Memory analysis is \n") - f.write(f'---\n') - f.write(f"{compiled.memory_analysis()}\n") - f.write(f'---\n') - - print(f"Saved HLO to {hlo_folder}") - else: - final_field, stats, chrono_fun = run_simulation(mesh_shape, box_size, halo_size, solver_choice, iterations, hlo_print, trace, pdims, output_path) - print(f"shape of final_field {final_field.shape} and sharding spec {final_field.sharding} and local shape {final_field.addressable_data(0).shape}") - metadata = { ->>>>>>> glab/ASKabalan/jaxdecomp_proto 'rank': rank, 'function_name': f'JAXPM-{solver_choice}', 'precision': args.precision, @@ -337,7 +233,6 @@ if __name__ == "__main__": 'py': str(pdims[1]), 'backend': 'NCCL', 'nodes': str(args.nodes) -<<<<<<< HEAD } # Print the results to a CSV file chrono_fun.print_to_csv(f'{output_path}/jaxpm_benchmark.csv', **metadata) @@ -356,27 +251,20 @@ if __name__ == "__main__": if args.save_fields: np.save(f'{field_folder}/final_field_0_{rank}.npy', final_field.addressable_data(0)) -======= - } - # Print the results to a CSV file - chrono_fun.print_to_csv(f'{output_path}/jaxpm_benchmark.csv', **metadata) - # Save the final field ->>>>>>> glab/ASKabalan/jaxdecomp_proto - - field_folder = f"{output_path}/final_field/jaxpm/{nb_gpus}/{mesh_size}_{int(box_size[0])}/{pdm_str}/{solver_choice}/halo_{halo_size}" - os.makedirs(field_folder, exist_ok=True) - with open(f'{field_folder}/jaxpm.log', 'w') as f: - f.write(f"Args: {args}\n") - f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n") - for i , time in enumerate(chrono_fun.times): + field_folder = f"{output_path}/final_field/jaxpm/{nb_gpus}/{mesh_size}_{int(box_size[0])}/{pdm_str}/{solver_choice}/halo_{halo_size}" + os.makedirs(field_folder, exist_ok=True) + with open(f'{field_folder}/jaxpm.log', 'w') as f: + f.write(f"Args: {args}\n") + f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n") + for i, time in enumerate(chrono_fun.times): f.write(f"Time {i}: {time:.4f} ms\n") - f.write(f"Stats: {stats}\n") - if args.save_fields: - np.save(f'{field_folder}/final_field_0_{rank}.npy', - final_field.addressable_data(0)) + f.write(f"Stats: {stats}\n") + if args.save_fields: + np.save(f'{field_folder}/final_field_0_{rank}.npy', + final_field.addressable_data(0)) - print(f"Finished! ") - print(f"Stats {stats}") - print(f"Saving to {output_path}/jax_pm_benchmark.csv") - print(f"Saving field and logs in {field_folder}") + print(f"Finished! ") + print(f"Stats {stats}") + print(f"Saving to {output_path}/jax_pm_benchmark.csv") + print(f"Saving field and logs in {field_folder}")