mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-04 13:31:12 +00:00
format
This commit is contained in:
parent
831291c1f9
commit
ece8c93540
12 changed files with 210 additions and 170 deletions
|
@ -10,13 +10,14 @@ size = jax.process_count()
|
|||
|
||||
import argparse
|
||||
import time
|
||||
from hpc_plotter.timer import Timer
|
||||
|
||||
import jax.numpy as jnp
|
||||
import jax_cosmo as jc
|
||||
import numpy as np
|
||||
from cupy.cuda.nvtx import RangePop, RangePush
|
||||
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
|
||||
PIDController, SaveAt, Tsit5, diffeqsolve)
|
||||
from hpc_plotter.timer import Timer
|
||||
from jax.experimental import mesh_utils
|
||||
from jax.experimental.multihost_utils import sync_global_devices
|
||||
from jax.sharding import Mesh, NamedSharding
|
||||
|
@ -27,7 +28,6 @@ from jaxpm.painting import cic_paint_dx
|
|||
from jaxpm.pm import linear_field, lpt, make_ode_fn
|
||||
|
||||
|
||||
|
||||
def run_simulation(mesh_shape,
|
||||
box_size,
|
||||
halo_size,
|
||||
|
@ -69,7 +69,7 @@ def run_simulation(mesh_shape,
|
|||
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size)
|
||||
term = ODETerm(
|
||||
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
|
||||
|
||||
|
||||
if solver_choice == "Dopri5" or solver_choice == "Tsit5":
|
||||
stepsize_controller = PIDController(rtol=1e-4, atol=1e-4)
|
||||
elif solver_choice == "LeapfrogMidpoint" or solver_choice == "Euler":
|
||||
|
@ -94,12 +94,18 @@ def run_simulation(mesh_shape,
|
|||
# Warm start
|
||||
chrono_fun = Timer()
|
||||
RangePush("warmup")
|
||||
final_field, stats = chrono_fun.chrono_jit(simulate, 0.32, 0.8 , ndarray_arg = 0)
|
||||
final_field, stats = chrono_fun.chrono_jit(simulate,
|
||||
0.32,
|
||||
0.8,
|
||||
ndarray_arg=0)
|
||||
RangePop()
|
||||
sync_global_devices("warmup")
|
||||
for i in range(iterations):
|
||||
RangePush(f"sim iter {i}")
|
||||
final_field, stats = chrono_fun.chrono_fun(simulate, 0.32, 0.8 , ndarray_arg = 0)
|
||||
final_field, stats = chrono_fun.chrono_fun(simulate,
|
||||
0.32,
|
||||
0.8,
|
||||
ndarray_arg=0)
|
||||
RangePop()
|
||||
return final_field, stats, chrono_fun
|
||||
|
||||
|
@ -134,11 +140,13 @@ if __name__ == "__main__":
|
|||
type=str,
|
||||
help='Processor dimensions',
|
||||
default=None)
|
||||
parser.add_argument('-pr',
|
||||
'--precision',
|
||||
type=str,
|
||||
help='Precision',
|
||||
choices=["float32", "float64"],)
|
||||
parser.add_argument(
|
||||
'-pr',
|
||||
'--precision',
|
||||
type=str,
|
||||
help='Precision',
|
||||
choices=["float32", "float64"],
|
||||
)
|
||||
parser.add_argument('-hs',
|
||||
'--halo_size',
|
||||
type=int,
|
||||
|
@ -173,7 +181,7 @@ if __name__ == "__main__":
|
|||
type=int,
|
||||
help='Number of nodes',
|
||||
default=1)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
mesh_size = args.mesh_size
|
||||
box_size = [args.box_size] * 3
|
||||
|
@ -182,14 +190,14 @@ if __name__ == "__main__":
|
|||
iterations = args.iterations
|
||||
output_path = args.output_path
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
|
||||
print(f"solver choice: {solver_choice}")
|
||||
match solver_choice:
|
||||
case "Dopri5" | "dopri5"| "d5":
|
||||
case "Dopri5" | "dopri5" | "d5":
|
||||
solver_choice = "Dopri5"
|
||||
case "Tsit5"| "tsit5"| "t5":
|
||||
case "Tsit5" | "tsit5" | "t5":
|
||||
solver_choice = "Tsit5"
|
||||
case "LeapfrogMidpoint"| "leapfrogmidpoint"| "lfm":
|
||||
case "LeapfrogMidpoint" | "leapfrogmidpoint" | "lfm":
|
||||
solver_choice = "LeapfrogMidpoint"
|
||||
case "lpt":
|
||||
solver_choice = "lpt"
|
||||
|
@ -199,7 +207,7 @@ if __name__ == "__main__":
|
|||
)
|
||||
if args.precision == "float32":
|
||||
jax.config.update("jax_enable_x64", False)
|
||||
elif args.precision == "float64":
|
||||
elif args.precision == "float64":
|
||||
jax.config.update("jax_enable_x64", True)
|
||||
|
||||
if args.pdims:
|
||||
|
@ -209,22 +217,26 @@ if __name__ == "__main__":
|
|||
|
||||
mesh_shape = [mesh_size] * 3
|
||||
|
||||
final_field , stats, chrono_fun = run_simulation(mesh_shape, box_size, halo_size, solver_choice, iterations, pdims)
|
||||
|
||||
print(f"shape of final_field {final_field.shape} and sharding spec {final_field.sharding} and local shape {final_field.addressable_data(0).shape}")
|
||||
final_field, stats, chrono_fun = run_simulation(mesh_shape, box_size,
|
||||
halo_size, solver_choice,
|
||||
iterations, pdims)
|
||||
|
||||
print(
|
||||
f"shape of final_field {final_field.shape} and sharding spec {final_field.sharding} and local shape {final_field.addressable_data(0).shape}"
|
||||
)
|
||||
|
||||
metadata = {
|
||||
'rank': rank,
|
||||
'function_name': f'JAXPM-{solver_choice}',
|
||||
'precision': args.precision,
|
||||
'x': str(mesh_size),
|
||||
'y': str(mesh_size),
|
||||
'z': str(stats["num_steps"]),
|
||||
'px': str(pdims[0]),
|
||||
'py': str(pdims[1]),
|
||||
'backend': 'NCCL',
|
||||
'nodes': str(args.nodes)
|
||||
}
|
||||
'rank': rank,
|
||||
'function_name': f'JAXPM-{solver_choice}',
|
||||
'precision': args.precision,
|
||||
'x': str(mesh_size),
|
||||
'y': str(mesh_size),
|
||||
'z': str(stats["num_steps"]),
|
||||
'px': str(pdims[0]),
|
||||
'py': str(pdims[1]),
|
||||
'backend': 'NCCL',
|
||||
'nodes': str(args.nodes)
|
||||
}
|
||||
# Print the results to a CSV file
|
||||
chrono_fun.print_to_csv(f'{output_path}/jaxpm_benchmark.csv', **metadata)
|
||||
|
||||
|
@ -236,8 +248,8 @@ if __name__ == "__main__":
|
|||
with open(f'{field_folder}/jaxpm.log', 'w') as f:
|
||||
f.write(f"Args: {args}\n")
|
||||
f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n")
|
||||
for i , time in enumerate(chrono_fun.times):
|
||||
f.write(f"Time {i}: {time:.4f} ms\n")
|
||||
for i, time in enumerate(chrono_fun.times):
|
||||
f.write(f"Time {i}: {time:.4f} ms\n")
|
||||
f.write(f"Stats: {stats}\n")
|
||||
if args.save_fields:
|
||||
np.save(f'{field_folder}/final_field_0_{rank}.npy',
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue