updating tests

This commit is contained in:
Francois Lanusse 2025-06-28 01:36:17 +02:00
parent 7623e60581
commit 627be7a764
2 changed files with 20 additions and 4 deletions

View file

@ -2,3 +2,4 @@ pfft-python @ git+https://github.com/MP-Gadget/pfft-python
pmesh @ git+https://github.com/MP-Gadget/pmesh pmesh @ git+https://github.com/MP-Gadget/pmesh
fastpm @ git+https://github.com/ASKabalan/fastpm-python fastpm @ git+https://github.com/ASKabalan/fastpm-python
numpy==2.2.6 numpy==2.2.6
diffrax

View file

@ -2,6 +2,7 @@ import pytest
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve
from helpers import MSE, MSRE from helpers import MSE, MSRE
from jax import numpy as jnp from jax import numpy as jnp
from numpy.testing import assert_allclose
from jaxpm.distributed import uniform_particles from jaxpm.distributed import uniform_particles
from jaxpm.painting import cic_paint, cic_paint_dx from jaxpm.painting import cic_paint, cic_paint_dx
@ -10,6 +11,8 @@ from jaxpm.utils import power_spectrum
_TOLERANCE = 1e-4 _TOLERANCE = 1e-4
_PM_TOLERANCE = 1e-3 _PM_TOLERANCE = 1e-3
_FIELD_RTOL = 1e-2
_FIELD_ATOL = 1e-1
@pytest.mark.single_device @pytest.mark.single_device
@ -34,7 +37,10 @@ def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor,
_, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape) _, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape)
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape) _, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE assert_allclose(lpt_field,
fpm_ref_field,
rtol=_FIELD_RTOL,
atol=_FIELD_ATOL)
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
@ -55,7 +61,10 @@ def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor,
_, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape) _, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape)
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape) _, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE assert_allclose(lpt_field,
fpm_ref_field,
rtol=_FIELD_RTOL,
atol=_FIELD_ATOL)
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
@ -106,7 +115,10 @@ def test_nbody_absolute(simulation_config, initial_conditions,
_, jpm_ps = power_spectrum(final_field, box_shape=box_shape) _, jpm_ps = power_spectrum(final_field, box_shape=box_shape)
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape) _, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE assert_allclose(final_field,
fpm_ref_field,
rtol=_FIELD_RTOL,
atol=_FIELD_ATOL)
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE
@ -152,5 +164,8 @@ def test_nbody_relative(simulation_config, initial_conditions,
_, jpm_ps = power_spectrum(final_field, box_shape=box_shape) _, jpm_ps = power_spectrum(final_field, box_shape=box_shape)
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape) _, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE assert_allclose(final_field,
fpm_ref_field,
rtol=_FIELD_RTOL,
atol=_FIELD_ATOL)
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE