mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-15 10:21:11 +00:00
update tests
This commit is contained in:
parent
e0ba85fb58
commit
c6a7dd4e4e
2 changed files with 10 additions and 12 deletions
|
@ -30,17 +30,15 @@ def enmesh(base_indices, displacements, cell_size, base_shape, offset,
|
|||
"""Multilinear enmeshing."""
|
||||
base_indices = jnp.asarray(base_indices)
|
||||
displacements = jnp.asarray(displacements)
|
||||
with jax.experimental.enable_x64():
|
||||
cell_size = jnp.float64(
|
||||
cell_size) if new_cell_size is not None else jnp.array(
|
||||
cell_size, dtype=displacements.dtype)
|
||||
if base_shape is not None:
|
||||
base_shape = jnp.array(base_shape, dtype=base_indices.dtype)
|
||||
offset = jnp.float64(offset)
|
||||
if new_cell_size is not None:
|
||||
new_cell_size = jnp.float64(new_cell_size)
|
||||
if new_shape is not None:
|
||||
new_shape = jnp.array(new_shape, dtype=base_indices.dtype)
|
||||
cell_size = jnp.array(
|
||||
cell_size, dtype=displacements.dtype)
|
||||
if base_shape is not None:
|
||||
base_shape = jnp.array(base_shape, dtype=base_indices.dtype)
|
||||
offset = offset.astype(base_indices.dtype)
|
||||
if new_cell_size is not None:
|
||||
new_cell_size = jnp.array(new_cell_size, dtype=displacements.dtype)
|
||||
if new_shape is not None:
|
||||
new_shape = jnp.array(new_shape, dtype=base_indices.dtype)
|
||||
|
||||
spatial_dim = base_indices.shape[1]
|
||||
neighbor_offsets = (
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue