mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-17 00:20:55 +00:00
242 lines
7.9 KiB
Text
242 lines
7.9 KiB
Text
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "view-in-github"
|
|
},
|
|
"source": [
|
|
"<a href=\"https://colab.research.google.com/github/DifferentiableUniverseInitiative/JaxPM/blob/main/notebooks/Introduction.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "9Jy5BL1XiK1s",
|
|
"metadata": {
|
|
"id": "9Jy5BL1XiK1s"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"!pip install --quiet git+https://github.com/DifferentiableUniverseInitiative/JaxPM.git\n",
|
|
"!pip install diffrax"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "c5f42bbe",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "c5f42bbe",
|
|
"outputId": "a7841b28-5f20-4856-bd1d-f8a3572095b5"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"%pylab is deprecated, use %matplotlib inline and import the required libraries.\n",
|
|
"Populating the interactive namespace from numpy and matplotlib\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import os\n",
|
|
"import jax\n",
|
|
"import jax.numpy as jnp\n",
|
|
"import jax_cosmo as jc\n",
|
|
"\n",
|
|
"from jax.experimental.ode import odeint\n",
|
|
"from jaxpm.kernels import interpolate_power_spectrum\n",
|
|
"from jaxpm.painting import cic_paint\n",
|
|
"from jaxpm.pm import linear_field, lpt, make_ode_fn\n",
|
|
"from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "38df34e3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"assert jax.device_count() >= 8, \"This notebook requires a TPU or GPU runtime with 8 devices\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "9edd2246",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from jax.experimental.mesh_utils import create_device_mesh\n",
|
|
"from jax.experimental.multihost_utils import process_allgather\n",
|
|
"from jax.sharding import Mesh, NamedSharding\n",
|
|
"from jax.sharding import PartitionSpec as P\n",
|
|
"from functools import partial\n",
|
|
"\n",
|
|
"all_gather = partial(process_allgather, tiled=True)\n",
|
|
"\n",
|
|
"pdims = (2, 4)\n",
|
|
"devices = create_device_mesh(pdims)\n",
|
|
"mesh = Mesh(devices, axis_names=('x', 'y'))\n",
|
|
"sharding = NamedSharding(mesh, P('x', 'y'))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "281b4d3b",
|
|
"metadata": {
|
|
"id": "281b4d3b"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"mesh_shape = [1024, 1024, 1024]\n",
|
|
"box_size = [1024., 1024., 1024.]\n",
|
|
"halo_size = 128\n",
|
|
"snapshots = jnp.linspace(0.1, 1., 3)\n",
|
|
"\n",
|
|
"\n",
|
|
"@jax.jit\n",
|
|
"def run_simulation(omega_c, sigma8):\n",
|
|
" # Create a small function to generate the matter power spectrum\n",
|
|
" k = jnp.logspace(-4, 1, 128)\n",
|
|
" pk = jc.power.linear_matter_power(\n",
|
|
" jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)\n",
|
|
" pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)\n",
|
|
"\n",
|
|
" # Create initial conditions\n",
|
|
" initial_conditions = linear_field(mesh_shape,\n",
|
|
" box_size,\n",
|
|
" pk_fn,\n",
|
|
" seed=jax.random.PRNGKey(0),\n",
|
|
" sharding=sharding)\n",
|
|
"\n",
|
|
" # Create particles\n",
|
|
" particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),\n",
|
|
" axis=-1).reshape([-1, 3])\n",
|
|
" cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)\n",
|
|
"\n",
|
|
" # Initial displacement\n",
|
|
" dx, p, f = lpt(cosmo,\n",
|
|
" initial_conditions,\n",
|
|
" particles,\n",
|
|
" 0.1,\n",
|
|
" halo_size=halo_size,\n",
|
|
" sharding=sharding)\n",
|
|
"\n",
|
|
" # Evolve the simulation forward\n",
|
|
" ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)\n",
|
|
" term = ODETerm(\n",
|
|
" lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))\n",
|
|
" solver = LeapfrogMidpoint()\n",
|
|
"\n",
|
|
" stepsize_controller = ConstantStepSize()\n",
|
|
" res = diffeqsolve(term,\n",
|
|
" solver,\n",
|
|
" t0=0.1,\n",
|
|
" t1=1.,\n",
|
|
" dt0=0.01,\n",
|
|
" y0=jnp.stack([dx, p], axis=0),\n",
|
|
" args=cosmo,\n",
|
|
" saveat=SaveAt(ts=snapshots),\n",
|
|
" stepsize_controller=stepsize_controller)\n",
|
|
"\n",
|
|
" return initial_conditions, dx, res.ys, res.stats"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "826be667",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "826be667",
|
|
"outputId": "dc43b5c4-a004-41bf-f2c8-128c17cf4de1"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)\n",
|
|
"%timeit initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)\n",
|
|
"print(f\"Solver Stats : {solver_stats}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "042cc55c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"initial_conditions = all_gather(initial_conditions)\n",
|
|
"lpt_particles = all_gather(lpt_particles)\n",
|
|
"ode_particles = [all_gather(p) for p in ode_particles]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "4e012ce8",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 323
|
|
},
|
|
"id": "4e012ce8",
|
|
"outputId": "75390318-8072-481f-ffb9-ec09cd71cb1d"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"shape of grid_mesh: (256, 256, 256)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from visualize import plot_fields\n",
|
|
"\n",
|
|
"fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint(jnp.zeros(mesh_shape) ,lpt_particles)}\n",
|
|
"for i , field in enumerate(ode_particles):\n",
|
|
" fields[f\"field_{i}\"] = cic_paint(jnp.zeros(mesh_shape) , field)\n",
|
|
"plot_fields(fields)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"accelerator": "GPU",
|
|
"colab": {
|
|
"include_colab_link": true,
|
|
"name": "Introduction.ipynb",
|
|
"provenance": []
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.4"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|