mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-05-15 12:31:11 +00:00
format
This commit is contained in:
parent
831291c1f9
commit
ece8c93540
12 changed files with 210 additions and 170 deletions
|
@ -10,13 +10,14 @@ size = jax.process_count()
|
|||
|
||||
import argparse
|
||||
import time
|
||||
from hpc_plotter.timer import Timer
|
||||
|
||||
import jax.numpy as jnp
|
||||
import jax_cosmo as jc
|
||||
import numpy as np
|
||||
from cupy.cuda.nvtx import RangePop, RangePush
|
||||
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
|
||||
PIDController, SaveAt, Tsit5, diffeqsolve)
|
||||
from hpc_plotter.timer import Timer
|
||||
from jax.experimental import mesh_utils
|
||||
from jax.experimental.multihost_utils import sync_global_devices
|
||||
from jax.sharding import Mesh, NamedSharding
|
||||
|
@ -27,7 +28,6 @@ from jaxpm.painting import cic_paint_dx
|
|||
from jaxpm.pm import linear_field, lpt, make_ode_fn
|
||||
|
||||
|
||||
|
||||
def run_simulation(mesh_shape,
|
||||
box_size,
|
||||
halo_size,
|
||||
|
@ -69,7 +69,7 @@ def run_simulation(mesh_shape,
|
|||
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size)
|
||||
term = ODETerm(
|
||||
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
|
||||
|
||||
|
||||
if solver_choice == "Dopri5" or solver_choice == "Tsit5":
|
||||
stepsize_controller = PIDController(rtol=1e-4, atol=1e-4)
|
||||
elif solver_choice == "LeapfrogMidpoint" or solver_choice == "Euler":
|
||||
|
@ -94,12 +94,18 @@ def run_simulation(mesh_shape,
|
|||
# Warm start
|
||||
chrono_fun = Timer()
|
||||
RangePush("warmup")
|
||||
final_field, stats = chrono_fun.chrono_jit(simulate, 0.32, 0.8 , ndarray_arg = 0)
|
||||
final_field, stats = chrono_fun.chrono_jit(simulate,
|
||||
0.32,
|
||||
0.8,
|
||||
ndarray_arg=0)
|
||||
RangePop()
|
||||
sync_global_devices("warmup")
|
||||
for i in range(iterations):
|
||||
RangePush(f"sim iter {i}")
|
||||
final_field, stats = chrono_fun.chrono_fun(simulate, 0.32, 0.8 , ndarray_arg = 0)
|
||||
final_field, stats = chrono_fun.chrono_fun(simulate,
|
||||
0.32,
|
||||
0.8,
|
||||
ndarray_arg=0)
|
||||
RangePop()
|
||||
return final_field, stats, chrono_fun
|
||||
|
||||
|
@ -134,11 +140,13 @@ if __name__ == "__main__":
|
|||
type=str,
|
||||
help='Processor dimensions',
|
||||
default=None)
|
||||
parser.add_argument('-pr',
|
||||
'--precision',
|
||||
type=str,
|
||||
help='Precision',
|
||||
choices=["float32", "float64"],)
|
||||
parser.add_argument(
|
||||
'-pr',
|
||||
'--precision',
|
||||
type=str,
|
||||
help='Precision',
|
||||
choices=["float32", "float64"],
|
||||
)
|
||||
parser.add_argument('-hs',
|
||||
'--halo_size',
|
||||
type=int,
|
||||
|
@ -173,7 +181,7 @@ if __name__ == "__main__":
|
|||
type=int,
|
||||
help='Number of nodes',
|
||||
default=1)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
mesh_size = args.mesh_size
|
||||
box_size = [args.box_size] * 3
|
||||
|
@ -182,14 +190,14 @@ if __name__ == "__main__":
|
|||
iterations = args.iterations
|
||||
output_path = args.output_path
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
|
||||
print(f"solver choice: {solver_choice}")
|
||||
match solver_choice:
|
||||
case "Dopri5" | "dopri5"| "d5":
|
||||
case "Dopri5" | "dopri5" | "d5":
|
||||
solver_choice = "Dopri5"
|
||||
case "Tsit5"| "tsit5"| "t5":
|
||||
case "Tsit5" | "tsit5" | "t5":
|
||||
solver_choice = "Tsit5"
|
||||
case "LeapfrogMidpoint"| "leapfrogmidpoint"| "lfm":
|
||||
case "LeapfrogMidpoint" | "leapfrogmidpoint" | "lfm":
|
||||
solver_choice = "LeapfrogMidpoint"
|
||||
case "lpt":
|
||||
solver_choice = "lpt"
|
||||
|
@ -199,7 +207,7 @@ if __name__ == "__main__":
|
|||
)
|
||||
if args.precision == "float32":
|
||||
jax.config.update("jax_enable_x64", False)
|
||||
elif args.precision == "float64":
|
||||
elif args.precision == "float64":
|
||||
jax.config.update("jax_enable_x64", True)
|
||||
|
||||
if args.pdims:
|
||||
|
@ -209,22 +217,26 @@ if __name__ == "__main__":
|
|||
|
||||
mesh_shape = [mesh_size] * 3
|
||||
|
||||
final_field , stats, chrono_fun = run_simulation(mesh_shape, box_size, halo_size, solver_choice, iterations, pdims)
|
||||
|
||||
print(f"shape of final_field {final_field.shape} and sharding spec {final_field.sharding} and local shape {final_field.addressable_data(0).shape}")
|
||||
final_field, stats, chrono_fun = run_simulation(mesh_shape, box_size,
|
||||
halo_size, solver_choice,
|
||||
iterations, pdims)
|
||||
|
||||
print(
|
||||
f"shape of final_field {final_field.shape} and sharding spec {final_field.sharding} and local shape {final_field.addressable_data(0).shape}"
|
||||
)
|
||||
|
||||
metadata = {
|
||||
'rank': rank,
|
||||
'function_name': f'JAXPM-{solver_choice}',
|
||||
'precision': args.precision,
|
||||
'x': str(mesh_size),
|
||||
'y': str(mesh_size),
|
||||
'z': str(stats["num_steps"]),
|
||||
'px': str(pdims[0]),
|
||||
'py': str(pdims[1]),
|
||||
'backend': 'NCCL',
|
||||
'nodes': str(args.nodes)
|
||||
}
|
||||
'rank': rank,
|
||||
'function_name': f'JAXPM-{solver_choice}',
|
||||
'precision': args.precision,
|
||||
'x': str(mesh_size),
|
||||
'y': str(mesh_size),
|
||||
'z': str(stats["num_steps"]),
|
||||
'px': str(pdims[0]),
|
||||
'py': str(pdims[1]),
|
||||
'backend': 'NCCL',
|
||||
'nodes': str(args.nodes)
|
||||
}
|
||||
# Print the results to a CSV file
|
||||
chrono_fun.print_to_csv(f'{output_path}/jaxpm_benchmark.csv', **metadata)
|
||||
|
||||
|
@ -236,8 +248,8 @@ if __name__ == "__main__":
|
|||
with open(f'{field_folder}/jaxpm.log', 'w') as f:
|
||||
f.write(f"Args: {args}\n")
|
||||
f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n")
|
||||
for i , time in enumerate(chrono_fun.times):
|
||||
f.write(f"Time {i}: {time:.4f} ms\n")
|
||||
for i, time in enumerate(chrono_fun.times):
|
||||
f.write(f"Time {i}: {time:.4f} ms\n")
|
||||
f.write(f"Stats: {stats}\n")
|
||||
if args.save_fields:
|
||||
np.save(f'{field_folder}/final_field_0_{rank}.npy',
|
||||
|
|
|
@ -3,34 +3,41 @@ import os
|
|||
# Change JAX GPU memory preallocation fraction
|
||||
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.95'
|
||||
|
||||
import jax
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import matplotlib.pyplot as plt
|
||||
from pmwd import (
|
||||
Configuration,
|
||||
Cosmology, SimpleLCDM,
|
||||
boltzmann, linear_power, growth,
|
||||
white_noise, linear_modes,
|
||||
lpt, nbody, scatter
|
||||
)
|
||||
import numpy as np
|
||||
from hpc_plotter.timer import Timer
|
||||
from pmwd import (Configuration, Cosmology, SimpleLCDM, boltzmann, growth,
|
||||
linear_modes, linear_power, lpt, nbody, scatter, white_noise)
|
||||
from pmwd.pm_util import fftinv
|
||||
from pmwd.spec_util import powspec
|
||||
from pmwd.vis_util import simshow
|
||||
from hpc_plotter.timer import Timer
|
||||
|
||||
|
||||
# Simulation configuration
|
||||
def run_pmwd_simulation(ptcl_grid_shape, ptcl_spacing, solver , iterations):
|
||||
def run_pmwd_simulation(ptcl_grid_shape, ptcl_spacing, solver, iterations):
|
||||
|
||||
@jax.jit
|
||||
def simulate(omega_m, sigma8):
|
||||
|
||||
|
||||
conf = Configuration(ptcl_spacing, ptcl_grid_shape=ptcl_grid_shape, mesh_shape=1,lpt_order=1,a_nbody_maxstep=1/91)
|
||||
print(conf)
|
||||
print(f'Simulating {conf.ptcl_num} particles with a {conf.mesh_shape} mesh for {conf.a_nbody_num} time steps.')
|
||||
|
||||
cosmo = Cosmology(conf, A_s_1e9=2.0, n_s=0.96, Omega_m=omega_m, Omega_b=sigma8, h=0.7)
|
||||
conf = Configuration(ptcl_spacing,
|
||||
ptcl_grid_shape=ptcl_grid_shape,
|
||||
mesh_shape=1,
|
||||
lpt_order=1,
|
||||
a_nbody_maxstep=1 / 91)
|
||||
print(conf)
|
||||
print(
|
||||
f'Simulating {conf.ptcl_num} particles with a {conf.mesh_shape} mesh for {conf.a_nbody_num} time steps.'
|
||||
)
|
||||
|
||||
cosmo = Cosmology(conf,
|
||||
A_s_1e9=2.0,
|
||||
n_s=0.96,
|
||||
Omega_m=omega_m,
|
||||
Omega_b=sigma8,
|
||||
h=0.7)
|
||||
print(cosmo)
|
||||
|
||||
# Boltzmann calculation
|
||||
|
@ -46,71 +53,95 @@ def run_pmwd_simulation(ptcl_grid_shape, ptcl_spacing, solver , iterations):
|
|||
# Solve LPT at some early time
|
||||
ptcl, obsvbl = lpt(modes, cosmo, conf)
|
||||
print("LPT solved.")
|
||||
|
||||
|
||||
if solver == "lfm":
|
||||
# N-body time integration from LPT initial conditions
|
||||
ptcl, obsvbl = jax.block_until_ready(nbody(ptcl, obsvbl, cosmo, conf))
|
||||
print("N-body time integration completed.")
|
||||
# N-body time integration from LPT initial conditions
|
||||
ptcl, obsvbl = jax.block_until_ready(
|
||||
nbody(ptcl, obsvbl, cosmo, conf))
|
||||
print("N-body time integration completed.")
|
||||
|
||||
# Scatter particles to mesh to get the density field
|
||||
dens = scatter(ptcl, conf)
|
||||
return dens
|
||||
|
||||
|
||||
chrono_timer = Timer()
|
||||
final_field = chrono_timer.chrono_jit(simulate, 0.3, 0.05)
|
||||
|
||||
|
||||
for _ in range(iterations):
|
||||
final_field = chrono_timer.chrono_fun(simulate, 0.3, 0.05)
|
||||
|
||||
return final_field , chrono_timer
|
||||
|
||||
return final_field, chrono_timer
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='PMWD Simulation')
|
||||
parser.add_argument('-m', '--mesh_size', type=int, help='Mesh size', required=True)
|
||||
parser.add_argument('-b', '--box_size', type=float, help='Box size', required=True)
|
||||
parser.add_argument('-i', '--iterations', type=int, help='Number of iterations', default=10)
|
||||
parser.add_argument('-o', '--output_path', type=str, help='Output path', default=".")
|
||||
parser.add_argument('-f', '--save_fields', action='store_true', help='Save fields')
|
||||
parser.add_argument('-s', '--solver', type=str, help='Solver', choices=["lfm" , "lpt"])
|
||||
parser.add_argument('-pr',
|
||||
'--precision',
|
||||
type=str,
|
||||
help='Precision',
|
||||
choices=["float32", "float64"],)
|
||||
|
||||
parser.add_argument('-m',
|
||||
'--mesh_size',
|
||||
type=int,
|
||||
help='Mesh size',
|
||||
required=True)
|
||||
parser.add_argument('-b',
|
||||
'--box_size',
|
||||
type=float,
|
||||
help='Box size',
|
||||
required=True)
|
||||
parser.add_argument('-i',
|
||||
'--iterations',
|
||||
type=int,
|
||||
help='Number of iterations',
|
||||
default=10)
|
||||
parser.add_argument('-o',
|
||||
'--output_path',
|
||||
type=str,
|
||||
help='Output path',
|
||||
default=".")
|
||||
parser.add_argument('-f',
|
||||
'--save_fields',
|
||||
action='store_true',
|
||||
help='Save fields')
|
||||
parser.add_argument('-s',
|
||||
'--solver',
|
||||
type=str,
|
||||
help='Solver',
|
||||
choices=["lfm", "lpt"])
|
||||
parser.add_argument(
|
||||
'-pr',
|
||||
'--precision',
|
||||
type=str,
|
||||
help='Precision',
|
||||
choices=["float32", "float64"],
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
mesh_shape = [args.mesh_size] * 3
|
||||
ptcl_spacing = args.box_size /args.mesh_size
|
||||
ptcl_spacing = args.box_size / args.mesh_size
|
||||
iterations = args.iterations
|
||||
solver = args.solver
|
||||
output_path = args.output_path
|
||||
if args.precision == "float32":
|
||||
jax.config.update("jax_enable_x64", False)
|
||||
elif args.precision == "float64":
|
||||
elif args.precision == "float64":
|
||||
jax.config.update("jax_enable_x64", True)
|
||||
|
||||
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
final_field , chrono_fun = run_pmwd_simulation(mesh_shape, ptcl_spacing, solver, iterations)
|
||||
|
||||
final_field, chrono_fun = run_pmwd_simulation(mesh_shape, ptcl_spacing,
|
||||
solver, iterations)
|
||||
print("PMWD simulation completed.")
|
||||
|
||||
|
||||
metadata = {
|
||||
'rank': 0,
|
||||
'function_name': f'PMWD-{solver}',
|
||||
'precision': args.precision,
|
||||
'x': str(mesh_shape[0]),
|
||||
'y': str(mesh_shape[0]),
|
||||
'z': str(mesh_shape[0]),
|
||||
'px': "1",
|
||||
'py': "1",
|
||||
'backend': 'NCCL',
|
||||
'nodes': "1"
|
||||
}
|
||||
'rank': 0,
|
||||
'function_name': f'PMWD-{solver}',
|
||||
'precision': args.precision,
|
||||
'x': str(mesh_shape[0]),
|
||||
'y': str(mesh_shape[0]),
|
||||
'z': str(mesh_shape[0]),
|
||||
'px': "1",
|
||||
'py': "1",
|
||||
'backend': 'NCCL',
|
||||
'nodes': "1"
|
||||
}
|
||||
chrono_fun.print_to_csv(f"{output_path}/pmwd.csv", **metadata)
|
||||
field_folder = f"{output_path}/final_field/pmwd/1/{args.mesh_size}_{int(args.box_size)}/1x1/{args.solver}/halo_0"
|
||||
os.makedirs(field_folder, exist_ok=True)
|
||||
|
@ -118,14 +149,11 @@ if __name__ == "__main__":
|
|||
f.write(f"PMWD simulation completed.\n")
|
||||
f.write(f"Args : {args}\n")
|
||||
f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n")
|
||||
for i , time in enumerate(chrono_fun.times):
|
||||
f.write(f"Time {i}: {time:.4f} ms\n")
|
||||
for i, time in enumerate(chrono_fun.times):
|
||||
f.write(f"Time {i}: {time:.4f} ms\n")
|
||||
if args.save_fields:
|
||||
np.save(f"{field_folder}/final_field_0_0.npy", final_field)
|
||||
print("Fields saved.")
|
||||
|
||||
|
||||
|
||||
print(f"saving to {output_path}/pmwd.csv")
|
||||
print(f"saving field and logs to {field_folder}/pmwd.log")
|
||||
|
||||
|
||||
|
|
|
@ -177,7 +177,3 @@ for pr in "${precisions[@]}"; do
|
|||
done
|
||||
done
|
||||
done
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -179,6 +179,3 @@ for pr in "${precisions[@]}"; do
|
|||
done
|
||||
done
|
||||
done
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -156,10 +156,7 @@ echo "Output dir is : $out_dir"
|
|||
for pr in "${precisions[@]}"; do
|
||||
for g in "${grid[@]}"; do
|
||||
for solver in "${solvers[@]}"; do
|
||||
launch bench_pmwd.py -m $g -b $g -p $p -pr $pr -s $solver -i 4 -o $out_dir -f
|
||||
launch bench_pmwd.py -m $g -b $g -p $p -pr $pr -s $solver -i 4 -o $out_dir -f
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -161,10 +161,7 @@ echo "Output dir is : $out_dir"
|
|||
for pr in "${precisions[@]}"; do
|
||||
for g in "${grid[@]}"; do
|
||||
for solver in "${solvers[@]}"; do
|
||||
slaunch bench_pmwd.py -m $g -b $g -pr $pr -s $solver -i 4 -o $out_dir -f
|
||||
slaunch bench_pmwd.py -m $g -b $g -pr $pr -s $solver -i 4 -o $out_dir -f
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue