This commit is contained in:
Wassim KABALAN 2024-08-03 00:23:40 +02:00
parent 831291c1f9
commit ece8c93540
12 changed files with 210 additions and 170 deletions

View file

@ -10,13 +10,14 @@ size = jax.process_count()
import argparse import argparse
import time import time
from hpc_plotter.timer import Timer
import jax.numpy as jnp import jax.numpy as jnp
import jax_cosmo as jc import jax_cosmo as jc
import numpy as np import numpy as np
from cupy.cuda.nvtx import RangePop, RangePush from cupy.cuda.nvtx import RangePop, RangePush
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm, from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
PIDController, SaveAt, Tsit5, diffeqsolve) PIDController, SaveAt, Tsit5, diffeqsolve)
from hpc_plotter.timer import Timer
from jax.experimental import mesh_utils from jax.experimental import mesh_utils
from jax.experimental.multihost_utils import sync_global_devices from jax.experimental.multihost_utils import sync_global_devices
from jax.sharding import Mesh, NamedSharding from jax.sharding import Mesh, NamedSharding
@ -27,7 +28,6 @@ from jaxpm.painting import cic_paint_dx
from jaxpm.pm import linear_field, lpt, make_ode_fn from jaxpm.pm import linear_field, lpt, make_ode_fn
def run_simulation(mesh_shape, def run_simulation(mesh_shape,
box_size, box_size,
halo_size, halo_size,
@ -69,7 +69,7 @@ def run_simulation(mesh_shape,
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size) ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size)
term = ODETerm( term = ODETerm(
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0)) lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
if solver_choice == "Dopri5" or solver_choice == "Tsit5": if solver_choice == "Dopri5" or solver_choice == "Tsit5":
stepsize_controller = PIDController(rtol=1e-4, atol=1e-4) stepsize_controller = PIDController(rtol=1e-4, atol=1e-4)
elif solver_choice == "LeapfrogMidpoint" or solver_choice == "Euler": elif solver_choice == "LeapfrogMidpoint" or solver_choice == "Euler":
@ -94,12 +94,18 @@ def run_simulation(mesh_shape,
# Warm start # Warm start
chrono_fun = Timer() chrono_fun = Timer()
RangePush("warmup") RangePush("warmup")
final_field, stats = chrono_fun.chrono_jit(simulate, 0.32, 0.8 , ndarray_arg = 0) final_field, stats = chrono_fun.chrono_jit(simulate,
0.32,
0.8,
ndarray_arg=0)
RangePop() RangePop()
sync_global_devices("warmup") sync_global_devices("warmup")
for i in range(iterations): for i in range(iterations):
RangePush(f"sim iter {i}") RangePush(f"sim iter {i}")
final_field, stats = chrono_fun.chrono_fun(simulate, 0.32, 0.8 , ndarray_arg = 0) final_field, stats = chrono_fun.chrono_fun(simulate,
0.32,
0.8,
ndarray_arg=0)
RangePop() RangePop()
return final_field, stats, chrono_fun return final_field, stats, chrono_fun
@ -134,11 +140,13 @@ if __name__ == "__main__":
type=str, type=str,
help='Processor dimensions', help='Processor dimensions',
default=None) default=None)
parser.add_argument('-pr', parser.add_argument(
'--precision', '-pr',
type=str, '--precision',
help='Precision', type=str,
choices=["float32", "float64"],) help='Precision',
choices=["float32", "float64"],
)
parser.add_argument('-hs', parser.add_argument('-hs',
'--halo_size', '--halo_size',
type=int, type=int,
@ -173,7 +181,7 @@ if __name__ == "__main__":
type=int, type=int,
help='Number of nodes', help='Number of nodes',
default=1) default=1)
args = parser.parse_args() args = parser.parse_args()
mesh_size = args.mesh_size mesh_size = args.mesh_size
box_size = [args.box_size] * 3 box_size = [args.box_size] * 3
@ -182,14 +190,14 @@ if __name__ == "__main__":
iterations = args.iterations iterations = args.iterations
output_path = args.output_path output_path = args.output_path
os.makedirs(output_path, exist_ok=True) os.makedirs(output_path, exist_ok=True)
print(f"solver choice: {solver_choice}") print(f"solver choice: {solver_choice}")
match solver_choice: match solver_choice:
case "Dopri5" | "dopri5"| "d5": case "Dopri5" | "dopri5" | "d5":
solver_choice = "Dopri5" solver_choice = "Dopri5"
case "Tsit5"| "tsit5"| "t5": case "Tsit5" | "tsit5" | "t5":
solver_choice = "Tsit5" solver_choice = "Tsit5"
case "LeapfrogMidpoint"| "leapfrogmidpoint"| "lfm": case "LeapfrogMidpoint" | "leapfrogmidpoint" | "lfm":
solver_choice = "LeapfrogMidpoint" solver_choice = "LeapfrogMidpoint"
case "lpt": case "lpt":
solver_choice = "lpt" solver_choice = "lpt"
@ -199,7 +207,7 @@ if __name__ == "__main__":
) )
if args.precision == "float32": if args.precision == "float32":
jax.config.update("jax_enable_x64", False) jax.config.update("jax_enable_x64", False)
elif args.precision == "float64": elif args.precision == "float64":
jax.config.update("jax_enable_x64", True) jax.config.update("jax_enable_x64", True)
if args.pdims: if args.pdims:
@ -209,22 +217,26 @@ if __name__ == "__main__":
mesh_shape = [mesh_size] * 3 mesh_shape = [mesh_size] * 3
final_field , stats, chrono_fun = run_simulation(mesh_shape, box_size, halo_size, solver_choice, iterations, pdims) final_field, stats, chrono_fun = run_simulation(mesh_shape, box_size,
halo_size, solver_choice,
print(f"shape of final_field {final_field.shape} and sharding spec {final_field.sharding} and local shape {final_field.addressable_data(0).shape}") iterations, pdims)
print(
f"shape of final_field {final_field.shape} and sharding spec {final_field.sharding} and local shape {final_field.addressable_data(0).shape}"
)
metadata = { metadata = {
'rank': rank, 'rank': rank,
'function_name': f'JAXPM-{solver_choice}', 'function_name': f'JAXPM-{solver_choice}',
'precision': args.precision, 'precision': args.precision,
'x': str(mesh_size), 'x': str(mesh_size),
'y': str(mesh_size), 'y': str(mesh_size),
'z': str(stats["num_steps"]), 'z': str(stats["num_steps"]),
'px': str(pdims[0]), 'px': str(pdims[0]),
'py': str(pdims[1]), 'py': str(pdims[1]),
'backend': 'NCCL', 'backend': 'NCCL',
'nodes': str(args.nodes) 'nodes': str(args.nodes)
} }
# Print the results to a CSV file # Print the results to a CSV file
chrono_fun.print_to_csv(f'{output_path}/jaxpm_benchmark.csv', **metadata) chrono_fun.print_to_csv(f'{output_path}/jaxpm_benchmark.csv', **metadata)
@ -236,8 +248,8 @@ if __name__ == "__main__":
with open(f'{field_folder}/jaxpm.log', 'w') as f: with open(f'{field_folder}/jaxpm.log', 'w') as f:
f.write(f"Args: {args}\n") f.write(f"Args: {args}\n")
f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n") f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n")
for i , time in enumerate(chrono_fun.times): for i, time in enumerate(chrono_fun.times):
f.write(f"Time {i}: {time:.4f} ms\n") f.write(f"Time {i}: {time:.4f} ms\n")
f.write(f"Stats: {stats}\n") f.write(f"Stats: {stats}\n")
if args.save_fields: if args.save_fields:
np.save(f'{field_folder}/final_field_0_{rank}.npy', np.save(f'{field_folder}/final_field_0_{rank}.npy',

View file

@ -3,34 +3,41 @@ import os
# Change JAX GPU memory preallocation fraction # Change JAX GPU memory preallocation fraction
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.95' os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.95'
import jax
import argparse import argparse
import numpy as np
import jax
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from pmwd import ( import numpy as np
Configuration, from hpc_plotter.timer import Timer
Cosmology, SimpleLCDM, from pmwd import (Configuration, Cosmology, SimpleLCDM, boltzmann, growth,
boltzmann, linear_power, growth, linear_modes, linear_power, lpt, nbody, scatter, white_noise)
white_noise, linear_modes,
lpt, nbody, scatter
)
from pmwd.pm_util import fftinv from pmwd.pm_util import fftinv
from pmwd.spec_util import powspec from pmwd.spec_util import powspec
from pmwd.vis_util import simshow from pmwd.vis_util import simshow
from hpc_plotter.timer import Timer
# Simulation configuration # Simulation configuration
def run_pmwd_simulation(ptcl_grid_shape, ptcl_spacing, solver , iterations): def run_pmwd_simulation(ptcl_grid_shape, ptcl_spacing, solver, iterations):
@jax.jit @jax.jit
def simulate(omega_m, sigma8): def simulate(omega_m, sigma8):
conf = Configuration(ptcl_spacing, ptcl_grid_shape=ptcl_grid_shape, mesh_shape=1,lpt_order=1,a_nbody_maxstep=1/91)
print(conf)
print(f'Simulating {conf.ptcl_num} particles with a {conf.mesh_shape} mesh for {conf.a_nbody_num} time steps.')
cosmo = Cosmology(conf, A_s_1e9=2.0, n_s=0.96, Omega_m=omega_m, Omega_b=sigma8, h=0.7) conf = Configuration(ptcl_spacing,
ptcl_grid_shape=ptcl_grid_shape,
mesh_shape=1,
lpt_order=1,
a_nbody_maxstep=1 / 91)
print(conf)
print(
f'Simulating {conf.ptcl_num} particles with a {conf.mesh_shape} mesh for {conf.a_nbody_num} time steps.'
)
cosmo = Cosmology(conf,
A_s_1e9=2.0,
n_s=0.96,
Omega_m=omega_m,
Omega_b=sigma8,
h=0.7)
print(cosmo) print(cosmo)
# Boltzmann calculation # Boltzmann calculation
@ -46,71 +53,95 @@ def run_pmwd_simulation(ptcl_grid_shape, ptcl_spacing, solver , iterations):
# Solve LPT at some early time # Solve LPT at some early time
ptcl, obsvbl = lpt(modes, cosmo, conf) ptcl, obsvbl = lpt(modes, cosmo, conf)
print("LPT solved.") print("LPT solved.")
if solver == "lfm": if solver == "lfm":
# N-body time integration from LPT initial conditions # N-body time integration from LPT initial conditions
ptcl, obsvbl = jax.block_until_ready(nbody(ptcl, obsvbl, cosmo, conf)) ptcl, obsvbl = jax.block_until_ready(
print("N-body time integration completed.") nbody(ptcl, obsvbl, cosmo, conf))
print("N-body time integration completed.")
# Scatter particles to mesh to get the density field # Scatter particles to mesh to get the density field
dens = scatter(ptcl, conf) dens = scatter(ptcl, conf)
return dens return dens
chrono_timer = Timer() chrono_timer = Timer()
final_field = chrono_timer.chrono_jit(simulate, 0.3, 0.05) final_field = chrono_timer.chrono_jit(simulate, 0.3, 0.05)
for _ in range(iterations): for _ in range(iterations):
final_field = chrono_timer.chrono_fun(simulate, 0.3, 0.05) final_field = chrono_timer.chrono_fun(simulate, 0.3, 0.05)
return final_field , chrono_timer return final_field, chrono_timer
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description='PMWD Simulation') parser = argparse.ArgumentParser(description='PMWD Simulation')
parser.add_argument('-m', '--mesh_size', type=int, help='Mesh size', required=True) parser.add_argument('-m',
parser.add_argument('-b', '--box_size', type=float, help='Box size', required=True) '--mesh_size',
parser.add_argument('-i', '--iterations', type=int, help='Number of iterations', default=10) type=int,
parser.add_argument('-o', '--output_path', type=str, help='Output path', default=".") help='Mesh size',
parser.add_argument('-f', '--save_fields', action='store_true', help='Save fields') required=True)
parser.add_argument('-s', '--solver', type=str, help='Solver', choices=["lfm" , "lpt"]) parser.add_argument('-b',
parser.add_argument('-pr', '--box_size',
'--precision', type=float,
type=str, help='Box size',
help='Precision', required=True)
choices=["float32", "float64"],) parser.add_argument('-i',
'--iterations',
type=int,
help='Number of iterations',
default=10)
parser.add_argument('-o',
'--output_path',
type=str,
help='Output path',
default=".")
parser.add_argument('-f',
'--save_fields',
action='store_true',
help='Save fields')
parser.add_argument('-s',
'--solver',
type=str,
help='Solver',
choices=["lfm", "lpt"])
parser.add_argument(
'-pr',
'--precision',
type=str,
help='Precision',
choices=["float32", "float64"],
)
args = parser.parse_args() args = parser.parse_args()
mesh_shape = [args.mesh_size] * 3 mesh_shape = [args.mesh_size] * 3
ptcl_spacing = args.box_size /args.mesh_size ptcl_spacing = args.box_size / args.mesh_size
iterations = args.iterations iterations = args.iterations
solver = args.solver solver = args.solver
output_path = args.output_path output_path = args.output_path
if args.precision == "float32": if args.precision == "float32":
jax.config.update("jax_enable_x64", False) jax.config.update("jax_enable_x64", False)
elif args.precision == "float64": elif args.precision == "float64":
jax.config.update("jax_enable_x64", True) jax.config.update("jax_enable_x64", True)
os.makedirs(output_path, exist_ok=True) os.makedirs(output_path, exist_ok=True)
final_field , chrono_fun = run_pmwd_simulation(mesh_shape, ptcl_spacing, solver, iterations) final_field, chrono_fun = run_pmwd_simulation(mesh_shape, ptcl_spacing,
solver, iterations)
print("PMWD simulation completed.") print("PMWD simulation completed.")
metadata = { metadata = {
'rank': 0, 'rank': 0,
'function_name': f'PMWD-{solver}', 'function_name': f'PMWD-{solver}',
'precision': args.precision, 'precision': args.precision,
'x': str(mesh_shape[0]), 'x': str(mesh_shape[0]),
'y': str(mesh_shape[0]), 'y': str(mesh_shape[0]),
'z': str(mesh_shape[0]), 'z': str(mesh_shape[0]),
'px': "1", 'px': "1",
'py': "1", 'py': "1",
'backend': 'NCCL', 'backend': 'NCCL',
'nodes': "1" 'nodes': "1"
} }
chrono_fun.print_to_csv(f"{output_path}/pmwd.csv", **metadata) chrono_fun.print_to_csv(f"{output_path}/pmwd.csv", **metadata)
field_folder = f"{output_path}/final_field/pmwd/1/{args.mesh_size}_{int(args.box_size)}/1x1/{args.solver}/halo_0" field_folder = f"{output_path}/final_field/pmwd/1/{args.mesh_size}_{int(args.box_size)}/1x1/{args.solver}/halo_0"
os.makedirs(field_folder, exist_ok=True) os.makedirs(field_folder, exist_ok=True)
@ -118,14 +149,11 @@ if __name__ == "__main__":
f.write(f"PMWD simulation completed.\n") f.write(f"PMWD simulation completed.\n")
f.write(f"Args : {args}\n") f.write(f"Args : {args}\n")
f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n") f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n")
for i , time in enumerate(chrono_fun.times): for i, time in enumerate(chrono_fun.times):
f.write(f"Time {i}: {time:.4f} ms\n") f.write(f"Time {i}: {time:.4f} ms\n")
if args.save_fields: if args.save_fields:
np.save(f"{field_folder}/final_field_0_0.npy", final_field) np.save(f"{field_folder}/final_field_0_0.npy", final_field)
print("Fields saved.") print("Fields saved.")
print(f"saving to {output_path}/pmwd.csv") print(f"saving to {output_path}/pmwd.csv")
print(f"saving field and logs to {field_folder}/pmwd.log") print(f"saving field and logs to {field_folder}/pmwd.log")

