mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-08 04:40:53 +00:00
times in ms
This commit is contained in:
parent
5f6d42eaeb
commit
1f6b9c3217
1 changed files with 18 additions and 9 deletions
|
@ -9,18 +9,22 @@ rank = jax.process_index()
|
|||
size = jax.process_count()
|
||||
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import jax.numpy as jnp
|
||||
import jax_cosmo as jc
|
||||
import numpy as np
|
||||
from cupy.cuda.nvtx import RangePop, RangePush
|
||||
from diffrax import (Dopri5, LeapfrogMidpoint, ODETerm, PIDController, SaveAt,
|
||||
diffeqsolve)
|
||||
from jax.experimental import mesh_utils
|
||||
from jax.experimental.multihost_utils import sync_global_devices
|
||||
from jax.sharding import Mesh, NamedSharding
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
from jaxpm.kernels import interpolate_power_spectrum
|
||||
from jaxpm.painting import cic_paint_dx
|
||||
from jaxpm.pm import linear_field, lpt, make_ode_fn
|
||||
from diffrax import diffeqsolve, ODETerm, Dopri5, LeapfrogMidpoint, SaveAt, PIDController
|
||||
import numpy as np
|
||||
from jax.experimental import mesh_utils
|
||||
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
|
||||
from jaxpm.kernels import interpolate_power_spectrum
|
||||
import time
|
||||
from cupy.cuda.nvtx import RangePush, RangePop
|
||||
from jax.experimental.multihost_utils import sync_global_devices
|
||||
|
||||
|
||||
def chrono_fun(fun, *args):
|
||||
|
@ -191,9 +195,14 @@ if __name__ == "__main__":
|
|||
# Write benchmark results to CSV
|
||||
# RANK SIZE MESHSIZE BOX HALO SOLVER NUM_STEPS JITTIME MIN MAX MEAN STD
|
||||
times = np.array(times)
|
||||
jit_in_ms = (warmup_time * 1000)
|
||||
min_time = np.min(times) * 1000
|
||||
max_time = np.max(times) * 1000
|
||||
mean_time = np.mean(times) * 1000
|
||||
std_time = np.std(times) * 1000
|
||||
with open(f"{output_path}/jax_pm_benchmark.csv", 'a') as f:
|
||||
f.write(
|
||||
f"{rank},{size},{mesh_size},{box_size[0]},{halo_size},{solver_choice},{iterations},{warmup_time},{np.min(times)},{np.max(times)},{np.mean(times)},{np.std(times)}\n"
|
||||
f"{rank},{size},{mesh_size},{box_size[0]},{halo_size},{solver_choice},{iterations},{jit_in_ms},{min_time},{max_time},{mean_time},{std_time}\n"
|
||||
)
|
||||
|
||||
print(f"Finished! Warmup time: {warmup_time:.4f} seconds")
|
||||
|
|
Loading…
Add table
Reference in a new issue