JaxPM/notebooks/03-MultiHost_PM.ipynb
2024-10-26 22:49:17 +02:00

4.1 KiB

Open In Colab

In [8]:
!pip install --quiet git+https://github.com/DifferentiableUniverseInitiative/JaxPM.git
!pip install diffrax
In [1]:
!salloc --account=tkc@a100 -C a100 --gres=gpu:8 --ntasks-per-node=8 --time=00:30:00  --cpus-per-task=8 --hint=nomultithread --qos=qos_gpu-dev --nodes=2
%pylab is deprecated, use %matplotlib inline and import the required libraries.
Populating the interactive namespace from numpy and matplotlib
In [ ]:
!squeue -u $USER
In [ ]:
export JOB_ID=123456
In [ ]:
!srun --jobid=$JOB_ID -n 16 python 03-MultiHost_PM.py
In [ ]:
import numpy as np

data = np.load("multihost_pm.npz")
initial_conditions = data['initial_conditions']
lpt_displacements = data['lpt_displacements']
ode_solutions = data['ode_solutions']
solver_stats = data['solver_stats']
print(f"Solver stats: {solver_stats}")
In [ ]:
from visualize import plot_fields

fields = {"Initial Conditions" : initial_conditions , "LPT Field" : cic_paint(jnp.zeros(mesh_shape) ,lpt_particles)}
for i , field in enumerate(ode_particles):
    fields[f"field_{i}"] = cic_paint(jnp.zeros(mesh_shape) , field)
plot_fields(fields)
shape of grid_mesh: (256, 256, 256)