updating tests

This commit is contained in:
Francois Lanusse 2025-06-28 01:36:17 +02:00
parent 7623e60581
commit 627be7a764
2 changed files with 20 additions and 4 deletions

View file

@ -2,6 +2,7 @@ import pytest
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve
from helpers import MSE, MSRE
from jax import numpy as jnp
from numpy.testing import assert_allclose
from jaxpm.distributed import uniform_particles
from jaxpm.painting import cic_paint, cic_paint_dx
@ -10,6 +11,8 @@ from jaxpm.utils import power_spectrum
_TOLERANCE = 1e-4
_PM_TOLERANCE = 1e-3
_FIELD_RTOL = 1e-2
_FIELD_ATOL = 1e-1
@pytest.mark.single_device
@ -34,7 +37,10 @@ def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor,
_, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape)
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE
assert_allclose(lpt_field,
fpm_ref_field,
rtol=_FIELD_RTOL,
atol=_FIELD_ATOL)
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
@ -55,7 +61,10 @@ def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor,
_, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape)
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE
assert_allclose(lpt_field,
fpm_ref_field,
rtol=_FIELD_RTOL,
atol=_FIELD_ATOL)
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
@ -106,7 +115,10 @@ def test_nbody_absolute(simulation_config, initial_conditions,
_, jpm_ps = power_spectrum(final_field, box_shape=box_shape)
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE
assert_allclose(final_field,
fpm_ref_field,
rtol=_FIELD_RTOL,
atol=_FIELD_ATOL)
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE
@ -152,5 +164,8 @@ def test_nbody_relative(simulation_config, initial_conditions,
_, jpm_ps = power_spectrum(final_field, box_shape=box_shape)
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE
assert_allclose(final_field,
fpm_ref_field,
rtol=_FIELD_RTOL,
atol=_FIELD_ATOL)
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE