diff --git a/jaxpm/painting_utils.py b/jaxpm/painting_utils.py index a27aed5..bc9cf4b 100644 --- a/jaxpm/painting_utils.py +++ b/jaxpm/painting_utils.py @@ -30,8 +30,7 @@ def enmesh(base_indices, displacements, cell_size, base_shape, offset, """Multilinear enmeshing.""" base_indices = jnp.asarray(base_indices) displacements = jnp.asarray(displacements) - cell_size = jnp.array( - cell_size, dtype=displacements.dtype) + cell_size = jnp.array(cell_size, dtype=displacements.dtype) if base_shape is not None: base_shape = jnp.array(base_shape, dtype=base_indices.dtype) offset = offset.astype(base_indices.dtype)