From c6a7dd4e4ed141792669718ef4a1aac3320a6cad Mon Sep 17 00:00:00 2001 From: Wassim Kabalan Date: Sun, 8 Jun 2025 15:27:32 +0200 Subject: [PATCH] update tests --- jaxpm/painting_utils.py | 20 +++++++++----------- tests/test_distributed_pm.py | 2 +- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/jaxpm/painting_utils.py b/jaxpm/painting_utils.py index cf68f9d..a27aed5 100644 --- a/jaxpm/painting_utils.py +++ b/jaxpm/painting_utils.py @@ -30,17 +30,15 @@ def enmesh(base_indices, displacements, cell_size, base_shape, offset, """Multilinear enmeshing.""" base_indices = jnp.asarray(base_indices) displacements = jnp.asarray(displacements) - with jax.experimental.enable_x64(): - cell_size = jnp.float64( - cell_size) if new_cell_size is not None else jnp.array( - cell_size, dtype=displacements.dtype) - if base_shape is not None: - base_shape = jnp.array(base_shape, dtype=base_indices.dtype) - offset = jnp.float64(offset) - if new_cell_size is not None: - new_cell_size = jnp.float64(new_cell_size) - if new_shape is not None: - new_shape = jnp.array(new_shape, dtype=base_indices.dtype) + cell_size = jnp.array( + cell_size, dtype=displacements.dtype) + if base_shape is not None: + base_shape = jnp.array(base_shape, dtype=base_indices.dtype) + offset = offset.astype(base_indices.dtype) + if new_cell_size is not None: + new_cell_size = jnp.array(new_cell_size, dtype=displacements.dtype) + if new_shape is not None: + new_shape = jnp.array(new_shape, dtype=base_indices.dtype) spatial_dim = base_indices.shape[1] neighbor_offsets = ( diff --git a/tests/test_distributed_pm.py b/tests/test_distributed_pm.py index d6bc2e6..f1d4b7f 100644 --- a/tests/test_distributed_pm.py +++ b/tests/test_distributed_pm.py @@ -22,7 +22,7 @@ from jaxpm.distributed import fft3d, ifft3d from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402 from jaxpm.pm import lpt, make_diffrax_ode, pm_forces # noqa : E402 -_TOLERANCE = 1e-8 # 🎉🎉🎉 +_TOLERANCE = 1e-6 # 🎉🎉🎉 pdims = [(1, 8), (8, 1), (4, 2), (2, 4)]