View file

@ -177,7 +177,3 @@ for pr in "${precisions[@]}"; do
done done
done done
done done

View file

@ -179,6 +179,3 @@ for pr in "${precisions[@]}"; do
done done
done done
done done

View file

@ -156,10 +156,7 @@ echo "Output dir is : $out_dir"
for pr in "${precisions[@]}"; do for pr in "${precisions[@]}"; do
for g in "${grid[@]}"; do for g in "${grid[@]}"; do
for solver in "${solvers[@]}"; do for solver in "${solvers[@]}"; do
launch bench_pmwd.py -m $g -b $g -p $p -pr $pr -s $solver -i 4 -o $out_dir -f launch bench_pmwd.py -m $g -b $g -p $p -pr $pr -s $solver -i 4 -o $out_dir -f
done done
done done
done done

View file

@ -161,10 +161,7 @@ echo "Output dir is : $out_dir"
for pr in "${precisions[@]}"; do for pr in "${precisions[@]}"; do
for g in "${grid[@]}"; do for g in "${grid[@]}"; do
for solver in "${solvers[@]}"; do for solver in "${solvers[@]}"; do
slaunch bench_pmwd.py -m $g -b $g -pr $pr -s $solver -i 4 -o $out_dir -f slaunch bench_pmwd.py -m $g -b $g -pr $pr -s $solver -i 4 -o $out_dir -f
done done
done done
done done

