mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-02-23 01:57:10 +00:00
format
This commit is contained in:
parent
a924458f0d
commit
bbacd45dcf
1 changed files with 6 additions and 4 deletions
|
@ -1,6 +1,7 @@
|
||||||
import jax
|
import jax
|
||||||
import pytest
|
import pytest
|
||||||
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve , RecursiveCheckpointAdjoint, BacksolveAdjoint
|
from diffrax import (BacksolveAdjoint, Dopri5, ODETerm, PIDController,
|
||||||
|
RecursiveCheckpointAdjoint, SaveAt, diffeqsolve)
|
||||||
from helpers import MSE
|
from helpers import MSE
|
||||||
from jax import numpy as jnp
|
from jax import numpy as jnp
|
||||||
|
|
||||||
|
@ -15,15 +16,16 @@ from jaxpm.pm import lpt, make_diffrax_ode
|
||||||
@pytest.mark.parametrize("adjoint", ['DTO', 'OTD'])
|
@pytest.mark.parametrize("adjoint", ['DTO', 'OTD'])
|
||||||
def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
|
def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
|
||||||
nbody_from_lpt1, nbody_from_lpt2, cosmo, order,
|
nbody_from_lpt1, nbody_from_lpt2, cosmo, order,
|
||||||
absolute_painting , adjoint):
|
absolute_painting, adjoint):
|
||||||
|
|
||||||
mesh_shape, _ = simulation_config
|
mesh_shape, _ = simulation_config
|
||||||
cosmo._workspace = {}
|
cosmo._workspace = {}
|
||||||
|
|
||||||
if adjoint == 'OTD':
|
if adjoint == 'OTD':
|
||||||
pytest.skip("OTD adjoint not implemented yet (needs PFFT3D JVP)")
|
pytest.skip("OTD adjoint not implemented yet (needs PFFT3D JVP)")
|
||||||
|
|
||||||
adjoint = RecursiveCheckpointAdjoint() if adjoint == 'DTO' else BacksolveAdjoint(solver=Dopri5())
|
adjoint = RecursiveCheckpointAdjoint(
|
||||||
|
) if adjoint == 'DTO' else BacksolveAdjoint(solver=Dopri5())
|
||||||
|
|
||||||
@jax.jit
|
@jax.jit
|
||||||
@jax.grad
|
@jax.grad
|
||||||
|
|
Loading…
Add table
Reference in a new issue