This commit is contained in:
Wassim Kabalan 2024-12-08 22:45:09 +01:00
parent 7823fdaf98
commit af29c4005d
7 changed files with 68 additions and 63 deletions

View file

@ -222,12 +222,11 @@ def cic_read_dx_impl(grid_mesh, disp, halo_size):
pmid = pmid.reshape([-1, 3])
disp = disp.reshape([-1, 3])
return gather(pmid, disp,
grid_mesh).reshape(original_shape)
return gather(pmid, disp, grid_mesh).reshape(original_shape)
@partial(jax.jit, static_argnums=(2, 3))
def cic_read_dx(grid_mesh,disp , halo_size=0, sharding=None):
def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None):
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
@ -239,7 +238,7 @@ def cic_read_dx(grid_mesh,disp , halo_size=0, sharding=None):
displacements = autoshmap(partial(cic_read_dx_impl, halo_size=halo_size),
gpu_mesh=gpu_mesh,
in_specs=(spec),
out_specs=spec)(grid_mesh , disp)
out_specs=spec)(grid_mesh, disp)
return displacements