View file

@ -44,17 +44,20 @@ def autoshmap(f: Callable,
return f return f
else: else:
if in_fourrier_space and 1 in mesh.devices.shape: if in_fourrier_space and 1 in mesh.devices.shape:
in_specs , out_specs = switch_specs((in_specs , out_specs)) in_specs, out_specs = switch_specs((in_specs, out_specs))
return shard_map(f, mesh, in_specs, out_specs, check_rep, auto) return shard_map(f, mesh, in_specs, out_specs, check_rep, auto)
def switch_specs(specs): def switch_specs(specs):
if isinstance(specs, P): if isinstance(specs, P):
new_axes = tuple('y' if ax == 'x' else 'x' if ax == 'y' else ax for ax in specs) new_axes = tuple('y' if ax == 'x' else 'x' if ax == 'y' else ax
return P(*new_axes) for ax in specs)
elif isinstance(specs, tuple): return P(*new_axes)
return tuple(switch_specs(sub_spec) for sub_spec in specs) elif isinstance(specs, tuple):
else: return tuple(switch_specs(sub_spec) for sub_spec in specs)
raise TypeError("Element must be either a PartitionSpec or a tuple") else:
raise TypeError("Element must be either a PartitionSpec or a tuple")
def fft3d(x): def fft3d(x):
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty): if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
@ -105,14 +108,15 @@ def slice_unpad_impl(x, pad_width):
# Apply corrections along y # Apply corrections along y
x = x.at[:, halo_y:halo_y + halo_y // 2].add(x[:, :halo_y // 2]) x = x.at[:, halo_y:halo_y + halo_y // 2].add(x[:, :halo_y // 2])
x = x.at[:, -(halo_y + halo_y // 2):-halo_y].add(x[:, -halo_y // 2:]) x = x.at[:, -(halo_y + halo_y // 2):-halo_y].add(x[:, -halo_y // 2:])
unpad_slice = [slice(None)] * 3 unpad_slice = [slice(None)] * 3
if halo_x > 0: if halo_x > 0:
unpad_slice[0] = slice(halo_x , -halo_x) unpad_slice[0] = slice(halo_x, -halo_x)
if halo_y > 0: if halo_y > 0:
unpad_slice[1] = slice(halo_y , -halo_y) unpad_slice[1] = slice(halo_y, -halo_y)
return x[tuple(unpad_slice)] return x[tuple(unpad_slice)]
def slice_pad(x, pad_width): def slice_pad(x, pad_width):
mesh = mesh_lib.thread_resources.env.physical_mesh mesh = mesh_lib.thread_resources.env.physical_mesh

View file

@ -1,3 +1,4 @@
from enum import Enum
from functools import partial from functools import partial
import jax.numpy as jnp import jax.numpy as jnp
@ -7,29 +8,31 @@ from jax._src import mesh as mesh_lib
from jax.sharding import PartitionSpec as P from jax.sharding import PartitionSpec as P
from jaxpm.distributed import autoshmap from jaxpm.distributed import autoshmap
from enum import Enum
class PencilType(Enum): class PencilType(Enum):
NO_DECOMP = 0 NO_DECOMP = 0
SLAB_XY = 1 SLAB_XY = 1
SLAB_YZ = 2 SLAB_YZ = 2
PENCILS = 3 PENCILS = 3
def get_pencil_type(): def get_pencil_type():
mesh = mesh_lib.thread_resources.env.physical_mesh mesh = mesh_lib.thread_resources.env.physical_mesh
if mesh.empty: if mesh.empty:
pdims = None pdims = None
else: else:
pdims = mesh.devices.shape[::-1] pdims = mesh.devices.shape[::-1]
if pdims == (1, 1) or pdims == None:
return PencilType.NO_DECOMP
elif pdims[0] == 1:
return PencilType.SLAB_XY
elif pdims[1] == 1:
return PencilType.SLAB_YZ
else:
return PencilType.PENCILS
if pdims == (1, 1) or pdims == None:
return PencilType.NO_DECOMP
elif pdims[0] == 1:
return PencilType.SLAB_XY
elif pdims[1] == 1:
return PencilType.SLAB_YZ
else:
return PencilType.PENCILS
def fftk(shape, dtype=np.float32): def fftk(shape, dtype=np.float32):
""" """
@ -46,22 +49,23 @@ def fftk(shape, dtype=np.float32):
@partial(autoshmap, @partial(autoshmap,
in_specs=(P('x'), P('y'), P(None)), in_specs=(P('x'), P('y'), P(None)),
out_specs=(P('x'), P(None, 'y'), P(None)),in_fourrier_space=True) out_specs=(P('x'), P(None, 'y'), P(None)),
in_fourrier_space=True)
def get_kvec(ky, kz, kx): def get_kvec(ky, kz, kx):
return (ky.reshape([-1, 1, 1]), return (ky.reshape([-1, 1, 1]),
kz.reshape([1, -1, 1]), kz.reshape([1, -1, 1]),
kx.reshape([1, 1, -1])) # yapf: disable kx.reshape([1, 1, -1])) # yapf: disable
pencil_type = get_pencil_type() pencil_type = get_pencil_type()
# YZ returns Y pencil # YZ returns Y pencil
# XY and pencils returns a Z pencil # XY and pencils returns a Z pencil
# NO_DECOMP returns a X pencil # NO_DECOMP returns a X pencil
if pencil_type == PencilType.NO_DECOMP: if pencil_type == PencilType.NO_DECOMP:
kx, ky, kz = get_kvec(kx, ky, kz) # Z Y X ==> X pencil kx, ky, kz = get_kvec(kx, ky, kz) # Z Y X ==> X pencil
elif pencil_type == PencilType.SLAB_YZ: elif pencil_type == PencilType.SLAB_YZ:
kz, kx, ky = get_kvec(kz, kx, ky) # X Z Y ==> Y pencil kz, kx, ky = get_kvec(kz, kx, ky) # X Z Y ==> Y pencil
elif pencil_type == PencilType.SLAB_XY or pencil_type == PencilType.PENCILS: elif pencil_type == PencilType.SLAB_XY or pencil_type == PencilType.PENCILS:
ky, kz, kx = get_kvec(ky, kz, kx) # Z X Y ==> Z pencil ky, kz, kx = get_kvec(ky, kz, kx) # Z X Y ==> Z pencil
else: else:
raise ValueError("Unknown pencil type") raise ValueError("Unknown pencil type")
@ -73,7 +77,10 @@ def interpolate_power_spectrum(input, k, pk):
pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape(-1), k, pk pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape(-1), k, pk
).reshape(x.shape) ).reshape(x.shape)
return autoshmap(pk_fn, in_specs=P('x', 'y'), out_specs=P('x', 'y'),in_fourrier_space=True)(input) return autoshmap(pk_fn,
in_specs=P('x', 'y'),
out_specs=P('x', 'y'),
in_fourrier_space=True)(input)
def gradient_kernel(kvec, direction, order=1): def gradient_kernel(kvec, direction, order=1):

View file

@ -150,7 +150,7 @@ def cic_paint_dx_impl(displacements, halo_size):
jnp.arange(particle_mesh.shape[1]), jnp.arange(particle_mesh.shape[1]),
jnp.arange(particle_mesh.shape[2]), jnp.arange(particle_mesh.shape[2]),
indexing='ij') indexing='ij')
particle_mesh = jnp.pad(particle_mesh, halo_size) particle_mesh = jnp.pad(particle_mesh, halo_size)
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1) pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
pmid = pmid.reshape([-1, 3]) pmid = pmid.reshape([-1, 3])
@ -159,13 +159,13 @@ def cic_paint_dx_impl(displacements, halo_size):
@partial(jax.jit, static_argnums=(1, )) @partial(jax.jit, static_argnums=(1, ))
def cic_paint_dx(displacements, halo_size=0): def cic_paint_dx(displacements, halo_size=0):
halo_size, halo_extents = get_halo_size(halo_size) halo_size, halo_extents = get_halo_size(halo_size)
mesh = autoshmap(partial(cic_paint_dx_impl, halo_size=halo_size), mesh = autoshmap(partial(cic_paint_dx_impl, halo_size=halo_size),
in_specs=(P('x', 'y')), in_specs=(P('x', 'y')),
out_specs=P('x', 'y'))(displacements) out_specs=P('x', 'y'))(displacements)
mesh = halo_exchange(mesh, mesh = halo_exchange(mesh,
halo_extents=halo_extents, halo_extents=halo_extents,
halo_periods=(True, True, True)) halo_periods=(True, True, True))
@ -173,19 +173,21 @@ def cic_paint_dx(displacements, halo_size=0):
return mesh return mesh
def cic_read_dx_impl(mesh , halo_size): def cic_read_dx_impl(mesh, halo_size):
halo_x, _ = halo_size[0] halo_x, _ = halo_size[0]
halo_y, _ = halo_size[1] halo_y, _ = halo_size[1]
original_shape = [dim - 2 * halo[0] for dim , halo in zip(mesh.shape, halo_size)] original_shape = [
dim - 2 * halo[0] for dim, halo in zip(mesh.shape, halo_size)
]
a, b, c = jnp.meshgrid(jnp.arange(original_shape[0]), a, b, c = jnp.meshgrid(jnp.arange(original_shape[0]),
jnp.arange(original_shape[1]), jnp.arange(original_shape[1]),
jnp.arange(original_shape[2]), jnp.arange(original_shape[2]),
indexing='ij') indexing='ij')
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1) pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
pmid = pmid.reshape([-1, 3]) pmid = pmid.reshape([-1, 3])
return gather(pmid, jnp.zeros_like(pmid), mesh).reshape(original_shape) return gather(pmid, jnp.zeros_like(pmid), mesh).reshape(original_shape)
@ -199,7 +201,7 @@ def cic_read_dx(mesh, halo_size=0):
mesh = halo_exchange(mesh, mesh = halo_exchange(mesh,
halo_extents=halo_extents, halo_extents=halo_extents,
halo_periods=(True, True, True)) halo_periods=(True, True, True))
displacements = autoshmap(partial(cic_read_dx_impl , halo_size=halo_size), displacements = autoshmap(partial(cic_read_dx_impl, halo_size=halo_size),
in_specs=(P('x', 'y')), in_specs=(P('x', 'y')),
out_specs=P('x', 'y'))(mesh) out_specs=P('x', 'y'))(mesh)

