mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-08 04:40:53 +00:00
update notebooks
This commit is contained in:
parent
21373b89ee
commit
36ef18e3d0
5 changed files with 88 additions and 80 deletions
File diff suppressed because one or more lines are too long
|
@ -21,17 +21,16 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import os\n",
|
|
||||||
"import jax\n",
|
"import jax\n",
|
||||||
"import jax.numpy as jnp\n",
|
"import jax.numpy as jnp\n",
|
||||||
"import jax_cosmo as jc\n",
|
"import jax_cosmo as jc\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from jaxpm.painting import cic_paint , cic_paint_dx\n",
|
"from jaxpm.painting import cic_paint , cic_paint_dx\n",
|
||||||
"from jaxpm.pm import linear_field, lpt, make_ode_fn\n",
|
"from jaxpm.pm import linear_field, lpt, make_diffrax_ode\n",
|
||||||
"from jaxpm.distributed import uniform_particles\n",
|
"from jaxpm.distributed import uniform_particles\n",
|
||||||
"from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve"
|
"from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve"
|
||||||
]
|
]
|
||||||
|
@ -84,13 +83,12 @@
|
||||||
" dx, p, f = lpt(cosmo, initial_conditions, a=0.1,order=1)\n",
|
" dx, p, f = lpt(cosmo, initial_conditions, a=0.1,order=1)\n",
|
||||||
" \n",
|
" \n",
|
||||||
" # Evolve the simulation forward\n",
|
" # Evolve the simulation forward\n",
|
||||||
" ode_fn = make_ode_fn(mesh_shape)\n",
|
" ode_fn = ODETerm(\n",
|
||||||
" term = ODETerm(\n",
|
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
|
||||||
" lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))\n",
|
|
||||||
" solver = LeapfrogMidpoint()\n",
|
" solver = LeapfrogMidpoint()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" stepsize_controller = ConstantStepSize()\n",
|
" stepsize_controller = ConstantStepSize()\n",
|
||||||
" res = diffeqsolve(term,\n",
|
" res = diffeqsolve(ode_fn,\n",
|
||||||
" solver,\n",
|
" solver,\n",
|
||||||
" t0=0.1,\n",
|
" t0=0.1,\n",
|
||||||
" t1=1.,\n",
|
" t1=1.,\n",
|
||||||
|
@ -258,13 +256,12 @@
|
||||||
" dx, p, f = lpt(cosmo, initial_conditions,particles=particles,a=0.1,order=2)\n",
|
" dx, p, f = lpt(cosmo, initial_conditions,particles=particles,a=0.1,order=2)\n",
|
||||||
" \n",
|
" \n",
|
||||||
" # Evolve the simulation forward\n",
|
" # Evolve the simulation forward\n",
|
||||||
" ode_fn = make_ode_fn(mesh_shape,particles=particles)\n",
|
" ode_fn = ODETerm(\n",
|
||||||
" term = ODETerm(\n",
|
" make_diffrax_ode(cosmo, mesh_shape))\n",
|
||||||
" lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))\n",
|
|
||||||
" solver = LeapfrogMidpoint()\n",
|
" solver = LeapfrogMidpoint()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" stepsize_controller = ConstantStepSize()\n",
|
" stepsize_controller = ConstantStepSize()\n",
|
||||||
" res = diffeqsolve(term,\n",
|
" res = diffeqsolve(ode_fn,\n",
|
||||||
" solver,\n",
|
" solver,\n",
|
||||||
" t0=0.1,\n",
|
" t0=0.1,\n",
|
||||||
" t1=1.,\n",
|
" t1=1.,\n",
|
||||||
|
@ -336,7 +333,6 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from math import prod\n",
|
|
||||||
"from jaxpm.plotting import plot_fields_single_projection\n",
|
"from jaxpm.plotting import plot_fields_single_projection\n",
|
||||||
"\n",
|
"\n",
|
||||||
"center = slice(mesh_shape[0] // 4, 3 * mesh_shape[0] // 4 )\n",
|
"center = slice(mesh_shape[0] // 4, 3 * mesh_shape[0] // 4 )\n",
|
||||||
|
@ -379,7 +375,6 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from math import prod\n",
|
|
||||||
"from jaxpm.plotting import plot_fields_single_projection\n",
|
"from jaxpm.plotting import plot_fields_single_projection\n",
|
||||||
"\n",
|
"\n",
|
||||||
"center = slice(mesh_shape[0] // 4, 3 * mesh_shape[0] // 4 )\n",
|
"center = slice(mesh_shape[0] // 4, 3 * mesh_shape[0] // 4 )\n",
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
"id": "7fb27b941602401d91542211134fc71a",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab_type": "text",
|
"colab_type": "text",
|
||||||
"id": "view-in-github"
|
"id": "view-in-github"
|
||||||
|
@ -45,7 +46,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": null,
|
||||||
"id": "c5f42bbe",
|
"id": "c5f42bbe",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
|
@ -63,11 +64,9 @@
|
||||||
"import jax_cosmo as jc\n",
|
"import jax_cosmo as jc\n",
|
||||||
"from jax.debug import visualize_array_sharding\n",
|
"from jax.debug import visualize_array_sharding\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from jax.experimental.ode import odeint\n",
|
|
||||||
"from jaxpm.kernels import interpolate_power_spectrum\n",
|
"from jaxpm.kernels import interpolate_power_spectrum\n",
|
||||||
"from jaxpm.painting import cic_paint , cic_paint_dx\n",
|
"from jaxpm.painting import cic_paint_dx\n",
|
||||||
"from jaxpm.pm import linear_field, lpt, make_ode_fn\n",
|
"from jaxpm.pm import linear_field, lpt, make_diffrax_ode\n",
|
||||||
"from jaxpm.distributed import uniform_particles\n",
|
|
||||||
"from functools import partial\n",
|
"from functools import partial\n",
|
||||||
"from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve"
|
"from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve"
|
||||||
]
|
]
|
||||||
|
@ -109,7 +108,6 @@
|
||||||
"from jax.experimental.multihost_utils import process_allgather\n",
|
"from jax.experimental.multihost_utils import process_allgather\n",
|
||||||
"from jax.sharding import Mesh, NamedSharding\n",
|
"from jax.sharding import Mesh, NamedSharding\n",
|
||||||
"from jax.sharding import PartitionSpec as P\n",
|
"from jax.sharding import PartitionSpec as P\n",
|
||||||
"from functools import partial\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"all_gather = partial(process_allgather, tiled=False)\n",
|
"all_gather = partial(process_allgather, tiled=False)\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -131,7 +129,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": null,
|
||||||
"id": "281b4d3b",
|
"id": "281b4d3b",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "281b4d3b"
|
"id": "281b4d3b"
|
||||||
|
@ -181,12 +179,12 @@
|
||||||
" sharding=sharding)\n",
|
" sharding=sharding)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Evolve the simulation forward\n",
|
" # Evolve the simulation forward\n",
|
||||||
" ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)\n",
|
" ode_fn = ODETerm(\n",
|
||||||
" term = ODETerm(lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))\n",
|
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
|
||||||
" solver = LeapfrogMidpoint()\n",
|
" solver = LeapfrogMidpoint()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" stepsize_controller = ConstantStepSize()\n",
|
" stepsize_controller = ConstantStepSize()\n",
|
||||||
" res = diffeqsolve(term,\n",
|
" res = diffeqsolve(ode_fn,\n",
|
||||||
" solver,\n",
|
" solver,\n",
|
||||||
" t0=0.1,\n",
|
" t0=0.1,\n",
|
||||||
" t1=1.,\n",
|
" t1=1.,\n",
|
||||||
|
@ -364,7 +362,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": null,
|
||||||
"id": "8c647b13",
|
"id": "8c647b13",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
|
@ -411,12 +409,12 @@
|
||||||
" sharding=sharding)\n",
|
" sharding=sharding)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Evolve the simulation forward\n",
|
" # Evolve the simulation forward\n",
|
||||||
" ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)\n",
|
" ode_fn = ODETerm(\n",
|
||||||
" term = ODETerm(lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))\n",
|
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
|
||||||
" solver = LeapfrogMidpoint()\n",
|
" solver = LeapfrogMidpoint()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" stepsize_controller = ConstantStepSize()\n",
|
" stepsize_controller = ConstantStepSize()\n",
|
||||||
" res = diffeqsolve(term,\n",
|
" res = diffeqsolve(ode_fn,\n",
|
||||||
" solver,\n",
|
" solver,\n",
|
||||||
" t0=0.1,\n",
|
" t0=0.1,\n",
|
||||||
" t1=1.,\n",
|
" t1=1.,\n",
|
||||||
|
@ -651,7 +649,6 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from math import prod\n",
|
|
||||||
"from jaxpm.plotting import plot_fields_single_projection\n",
|
"from jaxpm.plotting import plot_fields_single_projection\n",
|
||||||
"\n",
|
"\n",
|
||||||
"field = ode_solutions[0]\n",
|
"field = ode_solutions[0]\n",
|
||||||
|
|
|
@ -11,7 +11,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 8,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -20,13 +20,10 @@
|
||||||
"import jax\n",
|
"import jax\n",
|
||||||
"import jax.numpy as jnp\n",
|
"import jax.numpy as jnp\n",
|
||||||
"import jax_cosmo as jc\n",
|
"import jax_cosmo as jc\n",
|
||||||
"from jax.debug import visualize_array_sharding\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"from jax.experimental.ode import odeint\n",
|
|
||||||
"from jaxpm.kernels import interpolate_power_spectrum\n",
|
"from jaxpm.kernels import interpolate_power_spectrum\n",
|
||||||
"from jaxpm.painting import cic_paint , cic_paint_dx\n",
|
"from jaxpm.painting import cic_paint_dx\n",
|
||||||
"from jaxpm.pm import linear_field, lpt, make_ode_fn\n",
|
"from jaxpm.pm import linear_field, lpt, make_diffrax_ode\n",
|
||||||
"from jaxpm.distributed import uniform_particles\n",
|
|
||||||
"from functools import partial\n",
|
"from functools import partial\n",
|
||||||
"from diffrax import ConstantStepSize, LeapfrogMidpoint,Dopri5 , PIDController , ODETerm, SaveAt, diffeqsolve"
|
"from diffrax import ConstantStepSize, LeapfrogMidpoint,Dopri5 , PIDController , ODETerm, SaveAt, diffeqsolve"
|
||||||
]
|
]
|
||||||
|
@ -82,7 +79,6 @@
|
||||||
"from jax.experimental.multihost_utils import process_allgather\n",
|
"from jax.experimental.multihost_utils import process_allgather\n",
|
||||||
"from jax.sharding import Mesh, NamedSharding\n",
|
"from jax.sharding import Mesh, NamedSharding\n",
|
||||||
"from jax.sharding import PartitionSpec as P\n",
|
"from jax.sharding import PartitionSpec as P\n",
|
||||||
"from functools import partial\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"all_gather = partial(process_allgather, tiled=False)\n",
|
"all_gather = partial(process_allgather, tiled=False)\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -94,7 +90,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -127,12 +123,12 @@
|
||||||
" sharding=sharding)\n",
|
" sharding=sharding)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Evolve the simulation forward\n",
|
" # Evolve the simulation forward\n",
|
||||||
" ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)\n",
|
" ode_fn = ODETerm(\n",
|
||||||
" term = ODETerm(lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))\n",
|
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
|
||||||
" solver = LeapfrogMidpoint()\n",
|
" solver = LeapfrogMidpoint()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" stepsize_controller = ConstantStepSize()\n",
|
" stepsize_controller = ConstantStepSize()\n",
|
||||||
" res = diffeqsolve(term,\n",
|
" res = diffeqsolve(ode_fn,\n",
|
||||||
" solver,\n",
|
" solver,\n",
|
||||||
" t0=0.1,\n",
|
" t0=0.1,\n",
|
||||||
" t1=1.,\n",
|
" t1=1.,\n",
|
||||||
|
@ -244,7 +240,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 9,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -291,12 +287,12 @@
|
||||||
" sharding=sharding)\n",
|
" sharding=sharding)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Evolve the simulation forward\n",
|
" # Evolve the simulation forward\n",
|
||||||
" ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)\n",
|
" ode_fn = ODETerm(\n",
|
||||||
" term = ODETerm(lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))\n",
|
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
|
||||||
" solver = Dopri5()\n",
|
" solver = Dopri5()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" stepsize_controller = PIDController(rtol=1e-5,atol=1e-5)\n",
|
" stepsize_controller = PIDController(rtol=1e-5,atol=1e-5)\n",
|
||||||
" res = diffeqsolve(term,\n",
|
" res = diffeqsolve(ode_fn,\n",
|
||||||
" solver,\n",
|
" solver,\n",
|
||||||
" t0=0.1,\n",
|
" t0=0.1,\n",
|
||||||
" t1=1.,\n",
|
" t1=1.,\n",
|
||||||
|
|
|
@ -18,14 +18,13 @@ import numpy as np
|
||||||
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
|
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
|
||||||
PIDController, SaveAt, diffeqsolve)
|
PIDController, SaveAt, diffeqsolve)
|
||||||
from jax.experimental.mesh_utils import create_device_mesh
|
from jax.experimental.mesh_utils import create_device_mesh
|
||||||
from jax.experimental.multihost_utils import (process_allgather,
|
from jax.experimental.multihost_utils import (process_allgather)
|
||||||
sync_global_devices)
|
|
||||||
from jax.sharding import Mesh, NamedSharding
|
from jax.sharding import Mesh, NamedSharding
|
||||||
from jax.sharding import PartitionSpec as P
|
from jax.sharding import PartitionSpec as P
|
||||||
|
|
||||||
from jaxpm.kernels import interpolate_power_spectrum
|
from jaxpm.kernels import interpolate_power_spectrum
|
||||||
from jaxpm.painting import cic_paint_dx
|
from jaxpm.painting import cic_paint_dx
|
||||||
from jaxpm.pm import linear_field, lpt, make_ode_fn
|
from jaxpm.pm import linear_field, lpt, make_diffrax_ode
|
||||||
|
|
||||||
all_gather = partial(process_allgather, tiled=True)
|
all_gather = partial(process_allgather, tiled=True)
|
||||||
|
|
||||||
|
@ -86,7 +85,7 @@ def create_mesh_and_sharding(pdims):
|
||||||
|
|
||||||
@partial(jax.jit, static_argnums=(2, 3, 4, 5, 6))
|
@partial(jax.jit, static_argnums=(2, 3, 4, 5, 6))
|
||||||
def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
|
def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
|
||||||
solver_choice,nb_snapshots, sharding):
|
solver_choice, nb_snapshots, sharding):
|
||||||
k = jnp.logspace(-4, 1, 128)
|
k = jnp.logspace(-4, 1, 128)
|
||||||
pk = jc.power.linear_matter_power(
|
pk = jc.power.linear_matter_power(
|
||||||
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
|
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
|
||||||
|
@ -106,15 +105,14 @@ def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
|
||||||
halo_size=halo_size,
|
halo_size=halo_size,
|
||||||
sharding=sharding)
|
sharding=sharding)
|
||||||
|
|
||||||
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)
|
ode_fn = ODETerm(
|
||||||
term = ODETerm(
|
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
|
||||||
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
|
|
||||||
|
|
||||||
# Choose solver
|
# Choose solver
|
||||||
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()
|
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()
|
||||||
stepsize_controller = ConstantStepSize(
|
stepsize_controller = ConstantStepSize(
|
||||||
) if solver_choice == "leapfrog" else PIDController(rtol=1e-5, atol=1e-5)
|
) if solver_choice == "leapfrog" else PIDController(rtol=1e-5, atol=1e-5)
|
||||||
res = diffeqsolve(term,
|
res = diffeqsolve(ode_fn,
|
||||||
solver,
|
solver,
|
||||||
t0=0.1,
|
t0=0.1,
|
||||||
t1=1.,
|
t1=1.,
|
||||||
|
@ -151,7 +149,6 @@ def main():
|
||||||
print(f"[{rank}] Simulation done")
|
print(f"[{rank}] Simulation done")
|
||||||
print(f"Solver stats: {solver_stats}")
|
print(f"Solver stats: {solver_stats}")
|
||||||
|
|
||||||
|
|
||||||
# Save initial conditions
|
# Save initial conditions
|
||||||
initial_conditions_g = all_gather(initial_conditions)
|
initial_conditions_g = all_gather(initial_conditions)
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
|
|
Loading…
Add table
Reference in a new issue