From b3a264ad53703bad36318aa42a28a46cfd1d3e0d Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Wed, 30 Oct 2024 01:56:29 +0100 Subject: [PATCH] use float64 for enmeshing --- jaxpm/painting_utils.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/jaxpm/painting_utils.py b/jaxpm/painting_utils.py index 1d929ea..a0319a5 100644 --- a/jaxpm/painting_utils.py +++ b/jaxpm/painting_utils.py @@ -29,14 +29,16 @@ def enmesh(i1, d1, a1, s1, b12, a2, s2): """Multilinear enmeshing.""" i1 = jnp.asarray(i1) d1 = jnp.asarray(d1) - a1 = jnp.float64(a1) if a2 is not None else jnp.array(a1, dtype=d1.dtype) - if s1 is not None: - s1 = jnp.array(s1, dtype=i1.dtype) - b12 = jnp.float64(b12) - if a2 is not None: - a2 = jnp.float64(a2) - if s2 is not None: - s2 = jnp.array(s2, dtype=i1.dtype) + with jax.experimental.enable_x64(): + a1 = jnp.float64(a1) if a2 is not None else jnp.array(a1, + dtype=d1.dtype) + if s1 is not None: + s1 = jnp.array(s1, dtype=i1.dtype) + b12 = jnp.float64(b12) + if a2 is not None: + a2 = jnp.float64(a2) + if s2 is not None: + s2 = jnp.array(s2, dtype=i1.dtype) dim = i1.shape[1] neighbors = (jnp.arange(2**dim, dtype=i1.dtype)[:, jnp.newaxis] >>