View file

@ -19,10 +19,11 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0, halo_size=0):
Computes gravitational forces on particles using a PM scheme Computes gravitational forces on particles using a PM scheme
""" """
if mesh_shape is None: if mesh_shape is None:
assert(delta is not None) , "If mesh_shape is not provided, delta should be provided" assert (delta is not None
), "If mesh_shape is not provided, delta should be provided"
mesh_shape = delta.shape mesh_shape = delta.shape
kvec = fftk(mesh_shape) kvec = fftk(mesh_shape)
if delta is None: if delta is None:
delta_k = fft3d(cic_paint_dx(positions, halo_size=halo_size)) delta_k = fft3d(cic_paint_dx(positions, halo_size=halo_size))
else: else:
@ -33,8 +34,8 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0, halo_size=0):
r_split=r_split) r_split=r_split)
# Computes gravitational forces # Computes gravitational forces
forces = jnp.stack([ forces = jnp.stack([
cic_read_dx(ifft3d(gradient_kernel(kvec, i) * pot_k), halo_size=halo_size) cic_read_dx(ifft3d(gradient_kernel(kvec, i) * pot_k),
for i in range(3) halo_size=halo_size) for i in range(3)
], ],
axis=-1) axis=-1)

View file

@ -47,7 +47,6 @@ def run_simulation(omega_c, sigma8):
pk_fn, pk_fn,
seed=jax.random.PRNGKey(0)) seed=jax.random.PRNGKey(0))
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8) cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
# Initial displacement # Initial displacement

