From 1f6b9c3217c0bdf27bf9a516b3341f42e6258dbd Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Thu, 18 Jul 2024 13:23:53 +0200 Subject: [PATCH] times in ms --- scripts/bench_pm.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/scripts/bench_pm.py b/scripts/bench_pm.py index 30cb3db..a572dad 100644 --- a/scripts/bench_pm.py +++ b/scripts/bench_pm.py @@ -9,18 +9,22 @@ rank = jax.process_index() size = jax.process_count() import argparse +import time + import jax.numpy as jnp import jax_cosmo as jc +import numpy as np +from cupy.cuda.nvtx import RangePop, RangePush +from diffrax import (Dopri5, LeapfrogMidpoint, ODETerm, PIDController, SaveAt, + diffeqsolve) +from jax.experimental import mesh_utils +from jax.experimental.multihost_utils import sync_global_devices +from jax.sharding import Mesh, NamedSharding +from jax.sharding import PartitionSpec as P + +from jaxpm.kernels import interpolate_power_spectrum from jaxpm.painting import cic_paint_dx from jaxpm.pm import linear_field, lpt, make_ode_fn -from diffrax import diffeqsolve, ODETerm, Dopri5, LeapfrogMidpoint, SaveAt, PIDController -import numpy as np -from jax.experimental import mesh_utils -from jax.sharding import Mesh, PartitionSpec as P, NamedSharding -from jaxpm.kernels import interpolate_power_spectrum -import time -from cupy.cuda.nvtx import RangePush, RangePop -from jax.experimental.multihost_utils import sync_global_devices def chrono_fun(fun, *args): @@ -191,9 +195,14 @@ if __name__ == "__main__": # Write benchmark results to CSV # RANK SIZE MESHSIZE BOX HALO SOLVER NUM_STEPS JITTIME MIN MAX MEAN STD times = np.array(times) + jit_in_ms = (warmup_time * 1000) + min_time = np.min(times) * 1000 + max_time = np.max(times) * 1000 + mean_time = np.mean(times) * 1000 + std_time = np.std(times) * 1000 with open(f"{output_path}/jax_pm_benchmark.csv", 'a') as f: f.write( - f"{rank},{size},{mesh_size},{box_size[0]},{halo_size},{solver_choice},{iterations},{warmup_time},{np.min(times)},{np.max(times)},{np.mean(times)},{np.std(times)}\n" + f"{rank},{size},{mesh_size},{box_size[0]},{halo_size},{solver_choice},{iterations},{jit_in_ms},{min_time},{max_time},{mean_time},{std_time}\n" ) print(f"Finished! Warmup time: {warmup_time:.4f} seconds")