mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 20:30:54 +00:00
add notebook examples
This commit is contained in:
parent
0c96a4dc10
commit
49dd18a3f8
6 changed files with 861 additions and 192 deletions
320
notebooks/01-Introduction.ipynb
Normal file
320
notebooks/01-Introduction.ipynb
Normal file
|
@ -0,0 +1,320 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "view-in-github"
|
||||
},
|
||||
"source": [
|
||||
"<a href=\"https://colab.research.google.com/github/DifferentiableUniverseInitiative/JaxPM/blob/main/notebooks/Introduction.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "9Jy5BL1XiK1s",
|
||||
"metadata": {
|
||||
"id": "9Jy5BL1XiK1s"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install --quiet git+https://github.com/DifferentiableUniverseInitiative/JaxPM.git"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "c5f42bbe",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "c5f42bbe",
|
||||
"outputId": "a7841b28-5f20-4856-bd1d-f8a3572095b5"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"%pylab is deprecated, use %matplotlib inline and import the required libraries.\n",
|
||||
"Populating the interactive namespace from numpy and matplotlib\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import jax\n",
|
||||
"import jax.numpy as jnp\n",
|
||||
"import jax_cosmo as jc\n",
|
||||
"\n",
|
||||
"from jax.experimental.ode import odeint\n",
|
||||
"\n",
|
||||
"from jaxpm.painting import cic_paint\n",
|
||||
"from jaxpm.pm import linear_field, lpt, make_ode_fn"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "281b4d3b",
|
||||
"metadata": {
|
||||
"id": "281b4d3b"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mesh_shape= [256, 256, 256]\n",
|
||||
"box_size = [256.,256.,256.]\n",
|
||||
"snapshots = jnp.linspace(0.1,1.,3)\n",
|
||||
"\n",
|
||||
"@jax.jit\n",
|
||||
"def run_simulation(omega_c, sigma8):\n",
|
||||
" # Create a small function to generate the matter power spectrum\n",
|
||||
" k = jnp.logspace(-4, 1, 128)\n",
|
||||
" pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)\n",
|
||||
" pk_fn = lambda x: jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape)\n",
|
||||
"\n",
|
||||
" # Create initial conditions\n",
|
||||
" initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=jax.random.PRNGKey(0))\n",
|
||||
"\n",
|
||||
" # Create particles\n",
|
||||
" particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),axis=-1).reshape([-1 , 3])\n",
|
||||
" cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)\n",
|
||||
" \n",
|
||||
" # Initial displacement\n",
|
||||
" dx, p, f = lpt(cosmo, initial_conditions, particles, 0.1)\n",
|
||||
" \n",
|
||||
" # Evolve the simulation forward\n",
|
||||
" res = odeint(make_ode_fn(mesh_shape,particles), [particles + dx, p], snapshots, cosmo, rtol=1e-5, atol=1e-5)\n",
|
||||
" \n",
|
||||
" # Return the simulation volume at requested \n",
|
||||
"\n",
|
||||
" return initial_conditions , particles + dx , res[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "826be667",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "826be667",
|
||||
"outputId": "dc43b5c4-a004-41bf-f2c8-128c17cf4de1"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"initial_conditions , lpt_particles , ode_particles = run_simulation(0.25, 0.8)\n",
|
||||
"%timeit initial_conditions , lpt_particles , ode_particles = run_simulation(0.25, 0.8)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4e012ce8",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 323
|
||||
},
|
||||
"id": "4e012ce8",
|
||||
"outputId": "75390318-8072-481f-ffb9-ec09cd71cb1d"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"shape of grid_mesh: (256, 256, 256)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from visualize import plot_fields\n",
|
||||
"\n",
|
||||
"fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint(jnp.zeros(mesh_shape) ,lpt_particles)}\n",
|
||||
"for i , field in enumerate(ode_particles):\n",
|
||||
" fields[f\"field_{i}\"] = cic_paint(jnp.zeros(mesh_shape) , field)\n",
|
||||
"plot_fields(fields)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b71824ed",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mesh_shape= [256, 256, 256]\n",
|
||||
"box_size = [256.,256.,256.]\n",
|
||||
"snapshots = jnp.linspace(0.1,1.,3)\n",
|
||||
"\n",
|
||||
"@jax.jit\n",
|
||||
"def run_simulation(omega_c, sigma8):\n",
|
||||
" # Create a small function to generate the matter power spectrum\n",
|
||||
" k = jnp.logspace(-4, 1, 128)\n",
|
||||
" pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)\n",
|
||||
" pk_fn = lambda x: jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape)\n",
|
||||
"\n",
|
||||
" # Create initial conditions\n",
|
||||
" initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=jax.random.PRNGKey(0))\n",
|
||||
"\n",
|
||||
" # Create particles\n",
|
||||
" cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)\n",
|
||||
" \n",
|
||||
" # Initial displacement\n",
|
||||
" dx, p, f = lpt(cosmo, initial_conditions, 0.1)\n",
|
||||
" \n",
|
||||
" # Evolve the simulation forward\n",
|
||||
" res = odeint(make_ode_fn(mesh_shape), [dx, p], snapshots, cosmo, rtol=1e-5, atol=1e-5)\n",
|
||||
" \n",
|
||||
" # Return the simulation volume at requested \n",
|
||||
"\n",
|
||||
" return initial_conditions , dx , res[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e9c9fd56",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"initial_conditions , lpt_displacements , ode_displacements = run_simulation(0.25, 0.8)\n",
|
||||
"%timeit initial_conditions , lpt_displacements , ode_displacements = run_simulation(0.25, 0.8)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "33b5e684",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from visualize import plot_fields\n",
|
||||
"\n",
|
||||
"fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint(jnp.zeros(mesh_shape) ,lpt_displacements)}\n",
|
||||
"for i , field in enumerate(ode_displacements):\n",
|
||||
" fields[f\"field_{i}\"] = cic_paint(jnp.zeros(mesh_shape) , field)\n",
|
||||
"plot_fields(fields)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4e050871",
|
||||
"metadata": {
|
||||
"id": "4e050871"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install diffrax\n",
|
||||
"from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "43504a1b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mesh_shape= [256, 256, 256]\n",
|
||||
"box_size = [256.,256.,256.]\n",
|
||||
"snapshots = jnp.linspace(0.1,1.,3)\n",
|
||||
"\n",
|
||||
"@jax.jit\n",
|
||||
"def run_simulation(omega_c, sigma8):\n",
|
||||
" # Create a small function to generate the matter power spectrum\n",
|
||||
" k = jnp.logspace(-4, 1, 128)\n",
|
||||
" pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)\n",
|
||||
" pk_fn = lambda x: jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape)\n",
|
||||
"\n",
|
||||
" # Create initial conditions\n",
|
||||
" initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=jax.random.PRNGKey(0))\n",
|
||||
"\n",
|
||||
" # Create particles\n",
|
||||
" cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)\n",
|
||||
" \n",
|
||||
" # Initial displacement\n",
|
||||
" dx, p, f = lpt(cosmo, initial_conditions, 0.1)\n",
|
||||
" \n",
|
||||
" # Evolve the simulation forward\n",
|
||||
" ode_fn = make_ode_fn(mesh_shape)\n",
|
||||
" term = ODETerm(\n",
|
||||
" lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))\n",
|
||||
" solver = LeapfrogMidpoint()\n",
|
||||
"\n",
|
||||
" stepsize_controller = ConstantStepSize()\n",
|
||||
" res = diffeqsolve(term,\n",
|
||||
" solver,\n",
|
||||
" t0=0.1,\n",
|
||||
" t1=1.,\n",
|
||||
" dt0=0.01,\n",
|
||||
" y0=jnp.stack([dx, p], axis=0),\n",
|
||||
" args=cosmo,\n",
|
||||
" saveat=SaveAt(ts=snapshots),\n",
|
||||
" stepsize_controller=stepsize_controller)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" return initial_conditions , dx , res.ys , res.stats"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "19949ff1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)\n",
|
||||
"%timeit initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)\n",
|
||||
"print(f\"Solver Stats : {solver_stats}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "76a26e98",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from visualize import plot_fields\n",
|
||||
"\n",
|
||||
"fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint(jnp.zeros(mesh_shape) ,lpt_displacements)}\n",
|
||||
"for i , field in enumerate(ode_solutions):\n",
|
||||
" fields[f\"field_{i}\"] = cic_paint(jnp.zeros(mesh_shape) , field[0])\n",
|
||||
"plot_fields(fields)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"include_colab_link": true,
|
||||
"name": "Introduction.ipynb",
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
242
notebooks/02-MultiGPU_PM.ipynb
Normal file
242
notebooks/02-MultiGPU_PM.ipynb
Normal file
|
@ -0,0 +1,242 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "view-in-github"
|
||||
},
|
||||
"source": [
|
||||
"<a href=\"https://colab.research.google.com/github/DifferentiableUniverseInitiative/JaxPM/blob/main/notebooks/Introduction.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "9Jy5BL1XiK1s",
|
||||
"metadata": {
|
||||
"id": "9Jy5BL1XiK1s"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install --quiet git+https://github.com/DifferentiableUniverseInitiative/JaxPM.git\n",
|
||||
"!pip install diffrax"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "c5f42bbe",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "c5f42bbe",
|
||||
"outputId": "a7841b28-5f20-4856-bd1d-f8a3572095b5"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"%pylab is deprecated, use %matplotlib inline and import the required libraries.\n",
|
||||
"Populating the interactive namespace from numpy and matplotlib\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import jax\n",
|
||||
"import jax.numpy as jnp\n",
|
||||
"import jax_cosmo as jc\n",
|
||||
"\n",
|
||||
"from jax.experimental.ode import odeint\n",
|
||||
"from jaxpm.kernels import interpolate_power_spectrum\n",
|
||||
"from jaxpm.painting import cic_paint\n",
|
||||
"from jaxpm.pm import linear_field, lpt, make_ode_fn\n",
|
||||
"from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "38df34e3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"assert jax.device_count() >= 8, \"This notebook requires a TPU or GPU runtime with 8 devices\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9edd2246",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from jax.experimental.mesh_utils import create_device_mesh\n",
|
||||
"from jax.experimental.multihost_utils import process_allgather\n",
|
||||
"from jax.sharding import Mesh, NamedSharding\n",
|
||||
"from jax.sharding import PartitionSpec as P\n",
|
||||
"from functools import partial\n",
|
||||
"\n",
|
||||
"all_gather = partial(process_allgather, tiled=True)\n",
|
||||
"\n",
|
||||
"pdims = (2, 4)\n",
|
||||
"devices = create_device_mesh(pdims)\n",
|
||||
"mesh = Mesh(devices, axis_names=('x', 'y'))\n",
|
||||
"sharding = NamedSharding(mesh, P('x', 'y'))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "281b4d3b",
|
||||
"metadata": {
|
||||
"id": "281b4d3b"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mesh_shape = [1024, 1024, 1024]\n",
|
||||
"box_size = [1024., 1024., 1024.]\n",
|
||||
"halo_size = 128\n",
|
||||
"snapshots = jnp.linspace(0.1, 1., 3)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@jax.jit\n",
|
||||
"def run_simulation(omega_c, sigma8):\n",
|
||||
" # Create a small function to generate the matter power spectrum\n",
|
||||
" k = jnp.logspace(-4, 1, 128)\n",
|
||||
" pk = jc.power.linear_matter_power(\n",
|
||||
" jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)\n",
|
||||
" pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)\n",
|
||||
"\n",
|
||||
" # Create initial conditions\n",
|
||||
" initial_conditions = linear_field(mesh_shape,\n",
|
||||
" box_size,\n",
|
||||
" pk_fn,\n",
|
||||
" seed=jax.random.PRNGKey(0),\n",
|
||||
" sharding=sharding)\n",
|
||||
"\n",
|
||||
" # Create particles\n",
|
||||
" particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),\n",
|
||||
" axis=-1).reshape([-1, 3])\n",
|
||||
" cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)\n",
|
||||
"\n",
|
||||
" # Initial displacement\n",
|
||||
" dx, p, f = lpt(cosmo,\n",
|
||||
" initial_conditions,\n",
|
||||
" particles,\n",
|
||||
" 0.1,\n",
|
||||
" halo_size=halo_size,\n",
|
||||
" sharding=sharding)\n",
|
||||
"\n",
|
||||
" # Evolve the simulation forward\n",
|
||||
" ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)\n",
|
||||
" term = ODETerm(\n",
|
||||
" lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))\n",
|
||||
" solver = LeapfrogMidpoint()\n",
|
||||
"\n",
|
||||
" stepsize_controller = ConstantStepSize()\n",
|
||||
" res = diffeqsolve(term,\n",
|
||||
" solver,\n",
|
||||
" t0=0.1,\n",
|
||||
" t1=1.,\n",
|
||||
" dt0=0.01,\n",
|
||||
" y0=jnp.stack([dx, p], axis=0),\n",
|
||||
" args=cosmo,\n",
|
||||
" saveat=SaveAt(ts=snapshots),\n",
|
||||
" stepsize_controller=stepsize_controller)\n",
|
||||
"\n",
|
||||
" return initial_conditions, dx, res.ys, res.stats"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "826be667",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "826be667",
|
||||
"outputId": "dc43b5c4-a004-41bf-f2c8-128c17cf4de1"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)\n",
|
||||
"%timeit initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)\n",
|
||||
"print(f\"Solver Stats : {solver_stats}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "042cc55c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"initial_conditions = all_gather(initial_conditions)\n",
|
||||
"lpt_particles = all_gather(lpt_particles)\n",
|
||||
"ode_particles = [all_gather(p) for p in ode_particles]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4e012ce8",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 323
|
||||
},
|
||||
"id": "4e012ce8",
|
||||
"outputId": "75390318-8072-481f-ffb9-ec09cd71cb1d"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"shape of grid_mesh: (256, 256, 256)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from visualize import plot_fields\n",
|
||||
"\n",
|
||||
"fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint(jnp.zeros(mesh_shape) ,lpt_particles)}\n",
|
||||
"for i , field in enumerate(ode_particles):\n",
|
||||
" fields[f\"field_{i}\"] = cic_paint(jnp.zeros(mesh_shape) , field)\n",
|
||||
"plot_fields(fields)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"include_colab_link": true,
|
||||
"name": "Introduction.ipynb",
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
157
notebooks/03-MultiHost_PM.ipynb
Normal file
157
notebooks/03-MultiHost_PM.ipynb
Normal file
|
@ -0,0 +1,157 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "22803ddc",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "view-in-github"
|
||||
},
|
||||
"source": [
|
||||
"<a href=\"https://colab.research.google.com/github/DifferentiableUniverseInitiative/JaxPM/blob/main/notebooks/Introduction.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "9Jy5BL1XiK1s",
|
||||
"metadata": {
|
||||
"id": "9Jy5BL1XiK1s"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install --quiet git+https://github.com/DifferentiableUniverseInitiative/JaxPM.git\n",
|
||||
"!pip install diffrax"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "c5f42bbe",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "c5f42bbe",
|
||||
"outputId": "a7841b28-5f20-4856-bd1d-f8a3572095b5"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"%pylab is deprecated, use %matplotlib inline and import the required libraries.\n",
|
||||
"Populating the interactive namespace from numpy and matplotlib\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!salloc --account=tkc@a100 -C a100 --gres=gpu:8 --ntasks-per-node=8 --time=00:30:00 --cpus-per-task=8 --hint=nomultithread --qos=qos_gpu-dev --nodes=2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7ebdfc00",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!squeue -u $USER"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c014316c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"export JOB_ID=123456"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b7eabac5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!srun --jobid=$JOB_ID -n 16 python 03-MultiHost_PM.py"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "472dd4bf",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"data = np.load(\"multihost_pm.npz\")\n",
|
||||
"initial_conditions = data['initial_conditions']\n",
|
||||
"lpt_displacements = data['lpt_displacements']\n",
|
||||
"ode_solutions = data['ode_solutions']\n",
|
||||
"solver_stats = data['solver_stats']\n",
|
||||
"print(f\"Solver stats: {solver_stats}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4e012ce8",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 323
|
||||
},
|
||||
"id": "4e012ce8",
|
||||
"outputId": "75390318-8072-481f-ffb9-ec09cd71cb1d"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"shape of grid_mesh: (256, 256, 256)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from visualize import plot_fields\n",
|
||||
"\n",
|
||||
"fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint(jnp.zeros(mesh_shape) ,lpt_particles)}\n",
|
||||
"for i , field in enumerate(ode_particles):\n",
|
||||
" fields[f\"field_{i}\"] = cic_paint(jnp.zeros(mesh_shape) , field)\n",
|
||||
"plot_fields(fields)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"include_colab_link": true,
|
||||
"name": "Introduction.ipynb",
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
104
notebooks/03-MultiHost_PM.py
Normal file
104
notebooks/03-MultiHost_PM.py
Normal file
|
@ -0,0 +1,104 @@
|
|||
import os
|
||||
|
||||
os.environ["EQX_ON_ERROR"] = "nan" # avoid an allgather caused by diffrax
|
||||
import jax
|
||||
|
||||
jax.distributed.initialize()
|
||||
rank = jax.process_index()
|
||||
size = jax.process_count()
|
||||
|
||||
import jax.numpy as jnp
|
||||
import jax_cosmo as jc
|
||||
|
||||
from jaxpm.kernels import interpolate_power_spectrum
|
||||
from jaxpm.painting import cic_paint_dx
|
||||
from jaxpm.pm import linear_field, lpt, make_ode_fn
|
||||
|
||||
from jax.experimental.mesh_utils import create_device_mesh
|
||||
from jax.experimental.multihost_utils import process_allgather
|
||||
from jax.sharding import Mesh, NamedSharding
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
|
||||
from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve
|
||||
|
||||
all_gather = partial(process_allgather, tiled=True)
|
||||
|
||||
pdims = (2, 4)
|
||||
devices = create_device_mesh(pdims)
|
||||
mesh = Mesh(devices, axis_names=('x', 'y'))
|
||||
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||
|
||||
mesh_shape = [2024, 1024, 1024]
|
||||
box_size = [1024., 1024., 1024.]
|
||||
halo_size = 512
|
||||
snapshots = jnp.linspace(0.1, 1., 2)
|
||||
|
||||
|
||||
@jax.jit
|
||||
def run_simulation(omega_c, sigma8):
|
||||
# Create a small function to generate the matter power spectrum
|
||||
k = jnp.logspace(-4, 1, 128)
|
||||
pk = jc.power.linear_matter_power(
|
||||
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
|
||||
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)
|
||||
|
||||
# Create initial conditions
|
||||
initial_conditions = linear_field(mesh_shape,
|
||||
box_size,
|
||||
pk_fn,
|
||||
seed=jax.random.PRNGKey(0),
|
||||
sharding=sharding)
|
||||
|
||||
# Create particles
|
||||
particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),
|
||||
axis=-1).reshape([-1, 3])
|
||||
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
|
||||
|
||||
# Initial displacement
|
||||
dx, p, _ = lpt(cosmo,
|
||||
initial_conditions,
|
||||
particles,
|
||||
0.1,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
|
||||
# Evolve the simulation forward
|
||||
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)
|
||||
term = ODETerm(
|
||||
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
|
||||
solver = LeapfrogMidpoint()
|
||||
|
||||
stepsize_controller = ConstantStepSize()
|
||||
res = diffeqsolve(term,
|
||||
solver,
|
||||
t0=0.1,
|
||||
t1=1.,
|
||||
dt0=0.01,
|
||||
y0=jnp.stack([dx, p], axis=0),
|
||||
args=cosmo,
|
||||
saveat=SaveAt(ts=snapshots),
|
||||
stepsize_controller=stepsize_controller)
|
||||
|
||||
return initial_conditions, dx, res.ys, res.stats
|
||||
|
||||
|
||||
initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)
|
||||
print(f"[{rank}] Simulation completed")
|
||||
print(f"[{rank}] Solver stats: {solver_stats}")
|
||||
|
||||
# Gather the results
|
||||
initial_conditions = all_gather(initial_conditions)
|
||||
lpt_displacements = all_gather(lpt_displacements)
|
||||
ode_solutions = [all_gather(sol) for sol in ode_solutions]
|
||||
|
||||
if rank == 0:
|
||||
np.savez("multihost_pm.npz",
|
||||
initial_conditions=initial_conditions,
|
||||
lpt_displacements=lpt_displacements,
|
||||
ode_solutions=ode_solutions,
|
||||
solver_stats=solver_stats)
|
||||
|
||||
print(f"[{rank}] Simulation results saved")
|
||||
|
|
@ -1,192 +0,0 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "view-in-github"
|
||||
},
|
||||
"source": [
|
||||
"<a href=\"https://colab.research.google.com/github/DifferentiableUniverseInitiative/JaxPM/blob/main/notebooks/Introduction.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "9Jy5BL1XiK1s",
|
||||
"metadata": {
|
||||
"id": "9Jy5BL1XiK1s"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install --quiet git+https://github.com/DifferentiableUniverseInitiative/JaxPM.git@ASKabalan/jaxdecomp_proto"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "c5f42bbe",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "c5f42bbe",
|
||||
"outputId": "a7841b28-5f20-4856-bd1d-f8a3572095b5"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"%pylab is deprecated, use %matplotlib inline and import the required libraries.\n",
|
||||
"Populating the interactive namespace from numpy and matplotlib\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%pylab inline\n",
|
||||
"import os\n",
|
||||
"import jax\n",
|
||||
"import jax.numpy as jnp\n",
|
||||
"import jax_cosmo as jc\n",
|
||||
"\n",
|
||||
"from jax.experimental.ode import odeint\n",
|
||||
"\n",
|
||||
"from jaxpm.painting import cic_paint\n",
|
||||
"from jaxpm.pm import linear_field, lpt, make_ode_fn"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "281b4d3b",
|
||||
"metadata": {
|
||||
"id": "281b4d3b"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mesh_shape= [256, 256, 256]\n",
|
||||
"box_size = [256.,256.,256.]\n",
|
||||
"snapshots = jnp.linspace(0.1,1.,2)\n",
|
||||
"\n",
|
||||
"@jax.jit\n",
|
||||
"def run_simulation(omega_c, sigma8):\n",
|
||||
" # Create a small function to generate the matter power spectrum\n",
|
||||
" k = jnp.logspace(-4, 1, 128)\n",
|
||||
" pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)\n",
|
||||
" pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape([-1]), k, pk).reshape(x.shape)\n",
|
||||
"\n",
|
||||
" # Create initial conditions\n",
|
||||
" initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=jax.random.PRNGKey(0))\n",
|
||||
"\n",
|
||||
" # Create particles\n",
|
||||
" particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),axis=-1).reshape([-1,3])\n",
|
||||
"\n",
|
||||
" cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)\n",
|
||||
" \n",
|
||||
" # Initial displacement\n",
|
||||
" dx, p, f = lpt(cosmo, initial_conditions, particles, 0.1)\n",
|
||||
" \n",
|
||||
" # Evolve the simulation forward\n",
|
||||
" res = odeint(make_ode_fn(mesh_shape), [particles+dx, p], snapshots, cosmo, rtol=1e-5, atol=1e-5)\n",
|
||||
" \n",
|
||||
" # Return the simulation volume at requested \n",
|
||||
" return res[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "826be667",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "826be667",
|
||||
"outputId": "dc43b5c4-a004-41bf-f2c8-128c17cf4de1"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"particles are Traced<ShapedArray(int32[16777216,3])>with<DynamicJaxprTrace(level=1/0)>\n",
|
||||
"pm_forces particles are Traced<ShapedArray(int32[16777216,3])>with<DynamicJaxprTrace(level=1/0)>\n",
|
||||
"shape of displacement: (256, 256, 256)\n",
|
||||
"pm_forces particles are Traced<ShapedArray(float32[16777216,3])>with<DynamicJaxprTrace(level=2/0)>\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"res = run_simulation(0.25, 0.8)\n",
|
||||
"#%timeit res = run_simulation(0.25, 0.8)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4e012ce8",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 323
|
||||
},
|
||||
"id": "4e012ce8",
|
||||
"outputId": "75390318-8072-481f-ffb9-ec09cd71cb1d"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"shape of grid_mesh: (256, 256, 256)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"figure(figsize=[10,5])\n",
|
||||
"subplot(121)\n",
|
||||
"imshow(cic_paint(jnp.zeros(mesh_shape), res[0]).sum(axis=0),cmap='magma')\n",
|
||||
"subplot(122)\n",
|
||||
"imshow(cic_paint(jnp.zeros(mesh_shape), res[1]).sum(axis=0),cmap='magma')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4e050871",
|
||||
"metadata": {
|
||||
"id": "4e050871"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"include_colab_link": true,
|
||||
"name": "Introduction.ipynb",
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
38
notebooks/visualize.py
Normal file
38
notebooks/visualize.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def plot_fields(fields_dict, sum_over=None):
|
||||
"""
|
||||
Plots sum projections of 3D fields along different axes,
|
||||
slicing only the first `sum_over` elements along each axis.
|
||||
|
||||
Args:
|
||||
- fields: list of 3D arrays representing fields to plot
|
||||
- names: list of names for each field, used in titles
|
||||
- sum_over: number of slices to sum along each axis (default: fields[0].shape[0] // 8)
|
||||
"""
|
||||
sum_over = sum_over or list(fields_dict.values())[0].shape[0] // 8
|
||||
nb_rows = len(fields_dict)
|
||||
nb_cols = 3
|
||||
fig, axes = plt.subplots(nb_rows, nb_cols, figsize=(15, 5 * nb_rows))
|
||||
|
||||
def plot_subplots(proj_axis, field, row, title):
|
||||
slicing = [slice(None)] * field.ndim
|
||||
slicing[proj_axis] = slice(None, sum_over)
|
||||
slicing = tuple(slicing)
|
||||
|
||||
# Sum projection over the specified axis and plot
|
||||
axes[row, proj_axis].imshow(field[slicing].sum(axis=proj_axis) + 1,
|
||||
cmap='magma', extent=[0, field.shape[proj_axis], 0, field.shape[proj_axis]])
|
||||
axes[row, proj_axis].set_xlabel('Mpc/h')
|
||||
axes[row, proj_axis].set_ylabel('Mpc/h')
|
||||
axes[row, proj_axis].set_title(title)
|
||||
|
||||
# Plot each field across the three axes
|
||||
for i, (name, field) in enumerate(fields_dict.items()):
|
||||
for proj_axis in range(3):
|
||||
plot_subplots(proj_axis, field, i, f"{name} projection {proj_axis}")
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
Loading…
Add table
Reference in a new issue