update code

This commit is contained in:
Wassim Kabalan 2024-12-06 18:56:24 +01:00
parent e0c118a540
commit 21373b89ee
7 changed files with 84 additions and 100 deletions

View file

@ -204,7 +204,7 @@ def cic_paint_dx(displacements,
return grid_mesh
def cic_read_dx_impl(grid_mesh, halo_size):
def cic_read_dx_impl(grid_mesh, disp, halo_size):
halo_x, _ = halo_size[0]
halo_y, _ = halo_size[1]
@ -220,14 +220,15 @@ def cic_read_dx_impl(grid_mesh, halo_size):
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
pmid = pmid.reshape([-1, 3])
disp = disp.reshape([-1, 3])
return gather(pmid, jnp.zeros_like(pmid),
return gather(pmid, disp,
grid_mesh).reshape(original_shape)
@partial(jax.jit, static_argnums=(1, 2))
def cic_read_dx(grid_mesh, halo_size=0, sharding=None):
# return mesh
@partial(jax.jit, static_argnums=(2, 3))
def cic_read_dx(grid_mesh,disp , halo_size=0, sharding=None):
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
grid_mesh = halo_exchange(grid_mesh,
@ -238,7 +239,7 @@ def cic_read_dx(grid_mesh, halo_size=0, sharding=None):
displacements = autoshmap(partial(cic_read_dx_impl, halo_size=halo_size),
gpu_mesh=gpu_mesh,
in_specs=(spec),
out_specs=spec)(grid_mesh)
out_specs=spec)(grid_mesh , disp)
return displacements