Update for jaxDecomp pure JAX

This commit is contained in:
Wassim KABALAN 2024-08-07 23:52:13 +02:00
parent 831291c1f9
commit 2ea05a1cd6
9 changed files with 214 additions and 532 deletions

View file

@ -50,15 +50,15 @@ def cic_paint_impl(mesh, displacement, weight=None):
@partial(jax.jit, static_argnums=(2, ))
def cic_paint(mesh, positions, halo_size=0, weight=None):
halo_size, halo_extents = get_halo_size(halo_size)
mesh = slice_pad(mesh, halo_size)
halo_padding, halo_extents = get_halo_size(halo_size)
mesh = slice_pad(mesh, halo_padding)
mesh = autoshmap(cic_paint_impl,
in_specs=(P('x', 'y'), P('x', 'y'), P()),
out_specs=P('x', 'y'))(mesh, positions, weight)
mesh = halo_exchange(mesh,
halo_extents=halo_extents,
halo_periods=(True, True, True))
mesh = slice_unpad(mesh, halo_size)
halo_extents=halo_size // 2,
halo_periods=True)
mesh = slice_unpad(mesh, halo_padding)
return mesh
@ -95,11 +95,11 @@ def cic_read_impl(mesh, displacement):
@partial(jax.jit, static_argnums=(2, ))
def cic_read(mesh, displacement, halo_size=0):
halo_size, halo_extents = get_halo_size(halo_size)
mesh = slice_pad(mesh, halo_size)
halo_padding, halo_extents = get_halo_size(halo_size)
mesh = slice_pad(mesh, halo_padding)
mesh = halo_exchange(mesh,
halo_extents=halo_extents,
halo_periods=(True, True, True))
halo_extents=halo_size//2,
halo_periods=True)
displacement = autoshmap(cic_read_impl,
in_specs=(P('x', 'y'), P('x', 'y')),
out_specs=P('x', 'y'))(mesh, displacement)
@ -160,16 +160,16 @@ def cic_paint_dx_impl(displacements, halo_size):
@partial(jax.jit, static_argnums=(1, ))
def cic_paint_dx(displacements, halo_size=0):
halo_size, halo_extents = get_halo_size(halo_size)
halo_padding, halo_extents = get_halo_size(halo_size)
mesh = autoshmap(partial(cic_paint_dx_impl, halo_size=halo_size),
mesh = autoshmap(partial(cic_paint_dx_impl, halo_size=halo_padding),
in_specs=(P('x', 'y')),
out_specs=P('x', 'y'))(displacements)
mesh = halo_exchange(mesh,
halo_extents=halo_extents,
halo_periods=(True, True, True))
mesh = slice_unpad(mesh, halo_size)
halo_extents=halo_size//2,
halo_periods=True)
mesh = slice_unpad(mesh, halo_padding)
return mesh
@ -194,12 +194,12 @@ def cic_read_dx_impl(mesh , halo_size):
@partial(jax.jit, static_argnums=(1, ))
def cic_read_dx(mesh, halo_size=0):
# return mesh
halo_size, halo_extents = get_halo_size(halo_size)
mesh = slice_pad(mesh, halo_size)
halo_padding, halo_extents = get_halo_size(halo_size)
mesh = slice_pad(mesh, halo_padding)
mesh = halo_exchange(mesh,
halo_extents=halo_extents,
halo_periods=(True, True, True))
displacements = autoshmap(partial(cic_read_dx_impl , halo_size=halo_size),
halo_extents=halo_size//2,
halo_periods=True)
displacements = autoshmap(partial(cic_read_dx_impl , halo_size=halo_padding),
in_specs=(P('x', 'y')),
out_specs=P('x', 'y'))(mesh)