mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-08 04:40:53 +00:00
4.1 KiB
4.1 KiB
In [8]:
!pip install --quiet git+https://github.com/DifferentiableUniverseInitiative/JaxPM.git
!pip install diffrax
In [1]:
!salloc --account=tkc@a100 -C a100 --gres=gpu:8 --ntasks-per-node=8 --time=00:30:00 --cpus-per-task=8 --hint=nomultithread --qos=qos_gpu-dev --nodes=2
In [ ]:
!squeue -u $USER
In [ ]:
export JOB_ID=123456
In [ ]:
!srun --jobid=$JOB_ID -n 16 python 03-MultiHost_PM.py
In [ ]:
import numpy as np
data = np.load("multihost_pm.npz")
initial_conditions = data['initial_conditions']
lpt_displacements = data['lpt_displacements']
ode_solutions = data['ode_solutions']
solver_stats = data['solver_stats']
print(f"Solver stats: {solver_stats}")
In [ ]:
from visualize import plot_fields
fields = {"Initial Conditions" : initial_conditions , "LPT Field" : cic_paint(jnp.zeros(mesh_shape) ,lpt_particles)}
for i , field in enumerate(ode_particles):
fields[f"field_{i}"] = cic_paint(jnp.zeros(mesh_shape) , field)
plot_fields(fields)