View file

@ -1,23 +1,23 @@
#!/bin/bash #!/bin/bash
########################################## ##########################################
## SELECT EITHER tkc@a100 OR tkc@v100 ## ## SELECT EITHER tkc@a100 OR tkc@v100 ##
########################################## ##########################################
#SBATCH --account tkc@a100 #SBATCH --account tkc@a100
########################################## ##########################################
#SBATCH --job-name=Particle-Mesh # nom du job #SBATCH --job-name=Particle-Mesh # nom du job
# Il est possible d'utiliser une autre partition que celle par default # Il est possible d'utiliser une autre partition que celle par default
# en activant l'une des 5 directives suivantes : # en activant l'une des 5 directives suivantes :
########################################## ##########################################
## SELECT EITHER a100 or v100-32g ## ## SELECT EITHER a100 or v100-32g ##
########################################## ##########################################
#SBATCH -C a100 #SBATCH -C a100
########################################## ##########################################
#****************************************** #******************************************
########################################## ##########################################
## SELECT Number of nodes and GPUs per node ## SELECT Number of nodes and GPUs per node
## For A100 ntasks-per-node and gres=gpu should be 8 ## For A100 ntasks-per-node and gres=gpu should be 8
## For V100 ntasks-per-node and gres=gpu should be 4 ## For V100 ntasks-per-node and gres=gpu should be 4
########################################## ##########################################
#SBATCH --nodes=1 # nombre de noeud #SBATCH --nodes=1 # nombre de noeud
#SBATCH --ntasks-per-node=8 # nombre de tache MPI par noeud (= nombre de GPU par noeud) #SBATCH --ntasks-per-node=8 # nombre de tache MPI par noeud (= nombre de GPU par noeud)
#SBATCH --gres=gpu:8 # nombre de GPU par nœud (max 8 avec gpu_p2, gpu_p5) #SBATCH --gres=gpu:8 # nombre de GPU par nœud (max 8 avec gpu_p2, gpu_p5)
@ -57,7 +57,7 @@ fi
# Chargement des modules # Chargement des modules
module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake
module load nvidia-nsight-systems/2024.1.1.59 module load nvidia-nsight-systems/2024.1.1.59
echo "The number of nodes allocated for this job is: $num_nodes" echo "The number of nodes allocated for this job is: $num_nodes"
echo "The number of GPUs allocated for this job is: $nb_gpus" echo "The number of GPUs allocated for this job is: $nb_gpus"
@ -116,7 +116,7 @@ set -x
# Pour la partition "gpu_p5", le code doit etre compile avec les modules compatibles # Pour la partition "gpu_p5", le code doit etre compile avec les modules compatibles
# Execution du code avec binding via bind_gpu.sh : 1 GPU par tache # Execution du code avec binding via bind_gpu.sh : 1 GPU par tache