diff --git a/benchmarks/particle_mesh.slurm b/benchmarks/particle_mesh.slurm index 330fa1a..7e60678 100644 --- a/benchmarks/particle_mesh.slurm +++ b/benchmarks/particle_mesh.slurm @@ -1,6 +1,6 @@ #!/bin/bash ############################################################################################################################## -# USAGE:sbatch --account=tkc@a100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C a100 benchmarks/particle_mesh_a100.slurm +# USAGE:sbatch --account=tkc@a100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C a100 benchmarks/particle_mesh_a100.slurm ############################################################################################################################## #SBATCH --job-name=Particle-Mesh # nom du job #SBATCH --cpus-per-task=8 # nombre de CPU par tache pour gpu_p5 (1/8 du noeud 8-GPU) @@ -140,7 +140,7 @@ fi # GPU name is a100 if num_gpu_per_node is 8, otherwise it is v100 out_dir="pm_prof/$gpu_name/$nb_gpus" trace_dir="traces/$gpu_name/$nb_gpus/bench_pm" -echo "Output dir is : $out_dir" +echo "Output dir is : $out_dir" echo "Trace dir is : $trace_dir" for pr in "${precisions[@]}"; do diff --git a/benchmarks/pmwd_pm.slurm b/benchmarks/pmwd_pm.slurm index 171a7b3..4171a51 100644 --- a/benchmarks/pmwd_pm.slurm +++ b/benchmarks/pmwd_pm.slurm @@ -1,6 +1,6 @@ #!/bin/bash ############################################################################################################################## -# USAGE:sbatch --account=tkc@a100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C a100 benchmarks/particle_mesh_a100.slurm +# USAGE:sbatch --account=tkc@a100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C a100 benchmarks/particle_mesh_a100.slurm ############################################################################################################################## #SBATCH --job-name=Particle-Mesh # nom du job #SBATCH --cpus-per-task=8 # nombre de CPU par tache pour gpu_p5 (1/8 du noeud 8-GPU) @@ -126,7 +126,7 @@ fi out_dir="pm_prof/$gpu_name/$nb_gpus" trace_dir="traces/$gpu_name/$nb_gpus/bench_pmwd" -echo "Output dir is : $out_dir" +echo "Output dir is : $out_dir" echo "Trace dir is : $trace_dir" for pr in "${precisions[@]}"; do diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py index b4303fc..a7da0ee 100644 --- a/jaxpm/kernels.py +++ b/jaxpm/kernels.py @@ -43,7 +43,7 @@ def interpolate_power_spectrum(input, k, pk, sharding=None): def gradient_kernel(kvec, direction, order=1): """ Computes the gradient kernel in the requested direction - + Parameters ----------- kvec: list @@ -84,8 +84,8 @@ def invlaplace_kernel(kvec): Complex kernel values """ kk = sum(ki**2 for ki in kvec) - kk_nozeros = jnp.where(kk==0, 1, kk) - return - jnp.where(kk==0, 0, 1 / kk_nozeros) + kk_nozeros = jnp.where(kk == 0, 1, kk) + return -jnp.where(kk == 0, 0, 1 / kk_nozeros) def longrange_kernel(kvec, r_split): @@ -98,12 +98,12 @@ def longrange_kernel(kvec, r_split): List of wave-vectors r_split: float Splitting radius - + Returns -------- wts: array Complex kernel values - + TODO: @modichirag add documentation """ if r_split != 0: @@ -124,7 +124,7 @@ def cic_compensation(kvec): ----------- kvec: list List of wave-vectors - + Returns: -------- wts: array diff --git a/jaxpm/pm.py b/jaxpm/pm.py index 3b18d48..3155467 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -9,8 +9,8 @@ from jaxpm.distributed import (autoshmap, fft3d, get_local_shape, ifft3d, normal_field) from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second, growth_rate, growth_rate_second) -from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, invlaplace_kernel, - longrange_kernel) +from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, + invlaplace_kernel, longrange_kernel) from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx @@ -38,11 +38,11 @@ def pm_forces(positions, kvec = fftk(delta_k) # Computes gravitational potential - pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec, - r_split=r_split) + pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel( + kvec, r_split=r_split) # Computes gravitational forces forces = jnp.stack([ - cic_read_dx(ifft3d( - gradient_kernel(kvec, i) * pot_k), + cic_read_dx(ifft3d(-gradient_kernel(kvec, i) * pot_k), halo_size=halo_size, sharding=sharding) for i in range(3) ], @@ -51,9 +51,9 @@ def pm_forces(positions, return forces -def lpt(cosmo, initial_conditions, a, halo_size=0, sharding=None,order=1): +def lpt(cosmo, initial_conditions, a, halo_size=0, sharding=None, order=1): """ - Computes first and second order LPT displacement and momentum, + Computes first and second order LPT displacement and momentum, e.g. Eq. 2 and 3 [Jenkins2010](https://arxiv.org/pdf/0910.0258) """ gpu_mesh = sharding.mesh if sharding is not None else None @@ -68,7 +68,7 @@ def lpt(cosmo, initial_conditions, a, halo_size=0, sharding=None,order=1): a = jnp.atleast_1d(a) - E = jnp.sqrt(jc.background.Esqr(cosmo, a)) + E = jnp.sqrt(jc.background.Esqr(cosmo, a)) delta_k = fft3d(initial_conditions) initial_force = pm_forces(displacement, delta=delta_k, @@ -76,7 +76,7 @@ def lpt(cosmo, initial_conditions, a, halo_size=0, sharding=None,order=1): sharding=sharding) dx = growth_factor(cosmo, a) * initial_force p = a**2 * growth_rate(cosmo, a) * E * dx - f = a**2 * E * dGfa(cosmo,a) * initial_force + f = a**2 * E * dGfa(cosmo, a) * initial_force if order == 2: kvec = fftk(delta_k) pot_k = delta_k * invlaplace_kernel(kvec) @@ -89,26 +89,30 @@ def lpt(cosmo, initial_conditions, a, halo_size=0, sharding=None,order=1): # shear_ii = jnp.fft.irfftn(- ki**2 * pot_k) nabla_i_nabla_i = gradient_kernel(kvec, i)**2 shear_ii = jnp.fft.irfftn(nabla_i_nabla_i * pot_k) - delta2 += shear_ii * shear_acc + delta2 += shear_ii * shear_acc shear_acc += shear_ii # for kj in kvec[i+1:]: - for j in range(i+1, 3): + for j in range(i + 1, 3): # Substract squared strict-up-triangle terms # delta2 -= jnp.fft.irfftn(- ki * kj * pot_k)**2 - nabla_i_nabla_j = gradient_kernel(kvec, i) * gradient_kernel(kvec, j) + nabla_i_nabla_j = gradient_kernel(kvec, i) * gradient_kernel( + kvec, j) delta2 -= jnp.fft.irfftn(nabla_i_nabla_j * pot_k)**2 - + delta_k2 = fft3d(delta2) - init_force2 = pm_forces(displacement, delta=delta_k2,halo_size=halo_size,sharding=sharding) + init_force2 = pm_forces(displacement, + delta=delta_k2, + halo_size=halo_size, + sharding=sharding) # NOTE: growth_factor_second is renormalized: - D2 = 3/7 * growth_factor_second - dx2 = 3/7 * growth_factor_second(cosmo, a) * init_force2 + dx2 = 3 / 7 * growth_factor_second(cosmo, a) * init_force2 p2 = a**2 * growth_rate_second(cosmo, a) * E * dx2 f2 = a**2 * E * dGf2a(cosmo, a) * init_force2 dx += dx2 - p += p2 - f += f2 + p += p2 + f += f2 return dx, p, f @@ -153,6 +157,7 @@ def make_ode_fn(mesh_shape, halo_size=0, sharding=None): return nbody_ode + def get_ode_fn(cosmo, mesh_shape, halo_size=0, sharding=None): def nbody_ode(a, state, args): @@ -162,11 +167,13 @@ def get_ode_fn(cosmo, mesh_shape, halo_size=0, sharding=None): Compatible with [Diffrax API](https://docs.kidger.site/diffrax/) """ pos, vel = state - forces = pm_forces(pos, mesh_shape, halo_size=halo_size, sharding=sharding) * 1.5 * cosmo.Omega_m + forces = pm_forces( + pos, mesh_shape, halo_size=halo_size, + sharding=sharding) * 1.5 * cosmo.Omega_m # Computes the update of position (drift) dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel - + # Computes the update of velocity (kick) dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces @@ -177,7 +184,7 @@ def get_ode_fn(cosmo, mesh_shape, halo_size=0, sharding=None): def pgd_correction(pos, mesh_shape, params): """ - improve the short-range interactions of PM-Nbody simulations with potential gradient descent method, + improve the short-range interactions of PM-Nbody simulations with potential gradient descent method, based on https://arxiv.org/abs/1804.00671 args: @@ -188,20 +195,24 @@ def pgd_correction(pos, mesh_shape, params): delta = cic_paint(jnp.zeros(mesh_shape), pos) alpha, kl, ks = params delta_k = jnp.fft.rfftn(delta) - PGD_range=PGD_kernel(kvec, kl, ks) - - pot_k_pgd=(delta_k * invlaplace_kernel(kvec))*PGD_range + PGD_range = PGD_kernel(kvec, kl, ks) + + pot_k_pgd = (delta_k * invlaplace_kernel(kvec)) * PGD_range + + forces_pgd = jnp.stack([ + cic_read(jnp.fft.irfftn(-gradient_kernel(kvec, i) * pot_k_pgd), pos) + for i in range(3) + ], + axis=-1) + + dpos_pgd = forces_pgd * alpha - forces_pgd= jnp.stack([cic_read(jnp.fft.irfftn(- gradient_kernel(kvec, i)*pot_k_pgd), pos) - for i in range(3)],axis=-1) - - dpos_pgd = forces_pgd*alpha - return dpos_pgd def make_neural_ode_fn(model, mesh_shape): - def neural_nbody_ode(state, a, cosmo:Cosmology, params): + + def neural_nbody_ode(state, a, cosmo: Cosmology, params): """ state is a tuple (position, velocities) """ @@ -213,15 +224,19 @@ def make_neural_ode_fn(model, mesh_shape): delta_k = jnp.fft.rfftn(delta) # Computes gravitational potential - pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec, r_split=0) + pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec, + r_split=0) # Apply a correction filter - kk = jnp.sqrt(sum((ki/jnp.pi)**2 for ki in kvec)) - pot_k = pot_k *(1. + model.apply(params, kk, jnp.atleast_1d(a))) + kk = jnp.sqrt(sum((ki / jnp.pi)**2 for ki in kvec)) + pot_k = pot_k * (1. + model.apply(params, kk, jnp.atleast_1d(a))) # Computes gravitational forces - forces = jnp.stack([cic_read(jnp.fft.irfftn(- gradient_kernel(kvec, i)*pot_k), pos) - for i in range(3)],axis=-1) + forces = jnp.stack([ + cic_read(jnp.fft.irfftn(-gradient_kernel(kvec, i) * pot_k), pos) + for i in range(3) + ], + axis=-1) forces = forces * 1.5 * cosmo.Omega_m @@ -232,4 +247,5 @@ def make_neural_ode_fn(model, mesh_shape): dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces return dpos, dvel + return neural_nbody_ode