This commit is contained in:
Wassim KABALAN 2024-08-03 00:23:40 +02:00
parent 831291c1f9
commit ece8c93540
12 changed files with 210 additions and 170 deletions

View file

@ -10,13 +10,14 @@ size = jax.process_count()
import argparse import argparse
import time import time
from hpc_plotter.timer import Timer
import jax.numpy as jnp import jax.numpy as jnp
import jax_cosmo as jc import jax_cosmo as jc
import numpy as np import numpy as np
from cupy.cuda.nvtx import RangePop, RangePush from cupy.cuda.nvtx import RangePop, RangePush
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm, from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
PIDController, SaveAt, Tsit5, diffeqsolve) PIDController, SaveAt, Tsit5, diffeqsolve)
from hpc_plotter.timer import Timer
from jax.experimental import mesh_utils from jax.experimental import mesh_utils
from jax.experimental.multihost_utils import sync_global_devices from jax.experimental.multihost_utils import sync_global_devices
from jax.sharding import Mesh, NamedSharding from jax.sharding import Mesh, NamedSharding
@ -27,7 +28,6 @@ from jaxpm.painting import cic_paint_dx
from jaxpm.pm import linear_field, lpt, make_ode_fn from jaxpm.pm import linear_field, lpt, make_ode_fn
def run_simulation(mesh_shape, def run_simulation(mesh_shape,
box_size, box_size,
halo_size, halo_size,
@ -94,12 +94,18 @@ def run_simulation(mesh_shape,
# Warm start # Warm start
chrono_fun = Timer() chrono_fun = Timer()
RangePush("warmup") RangePush("warmup")
final_field, stats = chrono_fun.chrono_jit(simulate, 0.32, 0.8 , ndarray_arg = 0) final_field, stats = chrono_fun.chrono_jit(simulate,
0.32,
0.8,
ndarray_arg=0)
RangePop() RangePop()
sync_global_devices("warmup") sync_global_devices("warmup")
for i in range(iterations): for i in range(iterations):
RangePush(f"sim iter {i}") RangePush(f"sim iter {i}")
final_field, stats = chrono_fun.chrono_fun(simulate, 0.32, 0.8 , ndarray_arg = 0) final_field, stats = chrono_fun.chrono_fun(simulate,
0.32,
0.8,
ndarray_arg=0)
RangePop() RangePop()
return final_field, stats, chrono_fun return final_field, stats, chrono_fun
@ -134,11 +140,13 @@ if __name__ == "__main__":
type=str, type=str,
help='Processor dimensions', help='Processor dimensions',
default=None) default=None)
parser.add_argument('-pr', parser.add_argument(
'-pr',
'--precision', '--precision',
type=str, type=str,
help='Precision', help='Precision',
choices=["float32", "float64"],) choices=["float32", "float64"],
)
parser.add_argument('-hs', parser.add_argument('-hs',
'--halo_size', '--halo_size',
type=int, type=int,
@ -185,11 +193,11 @@ if __name__ == "__main__":
print(f"solver choice: {solver_choice}") print(f"solver choice: {solver_choice}")
match solver_choice: match solver_choice:
case "Dopri5" | "dopri5"| "d5": case "Dopri5" | "dopri5" | "d5":
solver_choice = "Dopri5" solver_choice = "Dopri5"
case "Tsit5"| "tsit5"| "t5": case "Tsit5" | "tsit5" | "t5":
solver_choice = "Tsit5" solver_choice = "Tsit5"
case "LeapfrogMidpoint"| "leapfrogmidpoint"| "lfm": case "LeapfrogMidpoint" | "leapfrogmidpoint" | "lfm":
solver_choice = "LeapfrogMidpoint" solver_choice = "LeapfrogMidpoint"
case "lpt": case "lpt":
solver_choice = "lpt" solver_choice = "lpt"
@ -209,9 +217,13 @@ if __name__ == "__main__":
mesh_shape = [mesh_size] * 3 mesh_shape = [mesh_size] * 3
final_field , stats, chrono_fun = run_simulation(mesh_shape, box_size, halo_size, solver_choice, iterations, pdims) final_field, stats, chrono_fun = run_simulation(mesh_shape, box_size,
halo_size, solver_choice,
iterations, pdims)
print(f"shape of final_field {final_field.shape} and sharding spec {final_field.sharding} and local shape {final_field.addressable_data(0).shape}") print(
f"shape of final_field {final_field.shape} and sharding spec {final_field.sharding} and local shape {final_field.addressable_data(0).shape}"
)
metadata = { metadata = {
'rank': rank, 'rank': rank,
@ -236,7 +248,7 @@ if __name__ == "__main__":
with open(f'{field_folder}/jaxpm.log', 'w') as f: with open(f'{field_folder}/jaxpm.log', 'w') as f:
f.write(f"Args: {args}\n") f.write(f"Args: {args}\n")
f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n") f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n")
for i , time in enumerate(chrono_fun.times): for i, time in enumerate(chrono_fun.times):
f.write(f"Time {i}: {time:.4f} ms\n") f.write(f"Time {i}: {time:.4f} ms\n")
f.write(f"Stats: {stats}\n") f.write(f"Stats: {stats}\n")
if args.save_fields: if args.save_fields:

View file

@ -3,34 +3,41 @@ import os
# Change JAX GPU memory preallocation fraction # Change JAX GPU memory preallocation fraction
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.95' os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.95'
import jax
import argparse import argparse
import numpy as np
import jax
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from pmwd import ( import numpy as np
Configuration, from hpc_plotter.timer import Timer
Cosmology, SimpleLCDM, from pmwd import (Configuration, Cosmology, SimpleLCDM, boltzmann, growth,
boltzmann, linear_power, growth, linear_modes, linear_power, lpt, nbody, scatter, white_noise)
white_noise, linear_modes,
lpt, nbody, scatter
)
from pmwd.pm_util import fftinv from pmwd.pm_util import fftinv
from pmwd.spec_util import powspec from pmwd.spec_util import powspec
from pmwd.vis_util import simshow from pmwd.vis_util import simshow
from hpc_plotter.timer import Timer
# Simulation configuration # Simulation configuration
def run_pmwd_simulation(ptcl_grid_shape, ptcl_spacing, solver , iterations): def run_pmwd_simulation(ptcl_grid_shape, ptcl_spacing, solver, iterations):
@jax.jit @jax.jit
def simulate(omega_m, sigma8): def simulate(omega_m, sigma8):
conf = Configuration(ptcl_spacing,
conf = Configuration(ptcl_spacing, ptcl_grid_shape=ptcl_grid_shape, mesh_shape=1,lpt_order=1,a_nbody_maxstep=1/91) ptcl_grid_shape=ptcl_grid_shape,
mesh_shape=1,
lpt_order=1,
a_nbody_maxstep=1 / 91)
print(conf) print(conf)
print(f'Simulating {conf.ptcl_num} particles with a {conf.mesh_shape} mesh for {conf.a_nbody_num} time steps.') print(
f'Simulating {conf.ptcl_num} particles with a {conf.mesh_shape} mesh for {conf.a_nbody_num} time steps.'
)
cosmo = Cosmology(conf, A_s_1e9=2.0, n_s=0.96, Omega_m=omega_m, Omega_b=sigma8, h=0.7) cosmo = Cosmology(conf,
A_s_1e9=2.0,
n_s=0.96,
Omega_m=omega_m,
Omega_b=sigma8,
h=0.7)
print(cosmo) print(cosmo)
# Boltzmann calculation # Boltzmann calculation
@ -49,7 +56,8 @@ def run_pmwd_simulation(ptcl_grid_shape, ptcl_spacing, solver , iterations):
if solver == "lfm": if solver == "lfm":
# N-body time integration from LPT initial conditions # N-body time integration from LPT initial conditions
ptcl, obsvbl = jax.block_until_ready(nbody(ptcl, obsvbl, cosmo, conf)) ptcl, obsvbl = jax.block_until_ready(
nbody(ptcl, obsvbl, cosmo, conf))
print("N-body time integration completed.") print("N-body time integration completed.")
# Scatter particles to mesh to get the density field # Scatter particles to mesh to get the density field
@ -62,28 +70,52 @@ def run_pmwd_simulation(ptcl_grid_shape, ptcl_spacing, solver , iterations):
for _ in range(iterations): for _ in range(iterations):
final_field = chrono_timer.chrono_fun(simulate, 0.3, 0.05) final_field = chrono_timer.chrono_fun(simulate, 0.3, 0.05)
return final_field , chrono_timer return final_field, chrono_timer
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description='PMWD Simulation') parser = argparse.ArgumentParser(description='PMWD Simulation')
parser.add_argument('-m', '--mesh_size', type=int, help='Mesh size', required=True) parser.add_argument('-m',
parser.add_argument('-b', '--box_size', type=float, help='Box size', required=True) '--mesh_size',
parser.add_argument('-i', '--iterations', type=int, help='Number of iterations', default=10) type=int,
parser.add_argument('-o', '--output_path', type=str, help='Output path', default=".") help='Mesh size',
parser.add_argument('-f', '--save_fields', action='store_true', help='Save fields') required=True)
parser.add_argument('-s', '--solver', type=str, help='Solver', choices=["lfm" , "lpt"]) parser.add_argument('-b',
parser.add_argument('-pr', '--box_size',
type=float,
help='Box size',
required=True)
parser.add_argument('-i',
'--iterations',
type=int,
help='Number of iterations',
default=10)
parser.add_argument('-o',
'--output_path',
type=str,
help='Output path',
default=".")
parser.add_argument('-f',
'--save_fields',
action='store_true',
help='Save fields')
parser.add_argument('-s',
'--solver',
type=str,
help='Solver',
choices=["lfm", "lpt"])
parser.add_argument(
'-pr',
'--precision', '--precision',
type=str, type=str,
help='Precision', help='Precision',
choices=["float32", "float64"],) choices=["float32", "float64"],
)
args = parser.parse_args() args = parser.parse_args()
mesh_shape = [args.mesh_size] * 3 mesh_shape = [args.mesh_size] * 3
ptcl_spacing = args.box_size /args.mesh_size ptcl_spacing = args.box_size / args.mesh_size
iterations = args.iterations iterations = args.iterations
solver = args.solver solver = args.solver
output_path = args.output_path output_path = args.output_path
@ -92,13 +124,12 @@ if __name__ == "__main__":
elif args.precision == "float64": elif args.precision == "float64":
jax.config.update("jax_enable_x64", True) jax.config.update("jax_enable_x64", True)
os.makedirs(output_path, exist_ok=True) os.makedirs(output_path, exist_ok=True)
final_field , chrono_fun = run_pmwd_simulation(mesh_shape, ptcl_spacing, solver, iterations) final_field, chrono_fun = run_pmwd_simulation(mesh_shape, ptcl_spacing,
solver, iterations)
print("PMWD simulation completed.") print("PMWD simulation completed.")
metadata = { metadata = {
'rank': 0, 'rank': 0,
'function_name': f'PMWD-{solver}', 'function_name': f'PMWD-{solver}',
@ -118,14 +149,11 @@ if __name__ == "__main__":
f.write(f"PMWD simulation completed.\n") f.write(f"PMWD simulation completed.\n")
f.write(f"Args : {args}\n") f.write(f"Args : {args}\n")
f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n") f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n")
for i , time in enumerate(chrono_fun.times): for i, time in enumerate(chrono_fun.times):
f.write(f"Time {i}: {time:.4f} ms\n") f.write(f"Time {i}: {time:.4f} ms\n")
if args.save_fields: if args.save_fields:
np.save(f"{field_folder}/final_field_0_0.npy", final_field) np.save(f"{field_folder}/final_field_0_0.npy", final_field)
print("Fields saved.") print("Fields saved.")
print(f"saving to {output_path}/pmwd.csv") print(f"saving to {output_path}/pmwd.csv")
print(f"saving field and logs to {field_folder}/pmwd.log") print(f"saving field and logs to {field_folder}/pmwd.log")

View file

@ -177,7 +177,3 @@ for pr in "${precisions[@]}"; do
done done
done done
done done

View file

@ -179,6 +179,3 @@ for pr in "${precisions[@]}"; do
done done
done done
done done

View file

@ -160,6 +160,3 @@ for pr in "${precisions[@]}"; do
done done
done done
done done

View file

@ -165,6 +165,3 @@ for pr in "${precisions[@]}"; do
done done
done done
done done

View file

@ -44,18 +44,21 @@ def autoshmap(f: Callable,
return f return f
else: else:
if in_fourrier_space and 1 in mesh.devices.shape: if in_fourrier_space and 1 in mesh.devices.shape:
in_specs , out_specs = switch_specs((in_specs , out_specs)) in_specs, out_specs = switch_specs((in_specs, out_specs))
return shard_map(f, mesh, in_specs, out_specs, check_rep, auto) return shard_map(f, mesh, in_specs, out_specs, check_rep, auto)
def switch_specs(specs): def switch_specs(specs):
if isinstance(specs, P): if isinstance(specs, P):
new_axes = tuple('y' if ax == 'x' else 'x' if ax == 'y' else ax for ax in specs) new_axes = tuple('y' if ax == 'x' else 'x' if ax == 'y' else ax
for ax in specs)
return P(*new_axes) return P(*new_axes)
elif isinstance(specs, tuple): elif isinstance(specs, tuple):
return tuple(switch_specs(sub_spec) for sub_spec in specs) return tuple(switch_specs(sub_spec) for sub_spec in specs)
else: else:
raise TypeError("Element must be either a PartitionSpec or a tuple") raise TypeError("Element must be either a PartitionSpec or a tuple")
def fft3d(x): def fft3d(x):
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty): if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
return jaxdecomp.pfft3d(x.astype(jnp.complex64)) return jaxdecomp.pfft3d(x.astype(jnp.complex64))
@ -108,12 +111,13 @@ def slice_unpad_impl(x, pad_width):
unpad_slice = [slice(None)] * 3 unpad_slice = [slice(None)] * 3
if halo_x > 0: if halo_x > 0:
unpad_slice[0] = slice(halo_x , -halo_x) unpad_slice[0] = slice(halo_x, -halo_x)
if halo_y > 0: if halo_y > 0:
unpad_slice[1] = slice(halo_y , -halo_y) unpad_slice[1] = slice(halo_y, -halo_y)
return x[tuple(unpad_slice)] return x[tuple(unpad_slice)]
def slice_pad(x, pad_width): def slice_pad(x, pad_width):
mesh = mesh_lib.thread_resources.env.physical_mesh mesh = mesh_lib.thread_resources.env.physical_mesh
if distributed and not (mesh.empty) and (pad_width[0][0] > 0 if distributed and not (mesh.empty) and (pad_width[0][0] > 0

View file

@ -1,3 +1,4 @@
from enum import Enum
from functools import partial from functools import partial
import jax.numpy as jnp import jax.numpy as jnp
@ -7,7 +8,7 @@ from jax._src import mesh as mesh_lib
from jax.sharding import PartitionSpec as P from jax.sharding import PartitionSpec as P
from jaxpm.distributed import autoshmap from jaxpm.distributed import autoshmap
from enum import Enum
class PencilType(Enum): class PencilType(Enum):
NO_DECOMP = 0 NO_DECOMP = 0
@ -15,6 +16,7 @@ class PencilType(Enum):
SLAB_YZ = 2 SLAB_YZ = 2
PENCILS = 3 PENCILS = 3
def get_pencil_type(): def get_pencil_type():
mesh = mesh_lib.thread_resources.env.physical_mesh mesh = mesh_lib.thread_resources.env.physical_mesh
if mesh.empty: if mesh.empty:
@ -31,6 +33,7 @@ def get_pencil_type():
else: else:
return PencilType.PENCILS return PencilType.PENCILS
def fftk(shape, dtype=np.float32): def fftk(shape, dtype=np.float32):
""" """
Generate Fourier transform wave numbers for a given mesh. Generate Fourier transform wave numbers for a given mesh.
@ -46,7 +49,8 @@ def fftk(shape, dtype=np.float32):
@partial(autoshmap, @partial(autoshmap,
in_specs=(P('x'), P('y'), P(None)), in_specs=(P('x'), P('y'), P(None)),
out_specs=(P('x'), P(None, 'y'), P(None)),in_fourrier_space=True) out_specs=(P('x'), P(None, 'y'), P(None)),
in_fourrier_space=True)
def get_kvec(ky, kz, kx): def get_kvec(ky, kz, kx):
return (ky.reshape([-1, 1, 1]), return (ky.reshape([-1, 1, 1]),
kz.reshape([1, -1, 1]), kz.reshape([1, -1, 1]),
@ -73,7 +77,10 @@ def interpolate_power_spectrum(input, k, pk):
pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape(-1), k, pk pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape(-1), k, pk
).reshape(x.shape) ).reshape(x.shape)
return autoshmap(pk_fn, in_specs=P('x', 'y'), out_specs=P('x', 'y'),in_fourrier_space=True)(input) return autoshmap(pk_fn,
in_specs=P('x', 'y'),
out_specs=P('x', 'y'),
in_fourrier_space=True)(input)
def gradient_kernel(kvec, direction, order=1): def gradient_kernel(kvec, direction, order=1):

View file

@ -173,12 +173,14 @@ def cic_paint_dx(displacements, halo_size=0):
return mesh return mesh
def cic_read_dx_impl(mesh , halo_size): def cic_read_dx_impl(mesh, halo_size):
halo_x, _ = halo_size[0] halo_x, _ = halo_size[0]
halo_y, _ = halo_size[1] halo_y, _ = halo_size[1]
original_shape = [dim - 2 * halo[0] for dim , halo in zip(mesh.shape, halo_size)] original_shape = [
dim - 2 * halo[0] for dim, halo in zip(mesh.shape, halo_size)
]
a, b, c = jnp.meshgrid(jnp.arange(original_shape[0]), a, b, c = jnp.meshgrid(jnp.arange(original_shape[0]),
jnp.arange(original_shape[1]), jnp.arange(original_shape[1]),
jnp.arange(original_shape[2]), jnp.arange(original_shape[2]),
@ -199,7 +201,7 @@ def cic_read_dx(mesh, halo_size=0):
mesh = halo_exchange(mesh, mesh = halo_exchange(mesh,
halo_extents=halo_extents, halo_extents=halo_extents,
halo_periods=(True, True, True)) halo_periods=(True, True, True))
displacements = autoshmap(partial(cic_read_dx_impl , halo_size=halo_size), displacements = autoshmap(partial(cic_read_dx_impl, halo_size=halo_size),
in_specs=(P('x', 'y')), in_specs=(P('x', 'y')),
out_specs=P('x', 'y'))(mesh) out_specs=P('x', 'y'))(mesh)

View file

@ -19,7 +19,8 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0, halo_size=0):
Computes gravitational forces on particles using a PM scheme Computes gravitational forces on particles using a PM scheme
""" """
if mesh_shape is None: if mesh_shape is None:
assert(delta is not None) , "If mesh_shape is not provided, delta should be provided" assert (delta is not None
), "If mesh_shape is not provided, delta should be provided"
mesh_shape = delta.shape mesh_shape = delta.shape
kvec = fftk(mesh_shape) kvec = fftk(mesh_shape)
@ -33,8 +34,8 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0, halo_size=0):
r_split=r_split) r_split=r_split)
# Computes gravitational forces # Computes gravitational forces
forces = jnp.stack([ forces = jnp.stack([
cic_read_dx(ifft3d(gradient_kernel(kvec, i) * pot_k), halo_size=halo_size) cic_read_dx(ifft3d(gradient_kernel(kvec, i) * pot_k),
for i in range(3) halo_size=halo_size) for i in range(3)
], ],
axis=-1) axis=-1)

View file

@ -47,7 +47,6 @@ def run_simulation(omega_c, sigma8):
pk_fn, pk_fn,
seed=jax.random.PRNGKey(0)) seed=jax.random.PRNGKey(0))
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8) cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
# Initial displacement # Initial displacement