12 KiB
Animating Particle Mesh density fields¶
In this tutorial, we will animate the density field of a particle mesh simulation. We will use the manim
library to create the animation.
The density fields are created exactly like in the notebook 05-MultiHost_PM.ipynb using the same script 05-MultiHost_PM.py.
To run a multi-host simulation, you first need to allocate a job with salloc
. This command requests resources on an HPC cluster.
just like in notebook [05-MultiHost_PM.ipynb]
!salloc --account=XXX@a100 -C a100 --gres=gpu:8 --ntasks-per-node=8 --time=00:40:00 --cpus-per-task=8 --hint=nomultithread --qos=qos_gpu-dev --nodes=4 &
A few hours later
Use !squeue -u $USER -o "%i %D %b"
to check the JOB ID and verify your resource allocation.
In this example, we’ve been allocated 32 GPUs split across 4 nodes.
!squeue -u $USER -o "%i %D %b"
Unset the following environment variables, as they can cause issues when using JAX in a distributed setting:
import os
del os.environ['VSCODE_PROXY_URI']
del os.environ['NO_PROXY']
del os.environ['no_proxy']
Checking Available Compute Resources¶
Run the following command to initialize JAX distributed computing and display the devices available for this job:
!srun --jobid=467745 -n 32 python -c "import jax; jax.distributed.initialize(); print(jax.devices()) if jax.process_index() == 0 else None"
Multi-Host Simulation Script with Arguments (reminder)¶
This script is nearly identical to the single-host version, with the main addition being the call to jax.distributed.initialize()
at the start, enabling multi-host parallelism. Here’s a breakdown of the key arguments:
--pdims
(-p
): Specifies processor grid dimensions as two integers, like16 2
for 16 x 2 device mesh (default is[1, jax.devices()]
).--mesh_shape
(-m
): Defines the simulation mesh shape as three integers (default is[512, 512, 512]
).--box_size
(-b
): Sets the physical box size of the simulation as three floating-point values, e.g.,1000. 1000. 1000.
(default is[500.0, 500.0, 500.0]
).--halo_size
(-H
): Specifies the halo size for boundary overlap across nodes (default is64
).--solver
(-s
): Chooses the ODE solver (leapfrog
ordopri8
). Theleapfrog
solver uses a fixed step size, whiledopri8
is an adaptive Runge-Kutta solver with a PID controller (default isleapfrog
).--snapthots
(-st
) : Number of snapshots to save (warning, increases memory usage)
Running the Multi-Host Simulation Script¶
To create a smooth animation, we need a series of closely spaced snapshots to capture the evolution of the density field over time. In this example, we set the number of snapshots to 10 to ensure smooth transitions in the animation.
Using a larger number of GPUs helps process these snapshots efficiently, especially with a large simulation mesh or high-resolution data. This allows us to achieve both the desired snapshot frequency and the necessary simulation detail without excessive runtime.
The command to run the multi-host simulation with these settings will look something like this:
import subprocess
# Define parameters as variables
jobid = "467745"
num_processes = 32
script_name = "05-MultiHost_PM.py"
mesh_shape = (1024, 1024, 1024)
box_size = (1000., 1000., 1000.)
halo_size = 128
solver = "leapfrog"
pdims = (16, 2)
snapshots = 8
# Build the command as a list, incorporating variables
command = [
"srun",
f"--jobid={jobid}",
"-n", str(num_processes),
"python", script_name,
"--mesh_shape", str(mesh_shape[0]), str(mesh_shape[1]), str(mesh_shape[2]),
"--box_size", str(box_size[0]), str(box_size[1]), str(box_size[2]),
"--halo_size", str(halo_size),
"-s", solver,
"--pdims", str(pdims[0]), str(pdims[1]),
"--snapshots", str(snapshots)
]
# Execute the command as a subprocess
subprocess.run(command)
Projecting the 3D Density Fields to 2D¶
To visualize the 3D density fields in 2D, we need to create a projection:
project_to_2d
Function: This function reduces the 3D array to 2D by summing over a portion of one axis.- We sum the top one-eighth of the data along the first axis to capture a slice of the density field.
Creating 2D Projections: Apply
project_to_2d
to each 3D field (initial_conditions
,lpt_displacements
,ode_solution_0
, andode_solution_1
) to get 2D arrays that represent the density fields.
Applying the Magma Colormap¶
To improve visualization, apply the "magma" colormap to each 2D projection:
apply_colormap
Function: This function maps values in the 2D array to colors using the "magma" colormap.- First, normalize the array to the
[0, 1]
range. - Apply the colormap to create RGB images, which will be used for the animation.
- First, normalize the array to the
from matplotlib import colormaps
# Define a function to project the 3D field to 2D
def project_to_2d(field):
sum_over = field.shape[0] // 8
slicing = [slice(None)] * field.ndim
slicing[0] = slice(None, sum_over)
slicing = tuple(slicing)
return field[slicing].sum(axis=0)
def apply_colormap(array, cmap_name="magma"):
cmap = colormaps[cmap_name]
normalized_array = (array - array.min()) / (array.max() - array.min())
colored_image = cmap(normalized_array)[:, :, :3] # Drop alpha channel for RGB
return (colored_image * 255).astype(np.uint8)
Loading and Visualizing Results¶
After running the multi-host simulation, we load the saved results from disk:
initial_conditions.npy
: Initial conditions for the simulation.lpt_displacements.npy
: Linear perturbation displacements.ode_solution_*.npy
: Solutions from the ODE solver at each snapshot.
We will now project the fields to 2D maps and apply the color map
import numpy as np
initial_conditions = apply_colormap(project_to_2d(np.load('fields/initial_conditions.npy')))
lpt_displacements = apply_colormap(project_to_2d(np.load('fields/lpt_displacements.npy')))
ode_solutions = []
for i in range(8):
ode_solutions.append(apply_colormap(project_to_2d(np.load(f'fields/ode_solution_{i}.npy'))))
Animating with Manim¶
To create animations with manim
in a Jupyter notebook, we start by configuring some settings to ensure the output displays correctly and without a background.
from manim import *
config.media_width = "100%"
config.verbosity = "WARNING"
config.background_color = "#00000000" # Transparent background
Defining the Animation in Manim¶
This animation class, FieldTransition
, smoothly transitions through the stages of the particle mesh density field evolution.
- Setup: Each density field snapshot is loaded as an image and aligned for smooth transitions.
- Animation Sequence:
- The animation begins with a fade-in of the initial conditions.
- It then transitions through the stages in sequence, showing each snapshot of the density field evolution with brief pauses in between.
To run the animation, execute %manim -v WARNING -qm FieldTransition
to render it in the Jupyter Notebook.
# Define the animation in Manim
class FieldTransition(Scene):
def construct(self):
init_conditions_img = ImageMobject(initial_conditions).scale(4)
lpt_img = ImageMobject(lpt_displacements).scale(4)
snapshots_imgs = [ImageMobject(sol).scale(4) for sol in ode_solutions]
# Place the images on top of each other initially
lpt_img.move_to(init_conditions_img)
for img in snapshots_imgs:
img.move_to(init_conditions_img)
# Show initial field and then transform between fields
self.play(FadeIn(init_conditions_img))
self.wait(0.2)
self.play(Transform(init_conditions_img, lpt_img))
self.wait(0.2)
self.play(Transform(lpt_img, snapshots_imgs[0]))
self.wait(0.2)
for img1, img2 in zip(snapshots_imgs, snapshots_imgs[1:]):
self.play(Transform(img1, img2))
self.wait(0.2)
%manim -v WARNING -qm -o anim.gif --format=gif FieldTransition