This commit is contained in:
Wassim Kabalan 2025-02-28 13:47:43 +01:00
parent 580387ce1c
commit 9f494da317
3 changed files with 62 additions and 75 deletions

View file

@ -30,8 +30,7 @@ def _cic_paint_impl(grid_mesh, positions, weight=1.):
if jnp.isscalar(weight):
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
else:
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]),
kernel)
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]), kernel)
neighboor_coords = jnp.mod(
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
@ -158,7 +157,10 @@ def cic_paint_2d(mesh, positions, weight):
return mesh
def _cic_paint_dx_impl(displacements, weight=1. , halo_size=0 , chunk_size=2**24):
def _cic_paint_dx_impl(displacements,
weight=1.,
halo_size=0,
chunk_size=2**24):
halo_x, _ = halo_size[0]
halo_y, _ = halo_size[1]
@ -203,7 +205,7 @@ def cic_paint_dx(displacements,
chunk_size=chunk_size),
gpu_mesh=gpu_mesh,
in_specs=(spec, weight_spec),
out_specs=spec)(displacements , weight)
out_specs=spec)(displacements, weight)
grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,