mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-29 08:31:11 +00:00
Update tolerance and precision settings in distributed PM tests
This commit is contained in:
parent
e666aada42
commit
2d21985279
1 changed files with 6 additions and 2 deletions
|
@ -22,10 +22,12 @@ from jaxpm.distributed import fft3d, ifft3d
|
||||||
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
|
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
|
||||||
from jaxpm.pm import lpt, make_diffrax_ode, pm_forces # noqa : E402
|
from jaxpm.pm import lpt, make_diffrax_ode, pm_forces # noqa : E402
|
||||||
|
|
||||||
_TOLERANCE = 1e-6 # 🎉🎉🎉
|
_TOLERANCE = 1e-12 # 🎉🎉🎉
|
||||||
|
|
||||||
pdims = [(1, 8), (8, 1), (4, 2), (2, 4)]
|
pdims = [(1, 8), (8, 1), (4, 2), (2, 4)]
|
||||||
|
|
||||||
|
jax.config.update("jax_enable_x64", True) # Use double precision for accuracy
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.distributed
|
@pytest.mark.distributed
|
||||||
@pytest.mark.parametrize("order", [1, 2])
|
@pytest.mark.parametrize("order", [1, 2])
|
||||||
|
@ -39,9 +41,11 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
||||||
|
|
||||||
painting_str = "absolute" if absolute_painting else "relative"
|
painting_str = "absolute" if absolute_painting else "relative"
|
||||||
print("=" * 50)
|
print("=" * 50)
|
||||||
print(f"Running with {painting_str} painting and pdims {pdims} ...")
|
|
||||||
|
|
||||||
mesh_shape, box_shape = simulation_config
|
mesh_shape, box_shape = simulation_config
|
||||||
|
print(
|
||||||
|
f"Running with {painting_str} painting and pdims {pdims} and order {order} and mesh shape {mesh_shape}..."
|
||||||
|
)
|
||||||
# SINGLE DEVICE RUN
|
# SINGLE DEVICE RUN
|
||||||
cosmo._workspace = {}
|
cosmo._workspace = {}
|
||||||
if absolute_painting:
|
if absolute_painting:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue