mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-16 16:10:54 +00:00
times in ms
This commit is contained in:
parent
5f6d42eaeb
commit
1f6b9c3217
1 changed files with 18 additions and 9 deletions
|
@ -9,18 +9,22 @@ rank = jax.process_index()
|
||||||
size = jax.process_count()
|
size = jax.process_count()
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import time
|
||||||
|
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jax_cosmo as jc
|
import jax_cosmo as jc
|
||||||
|
import numpy as np
|
||||||
|
from cupy.cuda.nvtx import RangePop, RangePush
|
||||||
|
from diffrax import (Dopri5, LeapfrogMidpoint, ODETerm, PIDController, SaveAt,
|
||||||
|
diffeqsolve)
|
||||||
|
from jax.experimental import mesh_utils
|
||||||
|
from jax.experimental.multihost_utils import sync_global_devices
|
||||||
|
from jax.sharding import Mesh, NamedSharding
|
||||||
|
from jax.sharding import PartitionSpec as P
|
||||||
|
|
||||||
|
from jaxpm.kernels import interpolate_power_spectrum
|
||||||
from jaxpm.painting import cic_paint_dx
|
from jaxpm.painting import cic_paint_dx
|
||||||
from jaxpm.pm import linear_field, lpt, make_ode_fn
|
from jaxpm.pm import linear_field, lpt, make_ode_fn
|
||||||
from diffrax import diffeqsolve, ODETerm, Dopri5, LeapfrogMidpoint, SaveAt, PIDController
|
|
||||||
import numpy as np
|
|
||||||
from jax.experimental import mesh_utils
|
|
||||||
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
|
|
||||||
from jaxpm.kernels import interpolate_power_spectrum
|
|
||||||
import time
|
|
||||||
from cupy.cuda.nvtx import RangePush, RangePop
|
|
||||||
from jax.experimental.multihost_utils import sync_global_devices
|
|
||||||
|
|
||||||
|
|
||||||
def chrono_fun(fun, *args):
|
def chrono_fun(fun, *args):
|
||||||
|
@ -191,9 +195,14 @@ if __name__ == "__main__":
|
||||||
# Write benchmark results to CSV
|
# Write benchmark results to CSV
|
||||||
# RANK SIZE MESHSIZE BOX HALO SOLVER NUM_STEPS JITTIME MIN MAX MEAN STD
|
# RANK SIZE MESHSIZE BOX HALO SOLVER NUM_STEPS JITTIME MIN MAX MEAN STD
|
||||||
times = np.array(times)
|
times = np.array(times)
|
||||||
|
jit_in_ms = (warmup_time * 1000)
|
||||||
|
min_time = np.min(times) * 1000
|
||||||
|
max_time = np.max(times) * 1000
|
||||||
|
mean_time = np.mean(times) * 1000
|
||||||
|
std_time = np.std(times) * 1000
|
||||||
with open(f"{output_path}/jax_pm_benchmark.csv", 'a') as f:
|
with open(f"{output_path}/jax_pm_benchmark.csv", 'a') as f:
|
||||||
f.write(
|
f.write(
|
||||||
f"{rank},{size},{mesh_size},{box_size[0]},{halo_size},{solver_choice},{iterations},{warmup_time},{np.min(times)},{np.max(times)},{np.mean(times)},{np.std(times)}\n"
|
f"{rank},{size},{mesh_size},{box_size[0]},{halo_size},{solver_choice},{iterations},{jit_in_ms},{min_time},{max_time},{mean_time},{std_time}\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Finished! Warmup time: {warmup_time:.4f} seconds")
|
print(f"Finished! Warmup time: {warmup_time:.4f} seconds")
|
||||||
|
|
Loading…
Add table
Reference in a new issue