update notebooks

This commit is contained in:
Wassim Kabalan 2024-12-06 18:56:35 +01:00
parent 21373b89ee
commit 36ef18e3d0
5 changed files with 88 additions and 80 deletions

File diff suppressed because one or more lines are too long

View file

@ -21,17 +21,16 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import jax_cosmo as jc\n",
"\n",
"from jaxpm.painting import cic_paint , cic_paint_dx\n",
"from jaxpm.pm import linear_field, lpt, make_ode_fn\n",
"from jaxpm.pm import linear_field, lpt, make_diffrax_ode\n",
"from jaxpm.distributed import uniform_particles\n",
"from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve"
]
@ -84,13 +83,12 @@
" dx, p, f = lpt(cosmo, initial_conditions, a=0.1,order=1)\n",
" \n",
" # Evolve the simulation forward\n",
" ode_fn = make_ode_fn(mesh_shape)\n",
" term = ODETerm(\n",
" lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))\n",
" ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
" solver = LeapfrogMidpoint()\n",
"\n",
" stepsize_controller = ConstantStepSize()\n",
" res = diffeqsolve(term,\n",
" res = diffeqsolve(ode_fn,\n",
" solver,\n",
" t0=0.1,\n",
" t1=1.,\n",
@ -258,13 +256,12 @@
" dx, p, f = lpt(cosmo, initial_conditions,particles=particles,a=0.1,order=2)\n",
" \n",
" # Evolve the simulation forward\n",
" ode_fn = make_ode_fn(mesh_shape,particles=particles)\n",
" term = ODETerm(\n",
" lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))\n",
" ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape))\n",
" solver = LeapfrogMidpoint()\n",
"\n",
" stepsize_controller = ConstantStepSize()\n",
" res = diffeqsolve(term,\n",
" res = diffeqsolve(ode_fn,\n",
" solver,\n",
" t0=0.1,\n",
" t1=1.,\n",
@ -336,7 +333,6 @@
}
],
"source": [
"from math import prod\n",
"from jaxpm.plotting import plot_fields_single_projection\n",
"\n",
"center = slice(mesh_shape[0] // 4, 3 * mesh_shape[0] // 4 )\n",
@ -379,7 +375,6 @@
}
],
"source": [
"from math import prod\n",
"from jaxpm.plotting import plot_fields_single_projection\n",
"\n",
"center = slice(mesh_shape[0] // 4, 3 * mesh_shape[0] // 4 )\n",

View file

@ -2,6 +2,7 @@
"cells": [
{
"cell_type": "markdown",
"id": "7fb27b941602401d91542211134fc71a",
"metadata": {
"colab_type": "text",
"id": "view-in-github"
@ -45,7 +46,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"id": "c5f42bbe",
"metadata": {
"colab": {
@ -63,11 +64,9 @@
"import jax_cosmo as jc\n",
"from jax.debug import visualize_array_sharding\n",
"\n",
"from jax.experimental.ode import odeint\n",
"from jaxpm.kernels import interpolate_power_spectrum\n",
"from jaxpm.painting import cic_paint , cic_paint_dx\n",
"from jaxpm.pm import linear_field, lpt, make_ode_fn\n",
"from jaxpm.distributed import uniform_particles\n",
"from jaxpm.painting import cic_paint_dx\n",
"from jaxpm.pm import linear_field, lpt, make_diffrax_ode\n",
"from functools import partial\n",
"from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve"
]
@ -109,7 +108,6 @@
"from jax.experimental.multihost_utils import process_allgather\n",
"from jax.sharding import Mesh, NamedSharding\n",
"from jax.sharding import PartitionSpec as P\n",
"from functools import partial\n",
"\n",
"all_gather = partial(process_allgather, tiled=False)\n",
"\n",
@ -131,7 +129,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"id": "281b4d3b",
"metadata": {
"id": "281b4d3b"
@ -181,12 +179,12 @@
" sharding=sharding)\n",
"\n",
" # Evolve the simulation forward\n",
" ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)\n",
" term = ODETerm(lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))\n",
" ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
" solver = LeapfrogMidpoint()\n",
"\n",
" stepsize_controller = ConstantStepSize()\n",
" res = diffeqsolve(term,\n",
" res = diffeqsolve(ode_fn,\n",
" solver,\n",
" t0=0.1,\n",
" t1=1.,\n",
@ -364,7 +362,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"id": "8c647b13",
"metadata": {},
"outputs": [
@ -411,12 +409,12 @@
" sharding=sharding)\n",
"\n",
" # Evolve the simulation forward\n",
" ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)\n",
" term = ODETerm(lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))\n",
" ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
" solver = LeapfrogMidpoint()\n",
"\n",
" stepsize_controller = ConstantStepSize()\n",
" res = diffeqsolve(term,\n",
" res = diffeqsolve(ode_fn,\n",
" solver,\n",
" t0=0.1,\n",
" t1=1.,\n",
@ -651,7 +649,6 @@
}
],
"source": [
"from math import prod\n",
"from jaxpm.plotting import plot_fields_single_projection\n",
"\n",
"field = ode_solutions[0]\n",

View file

@ -11,7 +11,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -20,13 +20,10 @@
"import jax\n",
"import jax.numpy as jnp\n",
"import jax_cosmo as jc\n",
"from jax.debug import visualize_array_sharding\n",
"\n",
"from jax.experimental.ode import odeint\n",
"from jaxpm.kernels import interpolate_power_spectrum\n",
"from jaxpm.painting import cic_paint , cic_paint_dx\n",
"from jaxpm.pm import linear_field, lpt, make_ode_fn\n",
"from jaxpm.distributed import uniform_particles\n",
"from jaxpm.painting import cic_paint_dx\n",
"from jaxpm.pm import linear_field, lpt, make_diffrax_ode\n",
"from functools import partial\n",
"from diffrax import ConstantStepSize, LeapfrogMidpoint,Dopri5 , PIDController , ODETerm, SaveAt, diffeqsolve"
]
@ -82,7 +79,6 @@
"from jax.experimental.multihost_utils import process_allgather\n",
"from jax.sharding import Mesh, NamedSharding\n",
"from jax.sharding import PartitionSpec as P\n",
"from functools import partial\n",
"\n",
"all_gather = partial(process_allgather, tiled=False)\n",
"\n",
@ -94,7 +90,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -127,12 +123,12 @@
" sharding=sharding)\n",
"\n",
" # Evolve the simulation forward\n",
" ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)\n",
" term = ODETerm(lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))\n",
" ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
" solver = LeapfrogMidpoint()\n",
"\n",
" stepsize_controller = ConstantStepSize()\n",
" res = diffeqsolve(term,\n",
" res = diffeqsolve(ode_fn,\n",
" solver,\n",
" t0=0.1,\n",
" t1=1.,\n",
@ -244,7 +240,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {},
"outputs": [
{
@ -291,12 +287,12 @@
" sharding=sharding)\n",
"\n",
" # Evolve the simulation forward\n",
" ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)\n",
" term = ODETerm(lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))\n",
" ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
" solver = Dopri5()\n",
"\n",
" stepsize_controller = PIDController(rtol=1e-5,atol=1e-5)\n",
" res = diffeqsolve(term,\n",
" res = diffeqsolve(ode_fn,\n",
" solver,\n",
" t0=0.1,\n",
" t1=1.,\n",

View file

@ -18,14 +18,13 @@ import numpy as np
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
PIDController, SaveAt, diffeqsolve)
from jax.experimental.mesh_utils import create_device_mesh
from jax.experimental.multihost_utils import (process_allgather,
sync_global_devices)
from jax.experimental.multihost_utils import (process_allgather)
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from jaxpm.kernels import interpolate_power_spectrum
from jaxpm.painting import cic_paint_dx
from jaxpm.pm import linear_field, lpt, make_ode_fn
from jaxpm.pm import linear_field, lpt, make_diffrax_ode
all_gather = partial(process_allgather, tiled=True)
@ -86,7 +85,7 @@ def create_mesh_and_sharding(pdims):
@partial(jax.jit, static_argnums=(2, 3, 4, 5, 6))
def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
solver_choice,nb_snapshots, sharding):
solver_choice, nb_snapshots, sharding):
k = jnp.logspace(-4, 1, 128)
pk = jc.power.linear_matter_power(
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
@ -106,15 +105,14 @@ def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
halo_size=halo_size,
sharding=sharding)
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)
term = ODETerm(
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
ode_fn = ODETerm(
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
# Choose solver
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()
stepsize_controller = ConstantStepSize(
) if solver_choice == "leapfrog" else PIDController(rtol=1e-5, atol=1e-5)
res = diffeqsolve(term,
res = diffeqsolve(ode_fn,
solver,
t0=0.1,
t1=1.,
@ -151,7 +149,6 @@ def main():
print(f"[{rank}] Simulation done")
print(f"Solver stats: {solver_stats}")
# Save initial conditions
initial_conditions_g = all_gather(initial_conditions)
if rank == 0: