From 627be7a7641f65a1fdf0b6574545dd3aa229aa7e Mon Sep 17 00:00:00 2001 From: Francois Lanusse Date: Sat, 28 Jun 2025 01:36:17 +0200 Subject: [PATCH] updating tests --- requirements-test.txt | 1 + tests/test_against_fpm.py | 23 +++++++++++++++++++---- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/requirements-test.txt b/requirements-test.txt index 70a7229..e415cad 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -2,3 +2,4 @@ pfft-python @ git+https://github.com/MP-Gadget/pfft-python pmesh @ git+https://github.com/MP-Gadget/pmesh fastpm @ git+https://github.com/ASKabalan/fastpm-python numpy==2.2.6 +diffrax diff --git a/tests/test_against_fpm.py b/tests/test_against_fpm.py index 5ebcbc2..1530733 100644 --- a/tests/test_against_fpm.py +++ b/tests/test_against_fpm.py @@ -2,6 +2,7 @@ import pytest from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve from helpers import MSE, MSRE from jax import numpy as jnp +from numpy.testing import assert_allclose from jaxpm.distributed import uniform_particles from jaxpm.painting import cic_paint, cic_paint_dx @@ -10,6 +11,8 @@ from jaxpm.utils import power_spectrum _TOLERANCE = 1e-4 _PM_TOLERANCE = 1e-3 +_FIELD_RTOL = 1e-2 +_FIELD_ATOL = 1e-1 @pytest.mark.single_device @@ -34,7 +37,10 @@ def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor, _, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape) _, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape) - assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE + assert_allclose(lpt_field, + fpm_ref_field, + rtol=_FIELD_RTOL, + atol=_FIELD_ATOL) assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE @@ -55,7 +61,10 @@ def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor, _, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape) _, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape) - assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE + assert_allclose(lpt_field, + fpm_ref_field, + rtol=_FIELD_RTOL, + atol=_FIELD_ATOL) assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE @@ -106,7 +115,10 @@ def test_nbody_absolute(simulation_config, initial_conditions, _, jpm_ps = power_spectrum(final_field, box_shape=box_shape) _, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape) - assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE + assert_allclose(final_field, + fpm_ref_field, + rtol=_FIELD_RTOL, + atol=_FIELD_ATOL) assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE @@ -152,5 +164,8 @@ def test_nbody_relative(simulation_config, initial_conditions, _, jpm_ps = power_spectrum(final_field, box_shape=box_shape) _, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape) - assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE + assert_allclose(final_field, + fpm_ref_field, + rtol=_FIELD_RTOL, + atol=_FIELD_ATOL) assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE