update to leapfrog

This commit is contained in:
Wassim KABALAN 2024-10-22 09:03:41 -04:00
parent 5a587fd402
commit a160a3faa9

View file

@ -11,7 +11,7 @@ size = jax.device_count()
import jax.numpy as jnp
import jax_cosmo as jc
import numpy as np
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve
from diffrax import Dopri5,LeapfrogMidpoint, ODETerm, ConstantStepSize, SaveAt, diffeqsolve
from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
@ -27,7 +27,7 @@ snapshots = jnp.linspace(0.1, 1., 4)
halo_size = 32
pdims = (1, 1)
if jax.device_count() > 1:
pdims = (4, 2)
pdims = (8, 1)
devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices.T, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
@ -51,15 +51,13 @@ def run_simulation(omega_c, sigma8):
# Initial displacement
dx, p, _ = lpt(cosmo, initial_conditions, 0.1, halo_size=halo_size)
return initial_conditions, cic_paint_dx(dx,
halo_size=halo_size), None, None
# Evolve the simulation forward
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size)
term = ODETerm(
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
solver = Dopri5()
solver = LeapfrogMidpoint()
stepsize_controller = PIDController(rtol=1e-4, atol=1e-4)
stepsize_controller = ConstantStepSize()
res = diffeqsolve(term,
solver,
t0=0.1,