This commit is contained in:
Wassim Kabalan 2024-12-21 23:27:05 +01:00
parent a924458f0d
commit bbacd45dcf

View file

@ -1,6 +1,7 @@
import jax import jax
import pytest import pytest
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve , RecursiveCheckpointAdjoint, BacksolveAdjoint from diffrax import (BacksolveAdjoint, Dopri5, ODETerm, PIDController,
RecursiveCheckpointAdjoint, SaveAt, diffeqsolve)
from helpers import MSE from helpers import MSE
from jax import numpy as jnp from jax import numpy as jnp
@ -23,7 +24,8 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
if adjoint == 'OTD': if adjoint == 'OTD':
pytest.skip("OTD adjoint not implemented yet (needs PFFT3D JVP)") pytest.skip("OTD adjoint not implemented yet (needs PFFT3D JVP)")
adjoint = RecursiveCheckpointAdjoint() if adjoint == 'DTO' else BacksolveAdjoint(solver=Dopri5()) adjoint = RecursiveCheckpointAdjoint(
) if adjoint == 'DTO' else BacksolveAdjoint(solver=Dopri5())
@jax.jit @jax.jit
@jax.grad @jax.grad