mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-02-22 17:47:11 +00:00
format
This commit is contained in:
parent
a924458f0d
commit
bbacd45dcf
1 changed files with 6 additions and 4 deletions
|
@ -1,6 +1,7 @@
|
|||
import jax
|
||||
import pytest
|
||||
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve , RecursiveCheckpointAdjoint, BacksolveAdjoint
|
||||
from diffrax import (BacksolveAdjoint, Dopri5, ODETerm, PIDController,
|
||||
RecursiveCheckpointAdjoint, SaveAt, diffeqsolve)
|
||||
from helpers import MSE
|
||||
from jax import numpy as jnp
|
||||
|
||||
|
@ -15,15 +16,16 @@ from jaxpm.pm import lpt, make_diffrax_ode
|
|||
@pytest.mark.parametrize("adjoint", ['DTO', 'OTD'])
|
||||
def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
|
||||
nbody_from_lpt1, nbody_from_lpt2, cosmo, order,
|
||||
absolute_painting , adjoint):
|
||||
absolute_painting, adjoint):
|
||||
|
||||
mesh_shape, _ = simulation_config
|
||||
cosmo._workspace = {}
|
||||
|
||||
if adjoint == 'OTD':
|
||||
pytest.skip("OTD adjoint not implemented yet (needs PFFT3D JVP)")
|
||||
pytest.skip("OTD adjoint not implemented yet (needs PFFT3D JVP)")
|
||||
|
||||
adjoint = RecursiveCheckpointAdjoint() if adjoint == 'DTO' else BacksolveAdjoint(solver=Dopri5())
|
||||
adjoint = RecursiveCheckpointAdjoint(
|
||||
) if adjoint == 'DTO' else BacksolveAdjoint(solver=Dopri5())
|
||||
|
||||
@jax.jit
|
||||
@jax.grad
|
||||
|
|
Loading…
Add table
Reference in a new issue