This commit is contained in:
Wassim KABALAN 2024-08-03 00:23:40 +02:00
parent 831291c1f9
commit ece8c93540
12 changed files with 210 additions and 170 deletions

View file

@ -10,13 +10,14 @@ size = jax.process_count()
import argparse
import time
from hpc_plotter.timer import Timer
import jax.numpy as jnp
import jax_cosmo as jc
import numpy as np
from cupy.cuda.nvtx import RangePop, RangePush
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
PIDController, SaveAt, Tsit5, diffeqsolve)
from hpc_plotter.timer import Timer
from jax.experimental import mesh_utils
from jax.experimental.multihost_utils import sync_global_devices
from jax.sharding import Mesh, NamedSharding
@ -27,7 +28,6 @@ from jaxpm.painting import cic_paint_dx
from jaxpm.pm import linear_field, lpt, make_ode_fn
def run_simulation(mesh_shape,
box_size,
halo_size,
@ -94,12 +94,18 @@ def run_simulation(mesh_shape,
# Warm start
chrono_fun = Timer()
RangePush("warmup")
final_field, stats = chrono_fun.chrono_jit(simulate, 0.32, 0.8 , ndarray_arg = 0)
final_field, stats = chrono_fun.chrono_jit(simulate,
0.32,
0.8,
ndarray_arg=0)
RangePop()
sync_global_devices("warmup")
for i in range(iterations):
RangePush(f"sim iter {i}")
final_field, stats = chrono_fun.chrono_fun(simulate, 0.32, 0.8 , ndarray_arg = 0)
final_field, stats = chrono_fun.chrono_fun(simulate,
0.32,
0.8,
ndarray_arg=0)
RangePop()
return final_field, stats, chrono_fun
@ -134,11 +140,13 @@ if __name__ == "__main__":
type=str,
help='Processor dimensions',
default=None)
parser.add_argument('-pr',
parser.add_argument(
'-pr',
'--precision',
type=str,
help='Precision',
choices=["float32", "float64"],)
choices=["float32", "float64"],
)
parser.add_argument('-hs',
'--halo_size',
type=int,
@ -185,11 +193,11 @@ if __name__ == "__main__":
print(f"solver choice: {solver_choice}")
match solver_choice:
case "Dopri5" | "dopri5"| "d5":
case "Dopri5" | "dopri5" | "d5":
solver_choice = "Dopri5"
case "Tsit5"| "tsit5"| "t5":
case "Tsit5" | "tsit5" | "t5":
solver_choice = "Tsit5"
case "LeapfrogMidpoint"| "leapfrogmidpoint"| "lfm":
case "LeapfrogMidpoint" | "leapfrogmidpoint" | "lfm":
solver_choice = "LeapfrogMidpoint"
case "lpt":
solver_choice = "lpt"
@ -209,9 +217,13 @@ if __name__ == "__main__":
mesh_shape = [mesh_size] * 3
final_field , stats, chrono_fun = run_simulation(mesh_shape, box_size, halo_size, solver_choice, iterations, pdims)
final_field, stats, chrono_fun = run_simulation(mesh_shape, box_size,
halo_size, solver_choice,
iterations, pdims)
print(f"shape of final_field {final_field.shape} and sharding spec {final_field.sharding} and local shape {final_field.addressable_data(0).shape}")
print(
f"shape of final_field {final_field.shape} and sharding spec {final_field.sharding} and local shape {final_field.addressable_data(0).shape}"
)
metadata = {
'rank': rank,
@ -236,7 +248,7 @@ if __name__ == "__main__":
with open(f'{field_folder}/jaxpm.log', 'w') as f:
f.write(f"Args: {args}\n")
f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n")
for i , time in enumerate(chrono_fun.times):
for i, time in enumerate(chrono_fun.times):
f.write(f"Time {i}: {time:.4f} ms\n")
f.write(f"Stats: {stats}\n")
if args.save_fields:

View file

@ -3,34 +3,41 @@ import os
# Change JAX GPU memory preallocation fraction
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.95'
import jax
import argparse
import numpy as np
import jax
import matplotlib.pyplot as plt
from pmwd import (
Configuration,
Cosmology, SimpleLCDM,
boltzmann, linear_power, growth,
white_noise, linear_modes,
lpt, nbody, scatter
)
import numpy as np
from hpc_plotter.timer import Timer
from pmwd import (Configuration, Cosmology, SimpleLCDM, boltzmann, growth,
linear_modes, linear_power, lpt, nbody, scatter, white_noise)
from pmwd.pm_util import fftinv
from pmwd.spec_util import powspec
from pmwd.vis_util import simshow
from hpc_plotter.timer import Timer
# Simulation configuration
def run_pmwd_simulation(ptcl_grid_shape, ptcl_spacing, solver , iterations):
def run_pmwd_simulation(ptcl_grid_shape, ptcl_spacing, solver, iterations):
@jax.jit
def simulate(omega_m, sigma8):
conf = Configuration(ptcl_spacing, ptcl_grid_shape=ptcl_grid_shape, mesh_shape=1,lpt_order=1,a_nbody_maxstep=1/91)
conf = Configuration(ptcl_spacing,
ptcl_grid_shape=ptcl_grid_shape,
mesh_shape=1,
lpt_order=1,
a_nbody_maxstep=1 / 91)
print(conf)
print(f'Simulating {conf.ptcl_num} particles with a {conf.mesh_shape} mesh for {conf.a_nbody_num} time steps.')
print(
f'Simulating {conf.ptcl_num} particles with a {conf.mesh_shape} mesh for {conf.a_nbody_num} time steps.'
)
cosmo = Cosmology(conf, A_s_1e9=2.0, n_s=0.96, Omega_m=omega_m, Omega_b=sigma8, h=0.7)
cosmo = Cosmology(conf,
A_s_1e9=2.0,
n_s=0.96,
Omega_m=omega_m,
Omega_b=sigma8,
h=0.7)
print(cosmo)
# Boltzmann calculation
@ -49,7 +56,8 @@ def run_pmwd_simulation(ptcl_grid_shape, ptcl_spacing, solver , iterations):
if solver == "lfm":
# N-body time integration from LPT initial conditions
ptcl, obsvbl = jax.block_until_ready(nbody(ptcl, obsvbl, cosmo, conf))
ptcl, obsvbl = jax.block_until_ready(
nbody(ptcl, obsvbl, cosmo, conf))
print("N-body time integration completed.")
# Scatter particles to mesh to get the density field
@ -62,28 +70,52 @@ def run_pmwd_simulation(ptcl_grid_shape, ptcl_spacing, solver , iterations):
for _ in range(iterations):
final_field = chrono_timer.chrono_fun(simulate, 0.3, 0.05)
return final_field , chrono_timer
return final_field, chrono_timer
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='PMWD Simulation')
parser.add_argument('-m', '--mesh_size', type=int, help='Mesh size', required=True)
parser.add_argument('-b', '--box_size', type=float, help='Box size', required=True)
parser.add_argument('-i', '--iterations', type=int, help='Number of iterations', default=10)
parser.add_argument('-o', '--output_path', type=str, help='Output path', default=".")
parser.add_argument('-f', '--save_fields', action='store_true', help='Save fields')
parser.add_argument('-s', '--solver', type=str, help='Solver', choices=["lfm" , "lpt"])
parser.add_argument('-pr',
parser.add_argument('-m',
'--mesh_size',
type=int,
help='Mesh size',
required=True)
parser.add_argument('-b',
'--box_size',
type=float,
help='Box size',
required=True)
parser.add_argument('-i',
'--iterations',
type=int,
help='Number of iterations',
default=10)
parser.add_argument('-o',
'--output_path',
type=str,
help='Output path',
default=".")
parser.add_argument('-f',
'--save_fields',
action='store_true',
help='Save fields')
parser.add_argument('-s',
'--solver',
type=str,
help='Solver',
choices=["lfm", "lpt"])
parser.add_argument(
'-pr',
'--precision',
type=str,
help='Precision',
choices=["float32", "float64"],)
choices=["float32", "float64"],
)
args = parser.parse_args()
mesh_shape = [args.mesh_size] * 3
ptcl_spacing = args.box_size /args.mesh_size
ptcl_spacing = args.box_size / args.mesh_size
iterations = args.iterations
solver = args.solver
output_path = args.output_path
@ -92,13 +124,12 @@ if __name__ == "__main__":
elif args.precision == "float64":
jax.config.update("jax_enable_x64", True)
os.makedirs(output_path, exist_ok=True)
final_field , chrono_fun = run_pmwd_simulation(mesh_shape, ptcl_spacing, solver, iterations)
final_field, chrono_fun = run_pmwd_simulation(mesh_shape, ptcl_spacing,
solver, iterations)
print("PMWD simulation completed.")
metadata = {
'rank': 0,
'function_name': f'PMWD-{solver}',
@ -118,14 +149,11 @@ if __name__ == "__main__":
f.write(f"PMWD simulation completed.\n")
f.write(f"Args : {args}\n")
f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n")
for i , time in enumerate(chrono_fun.times):
for i, time in enumerate(chrono_fun.times):
f.write(f"Time {i}: {time:.4f} ms\n")
if args.save_fields:
np.save(f"{field_folder}/final_field_0_0.npy", final_field)
print("Fields saved.")
print(f"saving to {output_path}/pmwd.csv")
print(f"saving field and logs to {field_folder}/pmwd.log")

View file

@ -177,7 +177,3 @@ for pr in "${precisions[@]}"; do
done
done
done

View file

@ -179,6 +179,3 @@ for pr in "${precisions[@]}"; do
done
done
done

View file

@ -160,6 +160,3 @@ for pr in "${precisions[@]}"; do
done
done
done

View file

@ -165,6 +165,3 @@ for pr in "${precisions[@]}"; do
done
done
done

View file

@ -44,18 +44,21 @@ def autoshmap(f: Callable,
return f
else:
if in_fourrier_space and 1 in mesh.devices.shape:
in_specs , out_specs = switch_specs((in_specs , out_specs))
in_specs, out_specs = switch_specs((in_specs, out_specs))
return shard_map(f, mesh, in_specs, out_specs, check_rep, auto)
def switch_specs(specs):
if isinstance(specs, P):
new_axes = tuple('y' if ax == 'x' else 'x' if ax == 'y' else ax for ax in specs)
new_axes = tuple('y' if ax == 'x' else 'x' if ax == 'y' else ax
for ax in specs)
return P(*new_axes)
elif isinstance(specs, tuple):
return tuple(switch_specs(sub_spec) for sub_spec in specs)
else:
raise TypeError("Element must be either a PartitionSpec or a tuple")
def fft3d(x):
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
return jaxdecomp.pfft3d(x.astype(jnp.complex64))
@ -108,12 +111,13 @@ def slice_unpad_impl(x, pad_width):
unpad_slice = [slice(None)] * 3
if halo_x > 0:
unpad_slice[0] = slice(halo_x , -halo_x)
unpad_slice[0] = slice(halo_x, -halo_x)
if halo_y > 0:
unpad_slice[1] = slice(halo_y , -halo_y)
unpad_slice[1] = slice(halo_y, -halo_y)
return x[tuple(unpad_slice)]
def slice_pad(x, pad_width):
mesh = mesh_lib.thread_resources.env.physical_mesh
if distributed and not (mesh.empty) and (pad_width[0][0] > 0

View file

@ -1,3 +1,4 @@
from enum import Enum
from functools import partial
import jax.numpy as jnp
@ -7,7 +8,7 @@ from jax._src import mesh as mesh_lib
from jax.sharding import PartitionSpec as P
from jaxpm.distributed import autoshmap
from enum import Enum
class PencilType(Enum):
NO_DECOMP = 0
@ -15,6 +16,7 @@ class PencilType(Enum):
SLAB_YZ = 2
PENCILS = 3
def get_pencil_type():
mesh = mesh_lib.thread_resources.env.physical_mesh
if mesh.empty:
@ -31,6 +33,7 @@ def get_pencil_type():
else:
return PencilType.PENCILS
def fftk(shape, dtype=np.float32):
"""
Generate Fourier transform wave numbers for a given mesh.
@ -46,7 +49,8 @@ def fftk(shape, dtype=np.float32):
@partial(autoshmap,
in_specs=(P('x'), P('y'), P(None)),
out_specs=(P('x'), P(None, 'y'), P(None)),in_fourrier_space=True)
out_specs=(P('x'), P(None, 'y'), P(None)),
in_fourrier_space=True)
def get_kvec(ky, kz, kx):
return (ky.reshape([-1, 1, 1]),
kz.reshape([1, -1, 1]),
@ -73,7 +77,10 @@ def interpolate_power_spectrum(input, k, pk):
pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape(-1), k, pk
).reshape(x.shape)
return autoshmap(pk_fn, in_specs=P('x', 'y'), out_specs=P('x', 'y'),in_fourrier_space=True)(input)
return autoshmap(pk_fn,
in_specs=P('x', 'y'),
out_specs=P('x', 'y'),
in_fourrier_space=True)(input)
def gradient_kernel(kvec, direction, order=1):

View file

@ -173,12 +173,14 @@ def cic_paint_dx(displacements, halo_size=0):
return mesh
def cic_read_dx_impl(mesh , halo_size):
def cic_read_dx_impl(mesh, halo_size):
halo_x, _ = halo_size[0]
halo_y, _ = halo_size[1]
original_shape = [dim - 2 * halo[0] for dim , halo in zip(mesh.shape, halo_size)]
original_shape = [
dim - 2 * halo[0] for dim, halo in zip(mesh.shape, halo_size)
]
a, b, c = jnp.meshgrid(jnp.arange(original_shape[0]),
jnp.arange(original_shape[1]),
jnp.arange(original_shape[2]),
@ -199,7 +201,7 @@ def cic_read_dx(mesh, halo_size=0):
mesh = halo_exchange(mesh,
halo_extents=halo_extents,
halo_periods=(True, True, True))
displacements = autoshmap(partial(cic_read_dx_impl , halo_size=halo_size),
displacements = autoshmap(partial(cic_read_dx_impl, halo_size=halo_size),
in_specs=(P('x', 'y')),
out_specs=P('x', 'y'))(mesh)

View file

@ -19,7 +19,8 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0, halo_size=0):
Computes gravitational forces on particles using a PM scheme
"""
if mesh_shape is None:
assert(delta is not None) , "If mesh_shape is not provided, delta should be provided"
assert (delta is not None
), "If mesh_shape is not provided, delta should be provided"
mesh_shape = delta.shape
kvec = fftk(mesh_shape)
@ -33,8 +34,8 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0, halo_size=0):
r_split=r_split)
# Computes gravitational forces
forces = jnp.stack([
cic_read_dx(ifft3d(gradient_kernel(kvec, i) * pot_k), halo_size=halo_size)
for i in range(3)
cic_read_dx(ifft3d(gradient_kernel(kvec, i) * pot_k),
halo_size=halo_size) for i in range(3)
],
axis=-1)

View file

@ -47,7 +47,6 @@ def run_simulation(omega_c, sigma8):
pk_fn,
seed=jax.random.PRNGKey(0))
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
# Initial displacement