times in ms

This commit is contained in:
Wassim KABALAN 2024-07-18 13:23:53 +02:00
parent 5f6d42eaeb
commit 1f6b9c3217

View file

@ -9,18 +9,22 @@ rank = jax.process_index()
size = jax.process_count()
import argparse
import time
import jax.numpy as jnp
import jax_cosmo as jc
import numpy as np
from cupy.cuda.nvtx import RangePop, RangePush
from diffrax import (Dopri5, LeapfrogMidpoint, ODETerm, PIDController, SaveAt,
diffeqsolve)
from jax.experimental import mesh_utils
from jax.experimental.multihost_utils import sync_global_devices
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from jaxpm.kernels import interpolate_power_spectrum
from jaxpm.painting import cic_paint_dx
from jaxpm.pm import linear_field, lpt, make_ode_fn
from diffrax import diffeqsolve, ODETerm, Dopri5, LeapfrogMidpoint, SaveAt, PIDController
import numpy as np
from jax.experimental import mesh_utils
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from jaxpm.kernels import interpolate_power_spectrum
import time
from cupy.cuda.nvtx import RangePush, RangePop
from jax.experimental.multihost_utils import sync_global_devices
def chrono_fun(fun, *args):
@ -191,9 +195,14 @@ if __name__ == "__main__":
# Write benchmark results to CSV
# RANK SIZE MESHSIZE BOX HALO SOLVER NUM_STEPS JITTIME MIN MAX MEAN STD
times = np.array(times)
jit_in_ms = (warmup_time * 1000)
min_time = np.min(times) * 1000
max_time = np.max(times) * 1000
mean_time = np.mean(times) * 1000
std_time = np.std(times) * 1000
with open(f"{output_path}/jax_pm_benchmark.csv", 'a') as f:
f.write(
f"{rank},{size},{mesh_size},{box_size[0]},{halo_size},{solver_choice},{iterations},{warmup_time},{np.min(times)},{np.max(times)},{np.mean(times)},{np.std(times)}\n"
f"{rank},{size},{mesh_size},{box_size[0]},{halo_size},{solver_choice},{iterations},{jit_in_ms},{min_time},{max_time},{mean_time},{std_time}\n"
)
print(f"Finished! Warmup time: {warmup_time:.4f} seconds")