mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-28 08:01:12 +00:00
updating tests
This commit is contained in:
parent
7623e60581
commit
627be7a764
2 changed files with 20 additions and 4 deletions
|
@ -2,3 +2,4 @@ pfft-python @ git+https://github.com/MP-Gadget/pfft-python
|
||||||
pmesh @ git+https://github.com/MP-Gadget/pmesh
|
pmesh @ git+https://github.com/MP-Gadget/pmesh
|
||||||
fastpm @ git+https://github.com/ASKabalan/fastpm-python
|
fastpm @ git+https://github.com/ASKabalan/fastpm-python
|
||||||
numpy==2.2.6
|
numpy==2.2.6
|
||||||
|
diffrax
|
||||||
|
|
|
@ -2,6 +2,7 @@ import pytest
|
||||||
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve
|
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve
|
||||||
from helpers import MSE, MSRE
|
from helpers import MSE, MSRE
|
||||||
from jax import numpy as jnp
|
from jax import numpy as jnp
|
||||||
|
from numpy.testing import assert_allclose
|
||||||
|
|
||||||
from jaxpm.distributed import uniform_particles
|
from jaxpm.distributed import uniform_particles
|
||||||
from jaxpm.painting import cic_paint, cic_paint_dx
|
from jaxpm.painting import cic_paint, cic_paint_dx
|
||||||
|
@ -10,6 +11,8 @@ from jaxpm.utils import power_spectrum
|
||||||
|
|
||||||
_TOLERANCE = 1e-4
|
_TOLERANCE = 1e-4
|
||||||
_PM_TOLERANCE = 1e-3
|
_PM_TOLERANCE = 1e-3
|
||||||
|
_FIELD_RTOL = 1e-2
|
||||||
|
_FIELD_ATOL = 1e-1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.single_device
|
@pytest.mark.single_device
|
||||||
|
@ -34,7 +37,10 @@ def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor,
|
||||||
_, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape)
|
_, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape)
|
||||||
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
|
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
|
||||||
|
|
||||||
assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE
|
assert_allclose(lpt_field,
|
||||||
|
fpm_ref_field,
|
||||||
|
rtol=_FIELD_RTOL,
|
||||||
|
atol=_FIELD_ATOL)
|
||||||
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
|
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
|
||||||
|
|
||||||
|
|
||||||
|
@ -55,7 +61,10 @@ def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor,
|
||||||
_, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape)
|
_, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape)
|
||||||
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
|
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
|
||||||
|
|
||||||
assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE
|
assert_allclose(lpt_field,
|
||||||
|
fpm_ref_field,
|
||||||
|
rtol=_FIELD_RTOL,
|
||||||
|
atol=_FIELD_ATOL)
|
||||||
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
|
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
|
||||||
|
|
||||||
|
|
||||||
|
@ -106,7 +115,10 @@ def test_nbody_absolute(simulation_config, initial_conditions,
|
||||||
_, jpm_ps = power_spectrum(final_field, box_shape=box_shape)
|
_, jpm_ps = power_spectrum(final_field, box_shape=box_shape)
|
||||||
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
|
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
|
||||||
|
|
||||||
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE
|
assert_allclose(final_field,
|
||||||
|
fpm_ref_field,
|
||||||
|
rtol=_FIELD_RTOL,
|
||||||
|
atol=_FIELD_ATOL)
|
||||||
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE
|
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE
|
||||||
|
|
||||||
|
|
||||||
|
@ -152,5 +164,8 @@ def test_nbody_relative(simulation_config, initial_conditions,
|
||||||
_, jpm_ps = power_spectrum(final_field, box_shape=box_shape)
|
_, jpm_ps = power_spectrum(final_field, box_shape=box_shape)
|
||||||
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
|
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
|
||||||
|
|
||||||
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE
|
assert_allclose(final_field,
|
||||||
|
fpm_ref_field,
|
||||||
|
rtol=_FIELD_RTOL,
|
||||||
|
atol=_FIELD_ATOL)
|
||||||
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE
|
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue