update code

This commit is contained in:
Wassim Kabalan 2024-12-06 18:56:24 +01:00
parent e0c118a540
commit 21373b89ee
7 changed files with 84 additions and 100 deletions

View file

@ -25,72 +25,71 @@ def _chunk_split(ptcl_num, chunk_size, *arrays):
return remainder, chunks
def enmesh(i1, d1, a1, s1, b12, a2, s2):
def enmesh(base_indices, displacements, cell_size, base_shape, offset, new_cell_size, new_shape):
"""Multilinear enmeshing."""
i1 = jnp.asarray(i1)
d1 = jnp.asarray(d1)
base_indices = jnp.asarray(base_indices)
displacements = jnp.asarray(displacements)
with jax.experimental.enable_x64():
a1 = jnp.float64(a1) if a2 is not None else jnp.array(a1,
dtype=d1.dtype)
if s1 is not None:
s1 = jnp.array(s1, dtype=i1.dtype)
b12 = jnp.float64(b12)
if a2 is not None:
a2 = jnp.float64(a2)
if s2 is not None:
s2 = jnp.array(s2, dtype=i1.dtype)
cell_size = jnp.float64(cell_size) if new_cell_size is not None else jnp.array(cell_size, dtype=displacements.dtype)
if base_shape is not None:
base_shape = jnp.array(base_shape, dtype=base_indices.dtype)
offset = jnp.float64(offset)
if new_cell_size is not None:
new_cell_size = jnp.float64(new_cell_size)
if new_shape is not None:
new_shape = jnp.array(new_shape, dtype=base_indices.dtype)
dim = i1.shape[1]
neighbors = (jnp.arange(2**dim, dtype=i1.dtype)[:, jnp.newaxis] >>
jnp.arange(dim, dtype=i1.dtype)) & 1
spatial_dim = base_indices.shape[1]
neighbor_offsets = (jnp.arange(2**spatial_dim, dtype=base_indices.dtype)[:, jnp.newaxis] >>
jnp.arange(spatial_dim, dtype=base_indices.dtype)) & 1
if a2 is not None:
P = i1 * a1 + d1 - b12
P = P[:, jnp.newaxis] # insert neighbor axis
i2 = P + neighbors * a2 # multilinear
if new_cell_size is not None:
particle_positions = base_indices * cell_size + displacements - offset
particle_positions = particle_positions[:, jnp.newaxis] # insert neighbor axis
new_indices = particle_positions + neighbor_offsets * new_cell_size # multilinear
if s1 is not None:
L = s1 * a1
i2 %= L
if base_shape is not None:
grid_length = base_shape * cell_size
new_indices %= grid_length
i2 //= a2
d2 = P - i2 * a2
new_indices //= new_cell_size
new_displacements = particle_positions - new_indices * new_cell_size
if s1 is not None:
d2 -= jnp.rint(d2 / L) * L # also abs(d2) < a2 is expected
if base_shape is not None:
new_displacements -= jnp.rint(new_displacements / grid_length) * grid_length # also abs(new_displacements) < new_cell_size is expected
i2 = i2.astype(i1.dtype)
d2 = d2.astype(d1.dtype)
a2 = a2.astype(d1.dtype)
new_indices = new_indices.astype(base_indices.dtype)
new_displacements = new_displacements.astype(displacements.dtype)
new_cell_size = new_cell_size.astype(displacements.dtype)
d2 /= a2
new_displacements /= new_cell_size
else:
i12, d12 = jnp.divmod(b12, a1)
i1 -= i12.astype(i1.dtype)
d1 -= d12.astype(d1.dtype)
offset_indices, offset_displacements = jnp.divmod(offset, cell_size)
base_indices -= offset_indices.astype(base_indices.dtype)
displacements -= offset_displacements.astype(displacements.dtype)
# insert neighbor axis
i1 = i1[:, jnp.newaxis]
d1 = d1[:, jnp.newaxis]
base_indices = base_indices[:, jnp.newaxis]
displacements = displacements[:, jnp.newaxis]
# multilinear
d1 /= a1
i2 = jnp.floor(d1).astype(i1.dtype)
i2 += neighbors
d2 = d1 - i2
i2 += i1
displacements /= cell_size
new_indices = jnp.floor(displacements).astype(base_indices.dtype)
new_indices += neighbor_offsets
new_displacements = displacements - new_indices
new_indices += base_indices
if s1 is not None:
i2 %= s1
if base_shape is not None:
new_indices %= base_shape
f2 = 1 - jnp.abs(d2)
weights = 1 - jnp.abs(new_displacements)
if s1 is None and s2 is not None: # all i2 >= 0 if s1 is not None
i2 = jnp.where(i2 < 0, s2, i2)
if base_shape is None and new_shape is not None: # all new_indices >= 0 if base_shape is not None
new_indices = jnp.where(new_indices < 0, new_shape, new_indices)
f2 = f2.prod(axis=-1)
weights = weights.prod(axis=-1)
return i2, f2
return new_indices, weights
def _scatter_chunk(carry, chunk):
@ -138,7 +137,7 @@ def _chunk_cat(remainder_array, chunked_array):
return array
def gather(pmid, disp, mesh, chunk_size=2**24, val=1, offset=0, cell_size=1.):
def gather(pmid, disp, mesh, chunk_size=2**24, val=0, offset=0, cell_size=1.):
ptcl_num, spatial_ndim = pmid.shape
mesh = jnp.asarray(mesh)