mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 20:30:54 +00:00
update notebooks
This commit is contained in:
parent
21373b89ee
commit
36ef18e3d0
5 changed files with 88 additions and 80 deletions
File diff suppressed because one or more lines are too long
|
@ -21,17 +21,16 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import jax\n",
|
||||
"import jax.numpy as jnp\n",
|
||||
"import jax_cosmo as jc\n",
|
||||
"\n",
|
||||
"from jaxpm.painting import cic_paint , cic_paint_dx\n",
|
||||
"from jaxpm.pm import linear_field, lpt, make_ode_fn\n",
|
||||
"from jaxpm.pm import linear_field, lpt, make_diffrax_ode\n",
|
||||
"from jaxpm.distributed import uniform_particles\n",
|
||||
"from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve"
|
||||
]
|
||||
|
@ -84,13 +83,12 @@
|
|||
" dx, p, f = lpt(cosmo, initial_conditions, a=0.1,order=1)\n",
|
||||
" \n",
|
||||
" # Evolve the simulation forward\n",
|
||||
" ode_fn = make_ode_fn(mesh_shape)\n",
|
||||
" term = ODETerm(\n",
|
||||
" lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))\n",
|
||||
" ode_fn = ODETerm(\n",
|
||||
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
|
||||
" solver = LeapfrogMidpoint()\n",
|
||||
"\n",
|
||||
" stepsize_controller = ConstantStepSize()\n",
|
||||
" res = diffeqsolve(term,\n",
|
||||
" res = diffeqsolve(ode_fn,\n",
|
||||
" solver,\n",
|
||||
" t0=0.1,\n",
|
||||
" t1=1.,\n",
|
||||
|
@ -258,13 +256,12 @@
|
|||
" dx, p, f = lpt(cosmo, initial_conditions,particles=particles,a=0.1,order=2)\n",
|
||||
" \n",
|
||||
" # Evolve the simulation forward\n",
|
||||
" ode_fn = make_ode_fn(mesh_shape,particles=particles)\n",
|
||||
" term = ODETerm(\n",
|
||||
" lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))\n",
|
||||
" ode_fn = ODETerm(\n",
|
||||
" make_diffrax_ode(cosmo, mesh_shape))\n",
|
||||
" solver = LeapfrogMidpoint()\n",
|
||||
"\n",
|
||||
" stepsize_controller = ConstantStepSize()\n",
|
||||
" res = diffeqsolve(term,\n",
|
||||
" res = diffeqsolve(ode_fn,\n",
|
||||
" solver,\n",
|
||||
" t0=0.1,\n",
|
||||
" t1=1.,\n",
|
||||
|
@ -336,7 +333,6 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"from math import prod\n",
|
||||
"from jaxpm.plotting import plot_fields_single_projection\n",
|
||||
"\n",
|
||||
"center = slice(mesh_shape[0] // 4, 3 * mesh_shape[0] // 4 )\n",
|
||||
|
@ -379,7 +375,6 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"from math import prod\n",
|
||||
"from jaxpm.plotting import plot_fields_single_projection\n",
|
||||
"\n",
|
||||
"center = slice(mesh_shape[0] // 4, 3 * mesh_shape[0] // 4 )\n",
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7fb27b941602401d91542211134fc71a",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "view-in-github"
|
||||
|
@ -45,7 +46,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"id": "c5f42bbe",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
|
@ -63,11 +64,9 @@
|
|||
"import jax_cosmo as jc\n",
|
||||
"from jax.debug import visualize_array_sharding\n",
|
||||
"\n",
|
||||
"from jax.experimental.ode import odeint\n",
|
||||
"from jaxpm.kernels import interpolate_power_spectrum\n",
|
||||
"from jaxpm.painting import cic_paint , cic_paint_dx\n",
|
||||
"from jaxpm.pm import linear_field, lpt, make_ode_fn\n",
|
||||
"from jaxpm.distributed import uniform_particles\n",
|
||||
"from jaxpm.painting import cic_paint_dx\n",
|
||||
"from jaxpm.pm import linear_field, lpt, make_diffrax_ode\n",
|
||||
"from functools import partial\n",
|
||||
"from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve"
|
||||
]
|
||||
|
@ -109,7 +108,6 @@
|
|||
"from jax.experimental.multihost_utils import process_allgather\n",
|
||||
"from jax.sharding import Mesh, NamedSharding\n",
|
||||
"from jax.sharding import PartitionSpec as P\n",
|
||||
"from functools import partial\n",
|
||||
"\n",
|
||||
"all_gather = partial(process_allgather, tiled=False)\n",
|
||||
"\n",
|
||||
|
@ -131,7 +129,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": null,
|
||||
"id": "281b4d3b",
|
||||
"metadata": {
|
||||
"id": "281b4d3b"
|
||||
|
@ -181,12 +179,12 @@
|
|||
" sharding=sharding)\n",
|
||||
"\n",
|
||||
" # Evolve the simulation forward\n",
|
||||
" ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)\n",
|
||||
" term = ODETerm(lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))\n",
|
||||
" ode_fn = ODETerm(\n",
|
||||
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
|
||||
" solver = LeapfrogMidpoint()\n",
|
||||
"\n",
|
||||
" stepsize_controller = ConstantStepSize()\n",
|
||||
" res = diffeqsolve(term,\n",
|
||||
" res = diffeqsolve(ode_fn,\n",
|
||||
" solver,\n",
|
||||
" t0=0.1,\n",
|
||||
" t1=1.,\n",
|
||||
|
@ -364,7 +362,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": null,
|
||||
"id": "8c647b13",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
|
@ -411,12 +409,12 @@
|
|||
" sharding=sharding)\n",
|
||||
"\n",
|
||||
" # Evolve the simulation forward\n",
|
||||
" ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)\n",
|
||||
" term = ODETerm(lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))\n",
|
||||
" ode_fn = ODETerm(\n",
|
||||
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
|
||||
" solver = LeapfrogMidpoint()\n",
|
||||
"\n",
|
||||
" stepsize_controller = ConstantStepSize()\n",
|
||||
" res = diffeqsolve(term,\n",
|
||||
" res = diffeqsolve(ode_fn,\n",
|
||||
" solver,\n",
|
||||
" t0=0.1,\n",
|
||||
" t1=1.,\n",
|
||||
|
@ -651,7 +649,6 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"from math import prod\n",
|
||||
"from jaxpm.plotting import plot_fields_single_projection\n",
|
||||
"\n",
|
||||
"field = ode_solutions[0]\n",
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -20,13 +20,10 @@
|
|||
"import jax\n",
|
||||
"import jax.numpy as jnp\n",
|
||||
"import jax_cosmo as jc\n",
|
||||
"from jax.debug import visualize_array_sharding\n",
|
||||
"\n",
|
||||
"from jax.experimental.ode import odeint\n",
|
||||
"from jaxpm.kernels import interpolate_power_spectrum\n",
|
||||
"from jaxpm.painting import cic_paint , cic_paint_dx\n",
|
||||
"from jaxpm.pm import linear_field, lpt, make_ode_fn\n",
|
||||
"from jaxpm.distributed import uniform_particles\n",
|
||||
"from jaxpm.painting import cic_paint_dx\n",
|
||||
"from jaxpm.pm import linear_field, lpt, make_diffrax_ode\n",
|
||||
"from functools import partial\n",
|
||||
"from diffrax import ConstantStepSize, LeapfrogMidpoint,Dopri5 , PIDController , ODETerm, SaveAt, diffeqsolve"
|
||||
]
|
||||
|
@ -82,7 +79,6 @@
|
|||
"from jax.experimental.multihost_utils import process_allgather\n",
|
||||
"from jax.sharding import Mesh, NamedSharding\n",
|
||||
"from jax.sharding import PartitionSpec as P\n",
|
||||
"from functools import partial\n",
|
||||
"\n",
|
||||
"all_gather = partial(process_allgather, tiled=False)\n",
|
||||
"\n",
|
||||
|
@ -94,7 +90,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -127,12 +123,12 @@
|
|||
" sharding=sharding)\n",
|
||||
"\n",
|
||||
" # Evolve the simulation forward\n",
|
||||
" ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)\n",
|
||||
" term = ODETerm(lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))\n",
|
||||
" ode_fn = ODETerm(\n",
|
||||
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
|
||||
" solver = LeapfrogMidpoint()\n",
|
||||
"\n",
|
||||
" stepsize_controller = ConstantStepSize()\n",
|
||||
" res = diffeqsolve(term,\n",
|
||||
" res = diffeqsolve(ode_fn,\n",
|
||||
" solver,\n",
|
||||
" t0=0.1,\n",
|
||||
" t1=1.,\n",
|
||||
|
@ -244,7 +240,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -291,12 +287,12 @@
|
|||
" sharding=sharding)\n",
|
||||
"\n",
|
||||
" # Evolve the simulation forward\n",
|
||||
" ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)\n",
|
||||
" term = ODETerm(lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))\n",
|
||||
" ode_fn = ODETerm(\n",
|
||||
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
|
||||
" solver = Dopri5()\n",
|
||||
"\n",
|
||||
" stepsize_controller = PIDController(rtol=1e-5,atol=1e-5)\n",
|
||||
" res = diffeqsolve(term,\n",
|
||||
" res = diffeqsolve(ode_fn,\n",
|
||||
" solver,\n",
|
||||
" t0=0.1,\n",
|
||||
" t1=1.,\n",
|
||||
|
|
|
@ -18,14 +18,13 @@ import numpy as np
|
|||
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
|
||||
PIDController, SaveAt, diffeqsolve)
|
||||
from jax.experimental.mesh_utils import create_device_mesh
|
||||
from jax.experimental.multihost_utils import (process_allgather,
|
||||
sync_global_devices)
|
||||
from jax.experimental.multihost_utils import (process_allgather)
|
||||
from jax.sharding import Mesh, NamedSharding
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
from jaxpm.kernels import interpolate_power_spectrum
|
||||
from jaxpm.painting import cic_paint_dx
|
||||
from jaxpm.pm import linear_field, lpt, make_ode_fn
|
||||
from jaxpm.pm import linear_field, lpt, make_diffrax_ode
|
||||
|
||||
all_gather = partial(process_allgather, tiled=True)
|
||||
|
||||
|
@ -86,7 +85,7 @@ def create_mesh_and_sharding(pdims):
|
|||
|
||||
@partial(jax.jit, static_argnums=(2, 3, 4, 5, 6))
|
||||
def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
|
||||
solver_choice,nb_snapshots, sharding):
|
||||
solver_choice, nb_snapshots, sharding):
|
||||
k = jnp.logspace(-4, 1, 128)
|
||||
pk = jc.power.linear_matter_power(
|
||||
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
|
||||
|
@ -106,15 +105,14 @@ def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
|
|||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
|
||||
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)
|
||||
term = ODETerm(
|
||||
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
|
||||
|
||||
# Choose solver
|
||||
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()
|
||||
stepsize_controller = ConstantStepSize(
|
||||
) if solver_choice == "leapfrog" else PIDController(rtol=1e-5, atol=1e-5)
|
||||
res = diffeqsolve(term,
|
||||
res = diffeqsolve(ode_fn,
|
||||
solver,
|
||||
t0=0.1,
|
||||
t1=1.,
|
||||
|
@ -151,7 +149,6 @@ def main():
|
|||
print(f"[{rank}] Simulation done")
|
||||
print(f"Solver stats: {solver_stats}")
|
||||
|
||||
|
||||
# Save initial conditions
|
||||
initial_conditions_g = all_gather(initial_conditions)
|
||||
if rank == 0:
|
||||
|
|
Loading…
Add table
Reference in a new issue