remove deprecated stuff

This commit is contained in:
Francois Lanusse 2024-10-24 16:36:41 -04:00 committed by Wassim KABALAN
parent 8e8e8964be
commit 0f833f0cb4
8 changed files with 96 additions and 831 deletions

View file

@ -50,6 +50,8 @@ def cic_paint_impl(grid_mesh, displacement, weight=None):
@partial(jax.jit, static_argnums=(2, 3, 4))
def cic_paint(grid_mesh, positions, halo_size=0, weight=None, sharding=None):
positions = positions.reshape((*grid_mesh.shape, 3))
halo_size, halo_extents = get_halo_size(halo_size, sharding)
grid_mesh = slice_pad(grid_mesh, halo_size, sharding)
@ -63,6 +65,8 @@ def cic_paint(grid_mesh, positions, halo_size=0, weight=None, sharding=None):
halo_extents=halo_extents,
halo_periods=(True, True))
grid_mesh = slice_unpad(grid_mesh, halo_size, sharding)
print(f"shape of grid_mesh: {grid_mesh.shape}")
return grid_mesh
@ -97,19 +101,20 @@ def cic_read_impl(mesh, displacement):
@partial(jax.jit, static_argnums=(2, 3))
def cic_read(mesh, displacement, halo_size=0, sharding=None):
def cic_read(grid_mesh, positions, halo_size=0, sharding=None):
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
mesh = slice_pad(mesh, halo_size, sharding=sharding)
mesh = halo_exchange(mesh,
halo_extents=halo_extents,
halo_periods=(True, True))
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,
halo_periods=(True, True))
gpu_mesh = sharding.mesh if sharding is not None else None
spec = sharding.spec if sharding is not None else P()
displacement = autoshmap(cic_read_impl,
gpu_mesh=gpu_mesh,
in_specs=(spec, spec),
out_specs=spec)(mesh, displacement)
out_specs=spec)(grid_mesh, positions)
print(f"shape of displacement: {displacement.shape}")
return displacement