From a160a3faa9404e991814906c284a671c1739b627 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Tue, 22 Oct 2024 09:03:41 -0400 Subject: [PATCH] update to leapfrog --- scripts/distributed_pm.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/scripts/distributed_pm.py b/scripts/distributed_pm.py index c7c48ac..5411930 100644 --- a/scripts/distributed_pm.py +++ b/scripts/distributed_pm.py @@ -11,7 +11,7 @@ size = jax.device_count() import jax.numpy as jnp import jax_cosmo as jc import numpy as np -from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve +from diffrax import Dopri5,LeapfrogMidpoint, ODETerm, ConstantStepSize, SaveAt, diffeqsolve from jax.experimental import mesh_utils from jax.sharding import Mesh, NamedSharding from jax.sharding import PartitionSpec as P @@ -27,7 +27,7 @@ snapshots = jnp.linspace(0.1, 1., 4) halo_size = 32 pdims = (1, 1) if jax.device_count() > 1: - pdims = (4, 2) + pdims = (8, 1) devices = mesh_utils.create_device_mesh(pdims) mesh = Mesh(devices.T, axis_names=('x', 'y')) sharding = NamedSharding(mesh, P('x', 'y')) @@ -51,15 +51,13 @@ def run_simulation(omega_c, sigma8): # Initial displacement dx, p, _ = lpt(cosmo, initial_conditions, 0.1, halo_size=halo_size) - return initial_conditions, cic_paint_dx(dx, - halo_size=halo_size), None, None # Evolve the simulation forward ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size) term = ODETerm( lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0)) - solver = Dopri5() + solver = LeapfrogMidpoint() - stepsize_controller = PIDController(rtol=1e-4, atol=1e-4) + stepsize_controller = ConstantStepSize() res = diffeqsolve(term, solver, t0=0.1,