mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-05-13 11:31:11 +00:00
fix 2D painting when input is (X , Y , 2) shape
This commit is contained in:
parent
20ace41d32
commit
d4049e5db4
1 changed files with 2 additions and 1 deletions
|
@ -134,6 +134,7 @@ def cic_paint_2d(mesh, positions, weight):
|
||||||
positions: [npart, 2]
|
positions: [npart, 2]
|
||||||
weight: [npart]
|
weight: [npart]
|
||||||
"""
|
"""
|
||||||
|
positions = positions.reshape([-1, 2])
|
||||||
positions = jnp.expand_dims(positions, 1)
|
positions = jnp.expand_dims(positions, 1)
|
||||||
floor = jnp.floor(positions)
|
floor = jnp.floor(positions)
|
||||||
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
|
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
|
||||||
|
@ -142,7 +143,7 @@ def cic_paint_2d(mesh, positions, weight):
|
||||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||||
kernel = kernel[..., 0] * kernel[..., 1]
|
kernel = kernel[..., 0] * kernel[..., 1]
|
||||||
if weight is not None:
|
if weight is not None:
|
||||||
kernel = kernel * weight[..., jnp.newaxis]
|
kernel = kernel * weight.reshape(*positions.shape[:-1])
|
||||||
|
|
||||||
neighboor_coords = jnp.mod(
|
neighboor_coords = jnp.mod(
|
||||||
neighboor_coords.reshape([-1, 4, 2]).astype('int32'),
|
neighboor_coords.reshape([-1, 4, 2]).astype('int32'),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue