This commit is contained in:
Wassim Kabalan 2024-12-21 23:27:05 +01:00
parent a924458f0d
commit bbacd45dcf

View file

@ -1,6 +1,7 @@
import jax import jax
import pytest import pytest
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve , RecursiveCheckpointAdjoint, BacksolveAdjoint from diffrax import (BacksolveAdjoint, Dopri5, ODETerm, PIDController,
RecursiveCheckpointAdjoint, SaveAt, diffeqsolve)
from helpers import MSE from helpers import MSE
from jax import numpy as jnp from jax import numpy as jnp
@ -15,15 +16,16 @@ from jaxpm.pm import lpt, make_diffrax_ode
@pytest.mark.parametrize("adjoint", ['DTO', 'OTD']) @pytest.mark.parametrize("adjoint", ['DTO', 'OTD'])
def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor, def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
nbody_from_lpt1, nbody_from_lpt2, cosmo, order, nbody_from_lpt1, nbody_from_lpt2, cosmo, order,
absolute_painting , adjoint): absolute_painting, adjoint):
mesh_shape, _ = simulation_config mesh_shape, _ = simulation_config
cosmo._workspace = {} cosmo._workspace = {}
if adjoint == 'OTD': if adjoint == 'OTD':
pytest.skip("OTD adjoint not implemented yet (needs PFFT3D JVP)") pytest.skip("OTD adjoint not implemented yet (needs PFFT3D JVP)")
adjoint = RecursiveCheckpointAdjoint() if adjoint == 'DTO' else BacksolveAdjoint(solver=Dopri5()) adjoint = RecursiveCheckpointAdjoint(
) if adjoint == 'DTO' else BacksolveAdjoint(solver=Dopri5())
@jax.jit @jax.jit
@jax.grad @jax.grad