From ece8c935403c0083d82617b1ca47128059c7b07f Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Sat, 3 Aug 2024 00:23:40 +0200 Subject: [PATCH] format --- benchmarks/bench_pm.py | 76 ++++++++------ benchmarks/bench_pmwd.py | 150 +++++++++++++++++----------- benchmarks/particle_mesh_a100.slurm | 4 - benchmarks/particle_mesh_v100.slurm | 3 - benchmarks/pmwd_a100.slurm | 5 +- benchmarks/pmwd_v100.slurm | 5 +- jaxpm/distributed.py | 30 +++--- jaxpm/kernels.py | 55 +++++----- jaxpm/painting.py | 18 ++-- jaxpm/pm.py | 9 +- scripts/distributed_pm.py | 1 - scripts/particle_mesh.slurm | 24 ++--- 12 files changed, 210 insertions(+), 170 deletions(-) diff --git a/benchmarks/bench_pm.py b/benchmarks/bench_pm.py index 9b916f9..2bf4534 100644 --- a/benchmarks/bench_pm.py +++ b/benchmarks/bench_pm.py @@ -10,13 +10,14 @@ size = jax.process_count() import argparse import time -from hpc_plotter.timer import Timer + import jax.numpy as jnp import jax_cosmo as jc import numpy as np from cupy.cuda.nvtx import RangePop, RangePush from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm, PIDController, SaveAt, Tsit5, diffeqsolve) +from hpc_plotter.timer import Timer from jax.experimental import mesh_utils from jax.experimental.multihost_utils import sync_global_devices from jax.sharding import Mesh, NamedSharding @@ -27,7 +28,6 @@ from jaxpm.painting import cic_paint_dx from jaxpm.pm import linear_field, lpt, make_ode_fn - def run_simulation(mesh_shape, box_size, halo_size, @@ -69,7 +69,7 @@ def run_simulation(mesh_shape, ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size) term = ODETerm( lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0)) - + if solver_choice == "Dopri5" or solver_choice == "Tsit5": stepsize_controller = PIDController(rtol=1e-4, atol=1e-4) elif solver_choice == "LeapfrogMidpoint" or solver_choice == "Euler": @@ -94,12 +94,18 @@ def run_simulation(mesh_shape, # Warm start chrono_fun = Timer() RangePush("warmup") - final_field, stats = chrono_fun.chrono_jit(simulate, 0.32, 0.8 , ndarray_arg = 0) + final_field, stats = chrono_fun.chrono_jit(simulate, + 0.32, + 0.8, + ndarray_arg=0) RangePop() sync_global_devices("warmup") for i in range(iterations): RangePush(f"sim iter {i}") - final_field, stats = chrono_fun.chrono_fun(simulate, 0.32, 0.8 , ndarray_arg = 0) + final_field, stats = chrono_fun.chrono_fun(simulate, + 0.32, + 0.8, + ndarray_arg=0) RangePop() return final_field, stats, chrono_fun @@ -134,11 +140,13 @@ if __name__ == "__main__": type=str, help='Processor dimensions', default=None) - parser.add_argument('-pr', - '--precision', - type=str, - help='Precision', - choices=["float32", "float64"],) + parser.add_argument( + '-pr', + '--precision', + type=str, + help='Precision', + choices=["float32", "float64"], + ) parser.add_argument('-hs', '--halo_size', type=int, @@ -173,7 +181,7 @@ if __name__ == "__main__": type=int, help='Number of nodes', default=1) - + args = parser.parse_args() mesh_size = args.mesh_size box_size = [args.box_size] * 3 @@ -182,14 +190,14 @@ if __name__ == "__main__": iterations = args.iterations output_path = args.output_path os.makedirs(output_path, exist_ok=True) - + print(f"solver choice: {solver_choice}") match solver_choice: - case "Dopri5" | "dopri5"| "d5": + case "Dopri5" | "dopri5" | "d5": solver_choice = "Dopri5" - case "Tsit5"| "tsit5"| "t5": + case "Tsit5" | "tsit5" | "t5": solver_choice = "Tsit5" - case "LeapfrogMidpoint"| "leapfrogmidpoint"| "lfm": + case "LeapfrogMidpoint" | "leapfrogmidpoint" | "lfm": solver_choice = "LeapfrogMidpoint" case "lpt": solver_choice = "lpt" @@ -199,7 +207,7 @@ if __name__ == "__main__": ) if args.precision == "float32": jax.config.update("jax_enable_x64", False) - elif args.precision == "float64": + elif args.precision == "float64": jax.config.update("jax_enable_x64", True) if args.pdims: @@ -209,22 +217,26 @@ if __name__ == "__main__": mesh_shape = [mesh_size] * 3 - final_field , stats, chrono_fun = run_simulation(mesh_shape, box_size, halo_size, solver_choice, iterations, pdims) - - print(f"shape of final_field {final_field.shape} and sharding spec {final_field.sharding} and local shape {final_field.addressable_data(0).shape}") + final_field, stats, chrono_fun = run_simulation(mesh_shape, box_size, + halo_size, solver_choice, + iterations, pdims) + + print( + f"shape of final_field {final_field.shape} and sharding spec {final_field.sharding} and local shape {final_field.addressable_data(0).shape}" + ) metadata = { - 'rank': rank, - 'function_name': f'JAXPM-{solver_choice}', - 'precision': args.precision, - 'x': str(mesh_size), - 'y': str(mesh_size), - 'z': str(stats["num_steps"]), - 'px': str(pdims[0]), - 'py': str(pdims[1]), - 'backend': 'NCCL', - 'nodes': str(args.nodes) - } + 'rank': rank, + 'function_name': f'JAXPM-{solver_choice}', + 'precision': args.precision, + 'x': str(mesh_size), + 'y': str(mesh_size), + 'z': str(stats["num_steps"]), + 'px': str(pdims[0]), + 'py': str(pdims[1]), + 'backend': 'NCCL', + 'nodes': str(args.nodes) + } # Print the results to a CSV file chrono_fun.print_to_csv(f'{output_path}/jaxpm_benchmark.csv', **metadata) @@ -236,8 +248,8 @@ if __name__ == "__main__": with open(f'{field_folder}/jaxpm.log', 'w') as f: f.write(f"Args: {args}\n") f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n") - for i , time in enumerate(chrono_fun.times): - f.write(f"Time {i}: {time:.4f} ms\n") + for i, time in enumerate(chrono_fun.times): + f.write(f"Time {i}: {time:.4f} ms\n") f.write(f"Stats: {stats}\n") if args.save_fields: np.save(f'{field_folder}/final_field_0_{rank}.npy', diff --git a/benchmarks/bench_pmwd.py b/benchmarks/bench_pmwd.py index 8f93a3e..bd11303 100644 --- a/benchmarks/bench_pmwd.py +++ b/benchmarks/bench_pmwd.py @@ -3,34 +3,41 @@ import os # Change JAX GPU memory preallocation fraction os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.95' -import jax import argparse -import numpy as np + +import jax import matplotlib.pyplot as plt -from pmwd import ( - Configuration, - Cosmology, SimpleLCDM, - boltzmann, linear_power, growth, - white_noise, linear_modes, - lpt, nbody, scatter -) +import numpy as np +from hpc_plotter.timer import Timer +from pmwd import (Configuration, Cosmology, SimpleLCDM, boltzmann, growth, + linear_modes, linear_power, lpt, nbody, scatter, white_noise) from pmwd.pm_util import fftinv from pmwd.spec_util import powspec from pmwd.vis_util import simshow -from hpc_plotter.timer import Timer + # Simulation configuration -def run_pmwd_simulation(ptcl_grid_shape, ptcl_spacing, solver , iterations): +def run_pmwd_simulation(ptcl_grid_shape, ptcl_spacing, solver, iterations): @jax.jit def simulate(omega_m, sigma8): - - - conf = Configuration(ptcl_spacing, ptcl_grid_shape=ptcl_grid_shape, mesh_shape=1,lpt_order=1,a_nbody_maxstep=1/91) - print(conf) - print(f'Simulating {conf.ptcl_num} particles with a {conf.mesh_shape} mesh for {conf.a_nbody_num} time steps.') - cosmo = Cosmology(conf, A_s_1e9=2.0, n_s=0.96, Omega_m=omega_m, Omega_b=sigma8, h=0.7) + conf = Configuration(ptcl_spacing, + ptcl_grid_shape=ptcl_grid_shape, + mesh_shape=1, + lpt_order=1, + a_nbody_maxstep=1 / 91) + print(conf) + print( + f'Simulating {conf.ptcl_num} particles with a {conf.mesh_shape} mesh for {conf.a_nbody_num} time steps.' + ) + + cosmo = Cosmology(conf, + A_s_1e9=2.0, + n_s=0.96, + Omega_m=omega_m, + Omega_b=sigma8, + h=0.7) print(cosmo) # Boltzmann calculation @@ -46,71 +53,95 @@ def run_pmwd_simulation(ptcl_grid_shape, ptcl_spacing, solver , iterations): # Solve LPT at some early time ptcl, obsvbl = lpt(modes, cosmo, conf) print("LPT solved.") - + if solver == "lfm": - # N-body time integration from LPT initial conditions - ptcl, obsvbl = jax.block_until_ready(nbody(ptcl, obsvbl, cosmo, conf)) - print("N-body time integration completed.") + # N-body time integration from LPT initial conditions + ptcl, obsvbl = jax.block_until_ready( + nbody(ptcl, obsvbl, cosmo, conf)) + print("N-body time integration completed.") # Scatter particles to mesh to get the density field dens = scatter(ptcl, conf) return dens - + chrono_timer = Timer() final_field = chrono_timer.chrono_jit(simulate, 0.3, 0.05) - + for _ in range(iterations): final_field = chrono_timer.chrono_fun(simulate, 0.3, 0.05) - return final_field , chrono_timer - + return final_field, chrono_timer + if __name__ == "__main__": parser = argparse.ArgumentParser(description='PMWD Simulation') - parser.add_argument('-m', '--mesh_size', type=int, help='Mesh size', required=True) - parser.add_argument('-b', '--box_size', type=float, help='Box size', required=True) - parser.add_argument('-i', '--iterations', type=int, help='Number of iterations', default=10) - parser.add_argument('-o', '--output_path', type=str, help='Output path', default=".") - parser.add_argument('-f', '--save_fields', action='store_true', help='Save fields') - parser.add_argument('-s', '--solver', type=str, help='Solver', choices=["lfm" , "lpt"]) - parser.add_argument('-pr', - '--precision', - type=str, - help='Precision', - choices=["float32", "float64"],) - + parser.add_argument('-m', + '--mesh_size', + type=int, + help='Mesh size', + required=True) + parser.add_argument('-b', + '--box_size', + type=float, + help='Box size', + required=True) + parser.add_argument('-i', + '--iterations', + type=int, + help='Number of iterations', + default=10) + parser.add_argument('-o', + '--output_path', + type=str, + help='Output path', + default=".") + parser.add_argument('-f', + '--save_fields', + action='store_true', + help='Save fields') + parser.add_argument('-s', + '--solver', + type=str, + help='Solver', + choices=["lfm", "lpt"]) + parser.add_argument( + '-pr', + '--precision', + type=str, + help='Precision', + choices=["float32", "float64"], + ) args = parser.parse_args() - + mesh_shape = [args.mesh_size] * 3 - ptcl_spacing = args.box_size /args.mesh_size + ptcl_spacing = args.box_size / args.mesh_size iterations = args.iterations solver = args.solver output_path = args.output_path if args.precision == "float32": jax.config.update("jax_enable_x64", False) - elif args.precision == "float64": + elif args.precision == "float64": jax.config.update("jax_enable_x64", True) - os.makedirs(output_path, exist_ok=True) - - final_field , chrono_fun = run_pmwd_simulation(mesh_shape, ptcl_spacing, solver, iterations) + + final_field, chrono_fun = run_pmwd_simulation(mesh_shape, ptcl_spacing, + solver, iterations) print("PMWD simulation completed.") - metadata = { - 'rank': 0, - 'function_name': f'PMWD-{solver}', - 'precision': args.precision, - 'x': str(mesh_shape[0]), - 'y': str(mesh_shape[0]), - 'z': str(mesh_shape[0]), - 'px': "1", - 'py': "1", - 'backend': 'NCCL', - 'nodes': "1" - } + 'rank': 0, + 'function_name': f'PMWD-{solver}', + 'precision': args.precision, + 'x': str(mesh_shape[0]), + 'y': str(mesh_shape[0]), + 'z': str(mesh_shape[0]), + 'px': "1", + 'py': "1", + 'backend': 'NCCL', + 'nodes': "1" + } chrono_fun.print_to_csv(f"{output_path}/pmwd.csv", **metadata) field_folder = f"{output_path}/final_field/pmwd/1/{args.mesh_size}_{int(args.box_size)}/1x1/{args.solver}/halo_0" os.makedirs(field_folder, exist_ok=True) @@ -118,14 +149,11 @@ if __name__ == "__main__": f.write(f"PMWD simulation completed.\n") f.write(f"Args : {args}\n") f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n") - for i , time in enumerate(chrono_fun.times): - f.write(f"Time {i}: {time:.4f} ms\n") + for i, time in enumerate(chrono_fun.times): + f.write(f"Time {i}: {time:.4f} ms\n") if args.save_fields: np.save(f"{field_folder}/final_field_0_0.npy", final_field) print("Fields saved.") - - + print(f"saving to {output_path}/pmwd.csv") print(f"saving field and logs to {field_folder}/pmwd.log") - - diff --git a/benchmarks/particle_mesh_a100.slurm b/benchmarks/particle_mesh_a100.slurm index 65930b2..a94019c 100644 --- a/benchmarks/particle_mesh_a100.slurm +++ b/benchmarks/particle_mesh_a100.slurm @@ -177,7 +177,3 @@ for pr in "${precisions[@]}"; do done done done - - - - diff --git a/benchmarks/particle_mesh_v100.slurm b/benchmarks/particle_mesh_v100.slurm index 9eeb610..9446b9b 100644 --- a/benchmarks/particle_mesh_v100.slurm +++ b/benchmarks/particle_mesh_v100.slurm @@ -179,6 +179,3 @@ for pr in "${precisions[@]}"; do done done done - - - diff --git a/benchmarks/pmwd_a100.slurm b/benchmarks/pmwd_a100.slurm index 99c64c7..f57f0ac 100644 --- a/benchmarks/pmwd_a100.slurm +++ b/benchmarks/pmwd_a100.slurm @@ -156,10 +156,7 @@ echo "Output dir is : $out_dir" for pr in "${precisions[@]}"; do for g in "${grid[@]}"; do for solver in "${solvers[@]}"; do - launch bench_pmwd.py -m $g -b $g -p $p -pr $pr -s $solver -i 4 -o $out_dir -f + launch bench_pmwd.py -m $g -b $g -p $p -pr $pr -s $solver -i 4 -o $out_dir -f done done done - - - diff --git a/benchmarks/pmwd_v100.slurm b/benchmarks/pmwd_v100.slurm index 4c58db5..9ca5f89 100644 --- a/benchmarks/pmwd_v100.slurm +++ b/benchmarks/pmwd_v100.slurm @@ -161,10 +161,7 @@ echo "Output dir is : $out_dir" for pr in "${precisions[@]}"; do for g in "${grid[@]}"; do for solver in "${solvers[@]}"; do - slaunch bench_pmwd.py -m $g -b $g -pr $pr -s $solver -i 4 -o $out_dir -f + slaunch bench_pmwd.py -m $g -b $g -pr $pr -s $solver -i 4 -o $out_dir -f done done done - - - diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index 9fb0e15..54377cb 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -44,17 +44,20 @@ def autoshmap(f: Callable, return f else: if in_fourrier_space and 1 in mesh.devices.shape: - in_specs , out_specs = switch_specs((in_specs , out_specs)) + in_specs, out_specs = switch_specs((in_specs, out_specs)) return shard_map(f, mesh, in_specs, out_specs, check_rep, auto) + def switch_specs(specs): - if isinstance(specs, P): - new_axes = tuple('y' if ax == 'x' else 'x' if ax == 'y' else ax for ax in specs) - return P(*new_axes) - elif isinstance(specs, tuple): - return tuple(switch_specs(sub_spec) for sub_spec in specs) - else: - raise TypeError("Element must be either a PartitionSpec or a tuple") + if isinstance(specs, P): + new_axes = tuple('y' if ax == 'x' else 'x' if ax == 'y' else ax + for ax in specs) + return P(*new_axes) + elif isinstance(specs, tuple): + return tuple(switch_specs(sub_spec) for sub_spec in specs) + else: + raise TypeError("Element must be either a PartitionSpec or a tuple") + def fft3d(x): if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty): @@ -105,14 +108,15 @@ def slice_unpad_impl(x, pad_width): # Apply corrections along y x = x.at[:, halo_y:halo_y + halo_y // 2].add(x[:, :halo_y // 2]) x = x.at[:, -(halo_y + halo_y // 2):-halo_y].add(x[:, -halo_y // 2:]) - + unpad_slice = [slice(None)] * 3 if halo_x > 0: - unpad_slice[0] = slice(halo_x , -halo_x) + unpad_slice[0] = slice(halo_x, -halo_x) if halo_y > 0: - unpad_slice[1] = slice(halo_y , -halo_y) - - return x[tuple(unpad_slice)] + unpad_slice[1] = slice(halo_y, -halo_y) + + return x[tuple(unpad_slice)] + def slice_pad(x, pad_width): mesh = mesh_lib.thread_resources.env.physical_mesh diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py index bfb7e7e..d954132 100644 --- a/jaxpm/kernels.py +++ b/jaxpm/kernels.py @@ -1,3 +1,4 @@ +from enum import Enum from functools import partial import jax.numpy as jnp @@ -7,29 +8,31 @@ from jax._src import mesh as mesh_lib from jax.sharding import PartitionSpec as P from jaxpm.distributed import autoshmap -from enum import Enum + class PencilType(Enum): - NO_DECOMP = 0 - SLAB_XY = 1 - SLAB_YZ = 2 - PENCILS = 3 + NO_DECOMP = 0 + SLAB_XY = 1 + SLAB_YZ = 2 + PENCILS = 3 + def get_pencil_type(): - mesh = mesh_lib.thread_resources.env.physical_mesh - if mesh.empty: - pdims = None - else: - pdims = mesh.devices.shape[::-1] + mesh = mesh_lib.thread_resources.env.physical_mesh + if mesh.empty: + pdims = None + else: + pdims = mesh.devices.shape[::-1] + + if pdims == (1, 1) or pdims == None: + return PencilType.NO_DECOMP + elif pdims[0] == 1: + return PencilType.SLAB_XY + elif pdims[1] == 1: + return PencilType.SLAB_YZ + else: + return PencilType.PENCILS - if pdims == (1, 1) or pdims == None: - return PencilType.NO_DECOMP - elif pdims[0] == 1: - return PencilType.SLAB_XY - elif pdims[1] == 1: - return PencilType.SLAB_YZ - else: - return PencilType.PENCILS def fftk(shape, dtype=np.float32): """ @@ -46,22 +49,23 @@ def fftk(shape, dtype=np.float32): @partial(autoshmap, in_specs=(P('x'), P('y'), P(None)), - out_specs=(P('x'), P(None, 'y'), P(None)),in_fourrier_space=True) + out_specs=(P('x'), P(None, 'y'), P(None)), + in_fourrier_space=True) def get_kvec(ky, kz, kx): return (ky.reshape([-1, 1, 1]), kz.reshape([1, -1, 1]), kx.reshape([1, 1, -1])) # yapf: disable - pencil_type = get_pencil_type() + pencil_type = get_pencil_type() # YZ returns Y pencil # XY and pencils returns a Z pencil # NO_DECOMP returns a X pencil if pencil_type == PencilType.NO_DECOMP: - kx, ky, kz = get_kvec(kx, ky, kz) # Z Y X ==> X pencil + kx, ky, kz = get_kvec(kx, ky, kz) # Z Y X ==> X pencil elif pencil_type == PencilType.SLAB_YZ: - kz, kx, ky = get_kvec(kz, kx, ky) # X Z Y ==> Y pencil + kz, kx, ky = get_kvec(kz, kx, ky) # X Z Y ==> Y pencil elif pencil_type == PencilType.SLAB_XY or pencil_type == PencilType.PENCILS: - ky, kz, kx = get_kvec(ky, kz, kx) # Z X Y ==> Z pencil + ky, kz, kx = get_kvec(ky, kz, kx) # Z X Y ==> Z pencil else: raise ValueError("Unknown pencil type") @@ -73,7 +77,10 @@ def interpolate_power_spectrum(input, k, pk): pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape(-1), k, pk ).reshape(x.shape) - return autoshmap(pk_fn, in_specs=P('x', 'y'), out_specs=P('x', 'y'),in_fourrier_space=True)(input) + return autoshmap(pk_fn, + in_specs=P('x', 'y'), + out_specs=P('x', 'y'), + in_fourrier_space=True)(input) def gradient_kernel(kvec, direction, order=1): diff --git a/jaxpm/painting.py b/jaxpm/painting.py index 7160913..975e43c 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -150,7 +150,7 @@ def cic_paint_dx_impl(displacements, halo_size): jnp.arange(particle_mesh.shape[1]), jnp.arange(particle_mesh.shape[2]), indexing='ij') - + particle_mesh = jnp.pad(particle_mesh, halo_size) pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1) pmid = pmid.reshape([-1, 3]) @@ -159,13 +159,13 @@ def cic_paint_dx_impl(displacements, halo_size): @partial(jax.jit, static_argnums=(1, )) def cic_paint_dx(displacements, halo_size=0): - + halo_size, halo_extents = get_halo_size(halo_size) - + mesh = autoshmap(partial(cic_paint_dx_impl, halo_size=halo_size), in_specs=(P('x', 'y')), out_specs=P('x', 'y'))(displacements) - + mesh = halo_exchange(mesh, halo_extents=halo_extents, halo_periods=(True, True, True)) @@ -173,19 +173,21 @@ def cic_paint_dx(displacements, halo_size=0): return mesh -def cic_read_dx_impl(mesh , halo_size): +def cic_read_dx_impl(mesh, halo_size): halo_x, _ = halo_size[0] halo_y, _ = halo_size[1] - original_shape = [dim - 2 * halo[0] for dim , halo in zip(mesh.shape, halo_size)] + original_shape = [ + dim - 2 * halo[0] for dim, halo in zip(mesh.shape, halo_size) + ] a, b, c = jnp.meshgrid(jnp.arange(original_shape[0]), jnp.arange(original_shape[1]), jnp.arange(original_shape[2]), indexing='ij') pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1) - + pmid = pmid.reshape([-1, 3]) return gather(pmid, jnp.zeros_like(pmid), mesh).reshape(original_shape) @@ -199,7 +201,7 @@ def cic_read_dx(mesh, halo_size=0): mesh = halo_exchange(mesh, halo_extents=halo_extents, halo_periods=(True, True, True)) - displacements = autoshmap(partial(cic_read_dx_impl , halo_size=halo_size), + displacements = autoshmap(partial(cic_read_dx_impl, halo_size=halo_size), in_specs=(P('x', 'y')), out_specs=P('x', 'y'))(mesh) diff --git a/jaxpm/pm.py b/jaxpm/pm.py index 3055058..79080df 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -19,10 +19,11 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0, halo_size=0): Computes gravitational forces on particles using a PM scheme """ if mesh_shape is None: - assert(delta is not None) , "If mesh_shape is not provided, delta should be provided" + assert (delta is not None + ), "If mesh_shape is not provided, delta should be provided" mesh_shape = delta.shape kvec = fftk(mesh_shape) - + if delta is None: delta_k = fft3d(cic_paint_dx(positions, halo_size=halo_size)) else: @@ -33,8 +34,8 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0, halo_size=0): r_split=r_split) # Computes gravitational forces forces = jnp.stack([ - cic_read_dx(ifft3d(gradient_kernel(kvec, i) * pot_k), halo_size=halo_size) - for i in range(3) + cic_read_dx(ifft3d(gradient_kernel(kvec, i) * pot_k), + halo_size=halo_size) for i in range(3) ], axis=-1) diff --git a/scripts/distributed_pm.py b/scripts/distributed_pm.py index 5f41af4..ad699c6 100644 --- a/scripts/distributed_pm.py +++ b/scripts/distributed_pm.py @@ -47,7 +47,6 @@ def run_simulation(omega_c, sigma8): pk_fn, seed=jax.random.PRNGKey(0)) - cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8) # Initial displacement diff --git a/scripts/particle_mesh.slurm b/scripts/particle_mesh.slurm index 51332a5..2585d5d 100644 --- a/scripts/particle_mesh.slurm +++ b/scripts/particle_mesh.slurm @@ -1,23 +1,23 @@ #!/bin/bash -########################################## -## SELECT EITHER tkc@a100 OR tkc@v100 ## -########################################## +########################################## +## SELECT EITHER tkc@a100 OR tkc@v100 ## +########################################## #SBATCH --account tkc@a100 -########################################## +########################################## #SBATCH --job-name=Particle-Mesh # nom du job # Il est possible d'utiliser une autre partition que celle par default # en activant l'une des 5 directives suivantes : -########################################## -## SELECT EITHER a100 or v100-32g ## -########################################## +########################################## +## SELECT EITHER a100 or v100-32g ## +########################################## #SBATCH -C a100 -########################################## +########################################## #****************************************** -########################################## +########################################## ## SELECT Number of nodes and GPUs per node ## For A100 ntasks-per-node and gres=gpu should be 8 ## For V100 ntasks-per-node and gres=gpu should be 4 -########################################## +########################################## #SBATCH --nodes=1 # nombre de noeud #SBATCH --ntasks-per-node=8 # nombre de tache MPI par noeud (= nombre de GPU par noeud) #SBATCH --gres=gpu:8 # nombre de GPU par nœud (max 8 avec gpu_p2, gpu_p5) @@ -57,7 +57,7 @@ fi # Chargement des modules module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake -module load nvidia-nsight-systems/2024.1.1.59 +module load nvidia-nsight-systems/2024.1.1.59 echo "The number of nodes allocated for this job is: $num_nodes" echo "The number of GPUs allocated for this job is: $nb_gpus" @@ -116,7 +116,7 @@ set -x # Pour la partition "gpu_p5", le code doit etre compile avec les modules compatibles # Execution du code avec binding via bind_gpu.sh : 1 GPU par tache - +