update tests

This commit is contained in:
Wassim Kabalan 2025-06-08 15:27:32 +02:00
parent e0ba85fb58
commit c6a7dd4e4e
2 changed files with 10 additions and 12 deletions

View file

@ -30,17 +30,15 @@ def enmesh(base_indices, displacements, cell_size, base_shape, offset,
"""Multilinear enmeshing."""
base_indices = jnp.asarray(base_indices)
displacements = jnp.asarray(displacements)
with jax.experimental.enable_x64():
cell_size = jnp.float64(
cell_size) if new_cell_size is not None else jnp.array(
cell_size, dtype=displacements.dtype)
if base_shape is not None:
base_shape = jnp.array(base_shape, dtype=base_indices.dtype)
offset = jnp.float64(offset)
if new_cell_size is not None:
new_cell_size = jnp.float64(new_cell_size)
if new_shape is not None:
new_shape = jnp.array(new_shape, dtype=base_indices.dtype)
cell_size = jnp.array(
cell_size, dtype=displacements.dtype)
if base_shape is not None:
base_shape = jnp.array(base_shape, dtype=base_indices.dtype)
offset = offset.astype(base_indices.dtype)
if new_cell_size is not None:
new_cell_size = jnp.array(new_cell_size, dtype=displacements.dtype)
if new_shape is not None:
new_shape = jnp.array(new_shape, dtype=base_indices.dtype)
spatial_dim = base_indices.shape[1]
neighbor_offsets = (

View file

@ -22,7 +22,7 @@ from jaxpm.distributed import fft3d, ifft3d
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
from jaxpm.pm import lpt, make_diffrax_ode, pm_forces # noqa : E402
_TOLERANCE = 1e-8 # 🎉🎉🎉
_TOLERANCE = 1e-6 # 🎉🎉🎉
pdims = [(1, 8), (8, 1), (4, 2), (2, 4)]