diff --git a/tests/test_distributed_pm.py b/tests/test_distributed_pm.py index f1d4b7f..0c81e01 100644 --- a/tests/test_distributed_pm.py +++ b/tests/test_distributed_pm.py @@ -22,10 +22,12 @@ from jaxpm.distributed import fft3d, ifft3d from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402 from jaxpm.pm import lpt, make_diffrax_ode, pm_forces # noqa : E402 -_TOLERANCE = 1e-6 # 🎉🎉🎉 +_TOLERANCE = 1e-12 # 🎉🎉🎉 pdims = [(1, 8), (8, 1), (4, 2), (2, 4)] +jax.config.update("jax_enable_x64", True) # Use double precision for accuracy + @pytest.mark.distributed @pytest.mark.parametrize("order", [1, 2]) @@ -39,9 +41,11 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, painting_str = "absolute" if absolute_painting else "relative" print("=" * 50) - print(f"Running with {painting_str} painting and pdims {pdims} ...") mesh_shape, box_shape = simulation_config + print( + f"Running with {painting_str} painting and pdims {pdims} and order {order} and mesh shape {mesh_shape}..." + ) # SINGLE DEVICE RUN cosmo._workspace = {} if absolute_painting: