Update tolerance and precision settings in distributed PM tests

This commit is contained in:
Wassim Kabalan 2025-06-28 12:08:16 +02:00
parent e666aada42
commit 2d21985279

View file

@ -22,10 +22,12 @@ from jaxpm.distributed import fft3d, ifft3d
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
from jaxpm.pm import lpt, make_diffrax_ode, pm_forces # noqa : E402
_TOLERANCE = 1e-6 # 🎉🎉🎉
_TOLERANCE = 1e-12 # 🎉🎉🎉
pdims = [(1, 8), (8, 1), (4, 2), (2, 4)]
jax.config.update("jax_enable_x64", True) # Use double precision for accuracy
@pytest.mark.distributed
@pytest.mark.parametrize("order", [1, 2])
@ -39,9 +41,11 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
painting_str = "absolute" if absolute_painting else "relative"
print("=" * 50)
print(f"Running with {painting_str} painting and pdims {pdims} ...")
mesh_shape, box_shape = simulation_config
print(
f"Running with {painting_str} painting and pdims {pdims} and order {order} and mesh shape {mesh_shape}..."
)
# SINGLE DEVICE RUN
cosmo._workspace = {}
if absolute_painting: