mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-02-23 01:57:10 +00:00
format
This commit is contained in:
parent
a924458f0d
commit
bbacd45dcf
1 changed files with 6 additions and 4 deletions
|
@ -1,6 +1,7 @@
|
|||
import jax
|
||||
import pytest
|
||||
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve , RecursiveCheckpointAdjoint, BacksolveAdjoint
|
||||
from diffrax import (BacksolveAdjoint, Dopri5, ODETerm, PIDController,
|
||||
RecursiveCheckpointAdjoint, SaveAt, diffeqsolve)
|
||||
from helpers import MSE
|
||||
from jax import numpy as jnp
|
||||
|
||||
|
@ -23,7 +24,8 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
|
|||
if adjoint == 'OTD':
|
||||
pytest.skip("OTD adjoint not implemented yet (needs PFFT3D JVP)")
|
||||
|
||||
adjoint = RecursiveCheckpointAdjoint() if adjoint == 'DTO' else BacksolveAdjoint(solver=Dopri5())
|
||||
adjoint = RecursiveCheckpointAdjoint(
|
||||
) if adjoint == 'DTO' else BacksolveAdjoint(solver=Dopri5())
|
||||
|
||||
@jax.jit
|
||||
@jax.grad
|
||||
|
|
Loading…
Add table
Reference in a new issue