update to leapfrog

This commit is contained in:
Wassim KABALAN 2024-10-22 09:03:41 -04:00
parent 5a587fd402
commit a160a3faa9

View file

@ -11,7 +11,7 @@ size = jax.device_count()
import jax.numpy as jnp import jax.numpy as jnp
import jax_cosmo as jc import jax_cosmo as jc
import numpy as np import numpy as np
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve from diffrax import Dopri5,LeapfrogMidpoint, ODETerm, ConstantStepSize, SaveAt, diffeqsolve
from jax.experimental import mesh_utils from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P from jax.sharding import PartitionSpec as P
@ -27,7 +27,7 @@ snapshots = jnp.linspace(0.1, 1., 4)
halo_size = 32 halo_size = 32
pdims = (1, 1) pdims = (1, 1)
if jax.device_count() > 1: if jax.device_count() > 1:
pdims = (4, 2) pdims = (8, 1)
devices = mesh_utils.create_device_mesh(pdims) devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices.T, axis_names=('x', 'y')) mesh = Mesh(devices.T, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y')) sharding = NamedSharding(mesh, P('x', 'y'))
@ -51,15 +51,13 @@ def run_simulation(omega_c, sigma8):
# Initial displacement # Initial displacement
dx, p, _ = lpt(cosmo, initial_conditions, 0.1, halo_size=halo_size) dx, p, _ = lpt(cosmo, initial_conditions, 0.1, halo_size=halo_size)
return initial_conditions, cic_paint_dx(dx,
halo_size=halo_size), None, None
# Evolve the simulation forward # Evolve the simulation forward
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size) ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size)
term = ODETerm( term = ODETerm(
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0)) lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
solver = Dopri5() solver = LeapfrogMidpoint()
stepsize_controller = PIDController(rtol=1e-4, atol=1e-4) stepsize_controller = ConstantStepSize()
res = diffeqsolve(term, res = diffeqsolve(term,
solver, solver,
t0=0.1, t0=0.1,