mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 20:30:54 +00:00
use float64 for enmeshing
This commit is contained in:
parent
2ad035a75b
commit
b3a264ad53
1 changed files with 10 additions and 8 deletions
|
@ -29,14 +29,16 @@ def enmesh(i1, d1, a1, s1, b12, a2, s2):
|
|||
"""Multilinear enmeshing."""
|
||||
i1 = jnp.asarray(i1)
|
||||
d1 = jnp.asarray(d1)
|
||||
a1 = jnp.float64(a1) if a2 is not None else jnp.array(a1, dtype=d1.dtype)
|
||||
if s1 is not None:
|
||||
s1 = jnp.array(s1, dtype=i1.dtype)
|
||||
b12 = jnp.float64(b12)
|
||||
if a2 is not None:
|
||||
a2 = jnp.float64(a2)
|
||||
if s2 is not None:
|
||||
s2 = jnp.array(s2, dtype=i1.dtype)
|
||||
with jax.experimental.enable_x64():
|
||||
a1 = jnp.float64(a1) if a2 is not None else jnp.array(a1,
|
||||
dtype=d1.dtype)
|
||||
if s1 is not None:
|
||||
s1 = jnp.array(s1, dtype=i1.dtype)
|
||||
b12 = jnp.float64(b12)
|
||||
if a2 is not None:
|
||||
a2 = jnp.float64(a2)
|
||||
if s2 is not None:
|
||||
s2 = jnp.array(s2, dtype=i1.dtype)
|
||||
|
||||
dim = i1.shape[1]
|
||||
neighbors = (jnp.arange(2**dim, dtype=i1.dtype)[:, jnp.newaxis] >>
|
||||
|
|
Loading…
Add table
Reference in a new issue