mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-02-23 10:00:54 +00:00
format
This commit is contained in:
parent
a924458f0d
commit
bbacd45dcf
1 changed files with 6 additions and 4 deletions
|
@ -1,6 +1,7 @@
|
||||||
import jax
|
import jax
|
||||||
import pytest
|
import pytest
|
||||||
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve , RecursiveCheckpointAdjoint, BacksolveAdjoint
|
from diffrax import (BacksolveAdjoint, Dopri5, ODETerm, PIDController,
|
||||||
|
RecursiveCheckpointAdjoint, SaveAt, diffeqsolve)
|
||||||
from helpers import MSE
|
from helpers import MSE
|
||||||
from jax import numpy as jnp
|
from jax import numpy as jnp
|
||||||
|
|
||||||
|
@ -23,7 +24,8 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
|
||||||
if adjoint == 'OTD':
|
if adjoint == 'OTD':
|
||||||
pytest.skip("OTD adjoint not implemented yet (needs PFFT3D JVP)")
|
pytest.skip("OTD adjoint not implemented yet (needs PFFT3D JVP)")
|
||||||
|
|
||||||
adjoint = RecursiveCheckpointAdjoint() if adjoint == 'DTO' else BacksolveAdjoint(solver=Dopri5())
|
adjoint = RecursiveCheckpointAdjoint(
|
||||||
|
) if adjoint == 'DTO' else BacksolveAdjoint(solver=Dopri5())
|
||||||
|
|
||||||
@jax.jit
|
@jax.jit
|
||||||
@jax.grad
|
@jax.grad
|
||||||
|
|
Loading…
Add table
Reference in a new issue