mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-29 16:41:11 +00:00
update code
This commit is contained in:
parent
e0c118a540
commit
21373b89ee
7 changed files with 84 additions and 100 deletions
|
@ -25,72 +25,71 @@ def _chunk_split(ptcl_num, chunk_size, *arrays):
|
|||
return remainder, chunks
|
||||
|
||||
|
||||
def enmesh(i1, d1, a1, s1, b12, a2, s2):
|
||||
def enmesh(base_indices, displacements, cell_size, base_shape, offset, new_cell_size, new_shape):
|
||||
"""Multilinear enmeshing."""
|
||||
i1 = jnp.asarray(i1)
|
||||
d1 = jnp.asarray(d1)
|
||||
base_indices = jnp.asarray(base_indices)
|
||||
displacements = jnp.asarray(displacements)
|
||||
with jax.experimental.enable_x64():
|
||||
a1 = jnp.float64(a1) if a2 is not None else jnp.array(a1,
|
||||
dtype=d1.dtype)
|
||||
if s1 is not None:
|
||||
s1 = jnp.array(s1, dtype=i1.dtype)
|
||||
b12 = jnp.float64(b12)
|
||||
if a2 is not None:
|
||||
a2 = jnp.float64(a2)
|
||||
if s2 is not None:
|
||||
s2 = jnp.array(s2, dtype=i1.dtype)
|
||||
cell_size = jnp.float64(cell_size) if new_cell_size is not None else jnp.array(cell_size, dtype=displacements.dtype)
|
||||
if base_shape is not None:
|
||||
base_shape = jnp.array(base_shape, dtype=base_indices.dtype)
|
||||
offset = jnp.float64(offset)
|
||||
if new_cell_size is not None:
|
||||
new_cell_size = jnp.float64(new_cell_size)
|
||||
if new_shape is not None:
|
||||
new_shape = jnp.array(new_shape, dtype=base_indices.dtype)
|
||||
|
||||
dim = i1.shape[1]
|
||||
neighbors = (jnp.arange(2**dim, dtype=i1.dtype)[:, jnp.newaxis] >>
|
||||
jnp.arange(dim, dtype=i1.dtype)) & 1
|
||||
spatial_dim = base_indices.shape[1]
|
||||
neighbor_offsets = (jnp.arange(2**spatial_dim, dtype=base_indices.dtype)[:, jnp.newaxis] >>
|
||||
jnp.arange(spatial_dim, dtype=base_indices.dtype)) & 1
|
||||
|
||||
if a2 is not None:
|
||||
P = i1 * a1 + d1 - b12
|
||||
P = P[:, jnp.newaxis] # insert neighbor axis
|
||||
i2 = P + neighbors * a2 # multilinear
|
||||
if new_cell_size is not None:
|
||||
particle_positions = base_indices * cell_size + displacements - offset
|
||||
particle_positions = particle_positions[:, jnp.newaxis] # insert neighbor axis
|
||||
new_indices = particle_positions + neighbor_offsets * new_cell_size # multilinear
|
||||
|
||||
if s1 is not None:
|
||||
L = s1 * a1
|
||||
i2 %= L
|
||||
if base_shape is not None:
|
||||
grid_length = base_shape * cell_size
|
||||
new_indices %= grid_length
|
||||
|
||||
i2 //= a2
|
||||
d2 = P - i2 * a2
|
||||
new_indices //= new_cell_size
|
||||
new_displacements = particle_positions - new_indices * new_cell_size
|
||||
|
||||
if s1 is not None:
|
||||
d2 -= jnp.rint(d2 / L) * L # also abs(d2) < a2 is expected
|
||||
if base_shape is not None:
|
||||
new_displacements -= jnp.rint(new_displacements / grid_length) * grid_length # also abs(new_displacements) < new_cell_size is expected
|
||||
|
||||
i2 = i2.astype(i1.dtype)
|
||||
d2 = d2.astype(d1.dtype)
|
||||
a2 = a2.astype(d1.dtype)
|
||||
new_indices = new_indices.astype(base_indices.dtype)
|
||||
new_displacements = new_displacements.astype(displacements.dtype)
|
||||
new_cell_size = new_cell_size.astype(displacements.dtype)
|
||||
|
||||
d2 /= a2
|
||||
new_displacements /= new_cell_size
|
||||
else:
|
||||
i12, d12 = jnp.divmod(b12, a1)
|
||||
i1 -= i12.astype(i1.dtype)
|
||||
d1 -= d12.astype(d1.dtype)
|
||||
offset_indices, offset_displacements = jnp.divmod(offset, cell_size)
|
||||
base_indices -= offset_indices.astype(base_indices.dtype)
|
||||
displacements -= offset_displacements.astype(displacements.dtype)
|
||||
|
||||
# insert neighbor axis
|
||||
i1 = i1[:, jnp.newaxis]
|
||||
d1 = d1[:, jnp.newaxis]
|
||||
base_indices = base_indices[:, jnp.newaxis]
|
||||
displacements = displacements[:, jnp.newaxis]
|
||||
|
||||
# multilinear
|
||||
d1 /= a1
|
||||
i2 = jnp.floor(d1).astype(i1.dtype)
|
||||
i2 += neighbors
|
||||
d2 = d1 - i2
|
||||
i2 += i1
|
||||
displacements /= cell_size
|
||||
new_indices = jnp.floor(displacements).astype(base_indices.dtype)
|
||||
new_indices += neighbor_offsets
|
||||
new_displacements = displacements - new_indices
|
||||
new_indices += base_indices
|
||||
|
||||
if s1 is not None:
|
||||
i2 %= s1
|
||||
if base_shape is not None:
|
||||
new_indices %= base_shape
|
||||
|
||||
f2 = 1 - jnp.abs(d2)
|
||||
weights = 1 - jnp.abs(new_displacements)
|
||||
|
||||
if s1 is None and s2 is not None: # all i2 >= 0 if s1 is not None
|
||||
i2 = jnp.where(i2 < 0, s2, i2)
|
||||
if base_shape is None and new_shape is not None: # all new_indices >= 0 if base_shape is not None
|
||||
new_indices = jnp.where(new_indices < 0, new_shape, new_indices)
|
||||
|
||||
f2 = f2.prod(axis=-1)
|
||||
weights = weights.prod(axis=-1)
|
||||
|
||||
return i2, f2
|
||||
return new_indices, weights
|
||||
|
||||
|
||||
def _scatter_chunk(carry, chunk):
|
||||
|
@ -138,7 +137,7 @@ def _chunk_cat(remainder_array, chunked_array):
|
|||
return array
|
||||
|
||||
|
||||
def gather(pmid, disp, mesh, chunk_size=2**24, val=1, offset=0, cell_size=1.):
|
||||
def gather(pmid, disp, mesh, chunk_size=2**24, val=0, offset=0, cell_size=1.):
|
||||
ptcl_num, spatial_ndim = pmid.shape
|
||||
|
||||
mesh = jnp.asarray(mesh)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue