mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-05-15 04:21:12 +00:00
merge with JZ
This commit is contained in:
commit
ab86699c88
7 changed files with 226 additions and 131 deletions
|
@ -50,15 +50,15 @@ def cic_paint_impl(mesh, displacement, weight=None):
|
|||
@partial(jax.jit, static_argnums=(2, ))
|
||||
def cic_paint(mesh, positions, halo_size=0, weight=None):
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size)
|
||||
mesh = slice_pad(mesh, halo_size)
|
||||
halo_padding, halo_extents = get_halo_size(halo_size)
|
||||
mesh = slice_pad(mesh, halo_padding)
|
||||
mesh = autoshmap(cic_paint_impl,
|
||||
in_specs=(P('x', 'y'), P('x', 'y'), P()),
|
||||
out_specs=P('x', 'y'))(mesh, positions, weight)
|
||||
mesh = halo_exchange(mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True, True))
|
||||
mesh = slice_unpad(mesh, halo_size)
|
||||
halo_extents=halo_size // 2,
|
||||
halo_periods=True)
|
||||
mesh = slice_unpad(mesh, halo_padding)
|
||||
return mesh
|
||||
|
||||
|
||||
|
@ -95,11 +95,11 @@ def cic_read_impl(mesh, displacement):
|
|||
@partial(jax.jit, static_argnums=(2, ))
|
||||
def cic_read(mesh, displacement, halo_size=0):
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size)
|
||||
mesh = slice_pad(mesh, halo_size)
|
||||
halo_padding, halo_extents = get_halo_size(halo_size)
|
||||
mesh = slice_pad(mesh, halo_padding)
|
||||
mesh = halo_exchange(mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True, True))
|
||||
halo_extents=halo_size//2,
|
||||
halo_periods=True)
|
||||
displacement = autoshmap(cic_read_impl,
|
||||
in_specs=(P('x', 'y'), P('x', 'y')),
|
||||
out_specs=P('x', 'y'))(mesh, displacement)
|
||||
|
@ -159,17 +159,24 @@ def cic_paint_dx_impl(displacements, halo_size):
|
|||
|
||||
@partial(jax.jit, static_argnums=(1, ))
|
||||
def cic_paint_dx(displacements, halo_size=0):
|
||||
<<<<<<< HEAD
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size)
|
||||
|
||||
mesh = autoshmap(partial(cic_paint_dx_impl, halo_size=halo_size),
|
||||
=======
|
||||
|
||||
halo_padding, halo_extents = get_halo_size(halo_size)
|
||||
|
||||
mesh = autoshmap(partial(cic_paint_dx_impl, halo_size=halo_padding),
|
||||
>>>>>>> glab/ASKabalan/jaxdecomp_proto
|
||||
in_specs=(P('x', 'y')),
|
||||
out_specs=P('x', 'y'))(displacements)
|
||||
|
||||
mesh = halo_exchange(mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True, True))
|
||||
mesh = slice_unpad(mesh, halo_size)
|
||||
halo_extents=halo_size//2,
|
||||
halo_periods=True)
|
||||
mesh = slice_unpad(mesh, halo_padding)
|
||||
return mesh
|
||||
|
||||
|
||||
|
@ -196,12 +203,18 @@ def cic_read_dx_impl(mesh, halo_size):
|
|||
@partial(jax.jit, static_argnums=(1, ))
|
||||
def cic_read_dx(mesh, halo_size=0):
|
||||
# return mesh
|
||||
halo_size, halo_extents = get_halo_size(halo_size)
|
||||
mesh = slice_pad(mesh, halo_size)
|
||||
halo_padding, halo_extents = get_halo_size(halo_size)
|
||||
mesh = slice_pad(mesh, halo_padding)
|
||||
mesh = halo_exchange(mesh,
|
||||
<<<<<<< HEAD
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True, True))
|
||||
displacements = autoshmap(partial(cic_read_dx_impl, halo_size=halo_size),
|
||||
=======
|
||||
halo_extents=halo_size//2,
|
||||
halo_periods=True)
|
||||
displacements = autoshmap(partial(cic_read_dx_impl , halo_size=halo_padding),
|
||||
>>>>>>> glab/ASKabalan/jaxdecomp_proto
|
||||
in_specs=(P('x', 'y')),
|
||||
out_specs=P('x', 'y'))(mesh)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue