add notebook examples

This commit is contained in:
Wassim KABALAN 2024-10-26 22:47:26 +02:00
parent 0c96a4dc10
commit 49dd18a3f8
6 changed files with 861 additions and 192 deletions

View file

@ -0,0 +1,320 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "view-in-github"
},
"source": [
"<a href=\"https://colab.research.google.com/github/DifferentiableUniverseInitiative/JaxPM/blob/main/notebooks/Introduction.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "9Jy5BL1XiK1s",
"metadata": {
"id": "9Jy5BL1XiK1s"
},
"outputs": [],
"source": [
"!pip install --quiet git+https://github.com/DifferentiableUniverseInitiative/JaxPM.git"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "c5f42bbe",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "c5f42bbe",
"outputId": "a7841b28-5f20-4856-bd1d-f8a3572095b5"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"%pylab is deprecated, use %matplotlib inline and import the required libraries.\n",
"Populating the interactive namespace from numpy and matplotlib\n"
]
}
],
"source": [
"import os\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import jax_cosmo as jc\n",
"\n",
"from jax.experimental.ode import odeint\n",
"\n",
"from jaxpm.painting import cic_paint\n",
"from jaxpm.pm import linear_field, lpt, make_ode_fn"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "281b4d3b",
"metadata": {
"id": "281b4d3b"
},
"outputs": [],
"source": [
"mesh_shape= [256, 256, 256]\n",
"box_size = [256.,256.,256.]\n",
"snapshots = jnp.linspace(0.1,1.,3)\n",
"\n",
"@jax.jit\n",
"def run_simulation(omega_c, sigma8):\n",
" # Create a small function to generate the matter power spectrum\n",
" k = jnp.logspace(-4, 1, 128)\n",
" pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)\n",
" pk_fn = lambda x: jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape)\n",
"\n",
" # Create initial conditions\n",
" initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=jax.random.PRNGKey(0))\n",
"\n",
" # Create particles\n",
" particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),axis=-1).reshape([-1 , 3])\n",
" cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)\n",
" \n",
" # Initial displacement\n",
" dx, p, f = lpt(cosmo, initial_conditions, particles, 0.1)\n",
" \n",
" # Evolve the simulation forward\n",
" res = odeint(make_ode_fn(mesh_shape,particles), [particles + dx, p], snapshots, cosmo, rtol=1e-5, atol=1e-5)\n",
" \n",
" # Return the simulation volume at requested \n",
"\n",
" return initial_conditions , particles + dx , res[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "826be667",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "826be667",
"outputId": "dc43b5c4-a004-41bf-f2c8-128c17cf4de1"
},
"outputs": [],
"source": [
"initial_conditions , lpt_particles , ode_particles = run_simulation(0.25, 0.8)\n",
"%timeit initial_conditions , lpt_particles , ode_particles = run_simulation(0.25, 0.8)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4e012ce8",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 323
},
"id": "4e012ce8",
"outputId": "75390318-8072-481f-ffb9-ec09cd71cb1d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"shape of grid_mesh: (256, 256, 256)\n"
]
}
],
"source": [
"from visualize import plot_fields\n",
"\n",
"fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint(jnp.zeros(mesh_shape) ,lpt_particles)}\n",
"for i , field in enumerate(ode_particles):\n",
" fields[f\"field_{i}\"] = cic_paint(jnp.zeros(mesh_shape) , field)\n",
"plot_fields(fields)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b71824ed",
"metadata": {},
"outputs": [],
"source": [
"mesh_shape= [256, 256, 256]\n",
"box_size = [256.,256.,256.]\n",
"snapshots = jnp.linspace(0.1,1.,3)\n",
"\n",
"@jax.jit\n",
"def run_simulation(omega_c, sigma8):\n",
" # Create a small function to generate the matter power spectrum\n",
" k = jnp.logspace(-4, 1, 128)\n",
" pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)\n",
" pk_fn = lambda x: jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape)\n",
"\n",
" # Create initial conditions\n",
" initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=jax.random.PRNGKey(0))\n",
"\n",
" # Create particles\n",
" cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)\n",
" \n",
" # Initial displacement\n",
" dx, p, f = lpt(cosmo, initial_conditions, 0.1)\n",
" \n",
" # Evolve the simulation forward\n",
" res = odeint(make_ode_fn(mesh_shape), [dx, p], snapshots, cosmo, rtol=1e-5, atol=1e-5)\n",
" \n",
" # Return the simulation volume at requested \n",
"\n",
" return initial_conditions , dx , res[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e9c9fd56",
"metadata": {},
"outputs": [],
"source": [
"initial_conditions , lpt_displacements , ode_displacements = run_simulation(0.25, 0.8)\n",
"%timeit initial_conditions , lpt_displacements , ode_displacements = run_simulation(0.25, 0.8)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "33b5e684",
"metadata": {},
"outputs": [],
"source": [
"from visualize import plot_fields\n",
"\n",
"fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint(jnp.zeros(mesh_shape) ,lpt_displacements)}\n",
"for i , field in enumerate(ode_displacements):\n",
" fields[f\"field_{i}\"] = cic_paint(jnp.zeros(mesh_shape) , field)\n",
"plot_fields(fields)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4e050871",
"metadata": {
"id": "4e050871"
},
"outputs": [],
"source": [
"!pip install diffrax\n",
"from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "43504a1b",
"metadata": {},
"outputs": [],
"source": [
"mesh_shape= [256, 256, 256]\n",
"box_size = [256.,256.,256.]\n",
"snapshots = jnp.linspace(0.1,1.,3)\n",
"\n",
"@jax.jit\n",
"def run_simulation(omega_c, sigma8):\n",
" # Create a small function to generate the matter power spectrum\n",
" k = jnp.logspace(-4, 1, 128)\n",
" pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)\n",
" pk_fn = lambda x: jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape)\n",
"\n",
" # Create initial conditions\n",
" initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=jax.random.PRNGKey(0))\n",
"\n",
" # Create particles\n",
" cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)\n",
" \n",
" # Initial displacement\n",
" dx, p, f = lpt(cosmo, initial_conditions, 0.1)\n",
" \n",
" # Evolve the simulation forward\n",
" ode_fn = make_ode_fn(mesh_shape)\n",
" term = ODETerm(\n",
" lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))\n",
" solver = LeapfrogMidpoint()\n",
"\n",
" stepsize_controller = ConstantStepSize()\n",
" res = diffeqsolve(term,\n",
" solver,\n",
" t0=0.1,\n",
" t1=1.,\n",
" dt0=0.01,\n",
" y0=jnp.stack([dx, p], axis=0),\n",
" args=cosmo,\n",
" saveat=SaveAt(ts=snapshots),\n",
" stepsize_controller=stepsize_controller)\n",
"\n",
"\n",
" return initial_conditions , dx , res.ys , res.stats"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "19949ff1",
"metadata": {},
"outputs": [],
"source": [
"initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)\n",
"%timeit initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)\n",
"print(f\"Solver Stats : {solver_stats}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "76a26e98",
"metadata": {},
"outputs": [],
"source": [
"from visualize import plot_fields\n",
"\n",
"fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint(jnp.zeros(mesh_shape) ,lpt_displacements)}\n",
"for i , field in enumerate(ode_solutions):\n",
" fields[f\"field_{i}\"] = cic_paint(jnp.zeros(mesh_shape) , field[0])\n",
"plot_fields(fields)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"include_colab_link": true,
"name": "Introduction.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View file

@ -0,0 +1,242 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "view-in-github"
},
"source": [
"<a href=\"https://colab.research.google.com/github/DifferentiableUniverseInitiative/JaxPM/blob/main/notebooks/Introduction.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "9Jy5BL1XiK1s",
"metadata": {
"id": "9Jy5BL1XiK1s"
},
"outputs": [],
"source": [
"!pip install --quiet git+https://github.com/DifferentiableUniverseInitiative/JaxPM.git\n",
"!pip install diffrax"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "c5f42bbe",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "c5f42bbe",
"outputId": "a7841b28-5f20-4856-bd1d-f8a3572095b5"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"%pylab is deprecated, use %matplotlib inline and import the required libraries.\n",
"Populating the interactive namespace from numpy and matplotlib\n"
]
}
],
"source": [
"import os\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import jax_cosmo as jc\n",
"\n",
"from jax.experimental.ode import odeint\n",
"from jaxpm.kernels import interpolate_power_spectrum\n",
"from jaxpm.painting import cic_paint\n",
"from jaxpm.pm import linear_field, lpt, make_ode_fn\n",
"from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "38df34e3",
"metadata": {},
"outputs": [],
"source": [
"assert jax.device_count() >= 8, \"This notebook requires a TPU or GPU runtime with 8 devices\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9edd2246",
"metadata": {},
"outputs": [],
"source": [
"from jax.experimental.mesh_utils import create_device_mesh\n",
"from jax.experimental.multihost_utils import process_allgather\n",
"from jax.sharding import Mesh, NamedSharding\n",
"from jax.sharding import PartitionSpec as P\n",
"from functools import partial\n",
"\n",
"all_gather = partial(process_allgather, tiled=True)\n",
"\n",
"pdims = (2, 4)\n",
"devices = create_device_mesh(pdims)\n",
"mesh = Mesh(devices, axis_names=('x', 'y'))\n",
"sharding = NamedSharding(mesh, P('x', 'y'))"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "281b4d3b",
"metadata": {
"id": "281b4d3b"
},
"outputs": [],
"source": [
"mesh_shape = [1024, 1024, 1024]\n",
"box_size = [1024., 1024., 1024.]\n",
"halo_size = 128\n",
"snapshots = jnp.linspace(0.1, 1., 3)\n",
"\n",
"\n",
"@jax.jit\n",
"def run_simulation(omega_c, sigma8):\n",
" # Create a small function to generate the matter power spectrum\n",
" k = jnp.logspace(-4, 1, 128)\n",
" pk = jc.power.linear_matter_power(\n",
" jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)\n",
" pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)\n",
"\n",
" # Create initial conditions\n",
" initial_conditions = linear_field(mesh_shape,\n",
" box_size,\n",
" pk_fn,\n",
" seed=jax.random.PRNGKey(0),\n",
" sharding=sharding)\n",
"\n",
" # Create particles\n",
" particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),\n",
" axis=-1).reshape([-1, 3])\n",
" cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)\n",
"\n",
" # Initial displacement\n",
" dx, p, f = lpt(cosmo,\n",
" initial_conditions,\n",
" particles,\n",
" 0.1,\n",
" halo_size=halo_size,\n",
" sharding=sharding)\n",
"\n",
" # Evolve the simulation forward\n",
" ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)\n",
" term = ODETerm(\n",
" lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))\n",
" solver = LeapfrogMidpoint()\n",
"\n",
" stepsize_controller = ConstantStepSize()\n",
" res = diffeqsolve(term,\n",
" solver,\n",
" t0=0.1,\n",
" t1=1.,\n",
" dt0=0.01,\n",
" y0=jnp.stack([dx, p], axis=0),\n",
" args=cosmo,\n",
" saveat=SaveAt(ts=snapshots),\n",
" stepsize_controller=stepsize_controller)\n",
"\n",
" return initial_conditions, dx, res.ys, res.stats"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "826be667",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "826be667",
"outputId": "dc43b5c4-a004-41bf-f2c8-128c17cf4de1"
},
"outputs": [],
"source": [
"initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)\n",
"%timeit initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)\n",
"print(f\"Solver Stats : {solver_stats}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "042cc55c",
"metadata": {},
"outputs": [],
"source": [
"initial_conditions = all_gather(initial_conditions)\n",
"lpt_particles = all_gather(lpt_particles)\n",
"ode_particles = [all_gather(p) for p in ode_particles]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4e012ce8",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 323
},
"id": "4e012ce8",
"outputId": "75390318-8072-481f-ffb9-ec09cd71cb1d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"shape of grid_mesh: (256, 256, 256)\n"
]
}
],
"source": [
"from visualize import plot_fields\n",
"\n",
"fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint(jnp.zeros(mesh_shape) ,lpt_particles)}\n",
"for i , field in enumerate(ode_particles):\n",
" fields[f\"field_{i}\"] = cic_paint(jnp.zeros(mesh_shape) , field)\n",
"plot_fields(fields)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"include_colab_link": true,
"name": "Introduction.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View file

@ -0,0 +1,157 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "22803ddc",
"metadata": {
"colab_type": "text",
"id": "view-in-github"
},
"source": [
"<a href=\"https://colab.research.google.com/github/DifferentiableUniverseInitiative/JaxPM/blob/main/notebooks/Introduction.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "9Jy5BL1XiK1s",
"metadata": {
"id": "9Jy5BL1XiK1s"
},
"outputs": [],
"source": [
"!pip install --quiet git+https://github.com/DifferentiableUniverseInitiative/JaxPM.git\n",
"!pip install diffrax"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "c5f42bbe",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "c5f42bbe",
"outputId": "a7841b28-5f20-4856-bd1d-f8a3572095b5"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"%pylab is deprecated, use %matplotlib inline and import the required libraries.\n",
"Populating the interactive namespace from numpy and matplotlib\n"
]
}
],
"source": [
"!salloc --account=tkc@a100 -C a100 --gres=gpu:8 --ntasks-per-node=8 --time=00:30:00 --cpus-per-task=8 --hint=nomultithread --qos=qos_gpu-dev --nodes=2"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7ebdfc00",
"metadata": {},
"outputs": [],
"source": [
"!squeue -u $USER"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c014316c",
"metadata": {},
"outputs": [],
"source": [
"export JOB_ID=123456"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b7eabac5",
"metadata": {},
"outputs": [],
"source": [
"!srun --jobid=$JOB_ID -n 16 python 03-MultiHost_PM.py"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "472dd4bf",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"data = np.load(\"multihost_pm.npz\")\n",
"initial_conditions = data['initial_conditions']\n",
"lpt_displacements = data['lpt_displacements']\n",
"ode_solutions = data['ode_solutions']\n",
"solver_stats = data['solver_stats']\n",
"print(f\"Solver stats: {solver_stats}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4e012ce8",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 323
},
"id": "4e012ce8",
"outputId": "75390318-8072-481f-ffb9-ec09cd71cb1d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"shape of grid_mesh: (256, 256, 256)\n"
]
}
],
"source": [
"from visualize import plot_fields\n",
"\n",
"fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint(jnp.zeros(mesh_shape) ,lpt_particles)}\n",
"for i , field in enumerate(ode_particles):\n",
" fields[f\"field_{i}\"] = cic_paint(jnp.zeros(mesh_shape) , field)\n",
"plot_fields(fields)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"include_colab_link": true,
"name": "Introduction.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View file

@ -0,0 +1,104 @@
import os
os.environ["EQX_ON_ERROR"] = "nan" # avoid an allgather caused by diffrax
import jax
jax.distributed.initialize()
rank = jax.process_index()
size = jax.process_count()
import jax.numpy as jnp
import jax_cosmo as jc
from jaxpm.kernels import interpolate_power_spectrum
from jaxpm.painting import cic_paint_dx
from jaxpm.pm import linear_field, lpt, make_ode_fn
from jax.experimental.mesh_utils import create_device_mesh
from jax.experimental.multihost_utils import process_allgather
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from functools import partial
import numpy as np
from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve
all_gather = partial(process_allgather, tiled=True)
pdims = (2, 4)
devices = create_device_mesh(pdims)
mesh = Mesh(devices, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
mesh_shape = [2024, 1024, 1024]
box_size = [1024., 1024., 1024.]
halo_size = 512
snapshots = jnp.linspace(0.1, 1., 2)
@jax.jit
def run_simulation(omega_c, sigma8):
# Create a small function to generate the matter power spectrum
k = jnp.logspace(-4, 1, 128)
pk = jc.power.linear_matter_power(
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)
# Create initial conditions
initial_conditions = linear_field(mesh_shape,
box_size,
pk_fn,
seed=jax.random.PRNGKey(0),
sharding=sharding)
# Create particles
particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),
axis=-1).reshape([-1, 3])
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
# Initial displacement
dx, p, _ = lpt(cosmo,
initial_conditions,
particles,
0.1,
halo_size=halo_size,
sharding=sharding)
# Evolve the simulation forward
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)
term = ODETerm(
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
solver = LeapfrogMidpoint()
stepsize_controller = ConstantStepSize()
res = diffeqsolve(term,
solver,
t0=0.1,
t1=1.,
dt0=0.01,
y0=jnp.stack([dx, p], axis=0),
args=cosmo,
saveat=SaveAt(ts=snapshots),
stepsize_controller=stepsize_controller)
return initial_conditions, dx, res.ys, res.stats
initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)
print(f"[{rank}] Simulation completed")
print(f"[{rank}] Solver stats: {solver_stats}")
# Gather the results
initial_conditions = all_gather(initial_conditions)
lpt_displacements = all_gather(lpt_displacements)
ode_solutions = [all_gather(sol) for sol in ode_solutions]
if rank == 0:
np.savez("multihost_pm.npz",
initial_conditions=initial_conditions,
lpt_displacements=lpt_displacements,
ode_solutions=ode_solutions,
solver_stats=solver_stats)
print(f"[{rank}] Simulation results saved")

View file

@ -1,192 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "view-in-github"
},
"source": [
"<a href=\"https://colab.research.google.com/github/DifferentiableUniverseInitiative/JaxPM/blob/main/notebooks/Introduction.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "9Jy5BL1XiK1s",
"metadata": {
"id": "9Jy5BL1XiK1s"
},
"outputs": [],
"source": [
"!pip install --quiet git+https://github.com/DifferentiableUniverseInitiative/JaxPM.git@ASKabalan/jaxdecomp_proto"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "c5f42bbe",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "c5f42bbe",
"outputId": "a7841b28-5f20-4856-bd1d-f8a3572095b5"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"%pylab is deprecated, use %matplotlib inline and import the required libraries.\n",
"Populating the interactive namespace from numpy and matplotlib\n"
]
}
],
"source": [
"%pylab inline\n",
"import os\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import jax_cosmo as jc\n",
"\n",
"from jax.experimental.ode import odeint\n",
"\n",
"from jaxpm.painting import cic_paint\n",
"from jaxpm.pm import linear_field, lpt, make_ode_fn"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "281b4d3b",
"metadata": {
"id": "281b4d3b"
},
"outputs": [],
"source": [
"mesh_shape= [256, 256, 256]\n",
"box_size = [256.,256.,256.]\n",
"snapshots = jnp.linspace(0.1,1.,2)\n",
"\n",
"@jax.jit\n",
"def run_simulation(omega_c, sigma8):\n",
" # Create a small function to generate the matter power spectrum\n",
" k = jnp.logspace(-4, 1, 128)\n",
" pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)\n",
" pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape([-1]), k, pk).reshape(x.shape)\n",
"\n",
" # Create initial conditions\n",
" initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=jax.random.PRNGKey(0))\n",
"\n",
" # Create particles\n",
" particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),axis=-1).reshape([-1,3])\n",
"\n",
" cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)\n",
" \n",
" # Initial displacement\n",
" dx, p, f = lpt(cosmo, initial_conditions, particles, 0.1)\n",
" \n",
" # Evolve the simulation forward\n",
" res = odeint(make_ode_fn(mesh_shape), [particles+dx, p], snapshots, cosmo, rtol=1e-5, atol=1e-5)\n",
" \n",
" # Return the simulation volume at requested \n",
" return res[0]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "826be667",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "826be667",
"outputId": "dc43b5c4-a004-41bf-f2c8-128c17cf4de1"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"particles are Traced<ShapedArray(int32[16777216,3])>with<DynamicJaxprTrace(level=1/0)>\n",
"pm_forces particles are Traced<ShapedArray(int32[16777216,3])>with<DynamicJaxprTrace(level=1/0)>\n",
"shape of displacement: (256, 256, 256)\n",
"pm_forces particles are Traced<ShapedArray(float32[16777216,3])>with<DynamicJaxprTrace(level=2/0)>\n"
]
}
],
"source": [
"res = run_simulation(0.25, 0.8)\n",
"#%timeit res = run_simulation(0.25, 0.8)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4e012ce8",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 323
},
"id": "4e012ce8",
"outputId": "75390318-8072-481f-ffb9-ec09cd71cb1d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"shape of grid_mesh: (256, 256, 256)\n"
]
}
],
"source": [
"figure(figsize=[10,5])\n",
"subplot(121)\n",
"imshow(cic_paint(jnp.zeros(mesh_shape), res[0]).sum(axis=0),cmap='magma')\n",
"subplot(122)\n",
"imshow(cic_paint(jnp.zeros(mesh_shape), res[1]).sum(axis=0),cmap='magma')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4e050871",
"metadata": {
"id": "4e050871"
},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"include_colab_link": true,
"name": "Introduction.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

38
notebooks/visualize.py Normal file
View file

@ -0,0 +1,38 @@
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
def plot_fields(fields_dict, sum_over=None):
"""
Plots sum projections of 3D fields along different axes,
slicing only the first `sum_over` elements along each axis.
Args:
- fields: list of 3D arrays representing fields to plot
- names: list of names for each field, used in titles
- sum_over: number of slices to sum along each axis (default: fields[0].shape[0] // 8)
"""
sum_over = sum_over or list(fields_dict.values())[0].shape[0] // 8
nb_rows = len(fields_dict)
nb_cols = 3
fig, axes = plt.subplots(nb_rows, nb_cols, figsize=(15, 5 * nb_rows))
def plot_subplots(proj_axis, field, row, title):
slicing = [slice(None)] * field.ndim
slicing[proj_axis] = slice(None, sum_over)
slicing = tuple(slicing)
# Sum projection over the specified axis and plot
axes[row, proj_axis].imshow(field[slicing].sum(axis=proj_axis) + 1,
cmap='magma', extent=[0, field.shape[proj_axis], 0, field.shape[proj_axis]])
axes[row, proj_axis].set_xlabel('Mpc/h')
axes[row, proj_axis].set_ylabel('Mpc/h')
axes[row, proj_axis].set_title(title)
# Plot each field across the three axes
for i, (name, field) in enumerate(fields_dict.items()):
for proj_axis in range(3):
plot_subplots(proj_axis, field, i, f"{name} projection {proj_axis}")
plt.tight_layout()
plt.show()