mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-08 04:40:53 +00:00
revert single halo extent change
This commit is contained in:
parent
ab86699c88
commit
afecb13cde
2 changed files with 22 additions and 34 deletions
|
@ -86,13 +86,13 @@ def get_halo_size(halo_size):
|
||||||
|
|
||||||
halo_x_ext = 0 if pdims[0] == 1 else halo_size // 2
|
halo_x_ext = 0 if pdims[0] == 1 else halo_size // 2
|
||||||
halo_y_ext = 0 if pdims[1] == 1 else halo_size // 2
|
halo_y_ext = 0 if pdims[1] == 1 else halo_size // 2
|
||||||
return ((halo_x, halo_y, (0, 0)), (halo_x_ext, halo_y_ext, 0))
|
return ((halo_x, halo_y, (0, 0)), (halo_x_ext, halo_y_ext))
|
||||||
|
|
||||||
|
|
||||||
def halo_exchange(x, halo_extents, halo_periods=(True, True, True)):
|
def halo_exchange(x, halo_extents, halo_periods=(True, True, True)):
|
||||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||||
if distributed and not (mesh.empty) and (halo_extents > 0
|
if distributed and not (mesh.empty) and (halo_extents[0] > 0
|
||||||
or halo_extents > 0):
|
or halo_extents[1] > 0):
|
||||||
return jaxdecomp.halo_exchange(x, halo_extents, halo_periods)
|
return jaxdecomp.halo_exchange(x, halo_extents, halo_periods)
|
||||||
else:
|
else:
|
||||||
return x
|
return x
|
||||||
|
|
|
@ -50,15 +50,15 @@ def cic_paint_impl(mesh, displacement, weight=None):
|
||||||
@partial(jax.jit, static_argnums=(2, ))
|
@partial(jax.jit, static_argnums=(2, ))
|
||||||
def cic_paint(mesh, positions, halo_size=0, weight=None):
|
def cic_paint(mesh, positions, halo_size=0, weight=None):
|
||||||
|
|
||||||
halo_padding, halo_extents = get_halo_size(halo_size)
|
halo_size, halo_extents = get_halo_size(halo_size)
|
||||||
mesh = slice_pad(mesh, halo_padding)
|
mesh = slice_pad(mesh, halo_size)
|
||||||
mesh = autoshmap(cic_paint_impl,
|
mesh = autoshmap(cic_paint_impl,
|
||||||
in_specs=(P('x', 'y'), P('x', 'y'), P()),
|
in_specs=(P('x', 'y'), P('x', 'y'), P()),
|
||||||
out_specs=P('x', 'y'))(mesh, positions, weight)
|
out_specs=P('x', 'y'))(mesh, positions, weight)
|
||||||
mesh = halo_exchange(mesh,
|
mesh = halo_exchange(mesh,
|
||||||
halo_extents=halo_size // 2,
|
halo_extents=halo_extents,
|
||||||
halo_periods=True)
|
halo_periods=(True, True))
|
||||||
mesh = slice_unpad(mesh, halo_padding)
|
mesh = slice_unpad(mesh, halo_size)
|
||||||
return mesh
|
return mesh
|
||||||
|
|
||||||
|
|
||||||
|
@ -95,11 +95,11 @@ def cic_read_impl(mesh, displacement):
|
||||||
@partial(jax.jit, static_argnums=(2, ))
|
@partial(jax.jit, static_argnums=(2, ))
|
||||||
def cic_read(mesh, displacement, halo_size=0):
|
def cic_read(mesh, displacement, halo_size=0):
|
||||||
|
|
||||||
halo_padding, halo_extents = get_halo_size(halo_size)
|
halo_size, halo_extents = get_halo_size(halo_size)
|
||||||
mesh = slice_pad(mesh, halo_padding)
|
mesh = slice_pad(mesh, halo_size)
|
||||||
mesh = halo_exchange(mesh,
|
mesh = halo_exchange(mesh,
|
||||||
halo_extents=halo_size//2,
|
halo_extents=halo_extents,
|
||||||
halo_periods=True)
|
halo_periods=(True, True))
|
||||||
displacement = autoshmap(cic_read_impl,
|
displacement = autoshmap(cic_read_impl,
|
||||||
in_specs=(P('x', 'y'), P('x', 'y')),
|
in_specs=(P('x', 'y'), P('x', 'y')),
|
||||||
out_specs=P('x', 'y'))(mesh, displacement)
|
out_specs=P('x', 'y'))(mesh, displacement)
|
||||||
|
@ -159,24 +159,17 @@ def cic_paint_dx_impl(displacements, halo_size):
|
||||||
|
|
||||||
@partial(jax.jit, static_argnums=(1, ))
|
@partial(jax.jit, static_argnums=(1, ))
|
||||||
def cic_paint_dx(displacements, halo_size=0):
|
def cic_paint_dx(displacements, halo_size=0):
|
||||||
<<<<<<< HEAD
|
|
||||||
|
|
||||||
halo_size, halo_extents = get_halo_size(halo_size)
|
halo_size, halo_extents = get_halo_size(halo_size)
|
||||||
|
|
||||||
mesh = autoshmap(partial(cic_paint_dx_impl, halo_size=halo_size),
|
mesh = autoshmap(partial(cic_paint_dx_impl, halo_size=halo_size),
|
||||||
=======
|
|
||||||
|
|
||||||
halo_padding, halo_extents = get_halo_size(halo_size)
|
|
||||||
|
|
||||||
mesh = autoshmap(partial(cic_paint_dx_impl, halo_size=halo_padding),
|
|
||||||
>>>>>>> glab/ASKabalan/jaxdecomp_proto
|
|
||||||
in_specs=(P('x', 'y')),
|
in_specs=(P('x', 'y')),
|
||||||
out_specs=P('x', 'y'))(displacements)
|
out_specs=P('x', 'y'))(displacements)
|
||||||
|
|
||||||
mesh = halo_exchange(mesh,
|
mesh = halo_exchange(mesh,
|
||||||
halo_extents=halo_size//2,
|
halo_extents=halo_extents,
|
||||||
halo_periods=True)
|
halo_periods=(True, True))
|
||||||
mesh = slice_unpad(mesh, halo_padding)
|
mesh = slice_unpad(mesh, halo_size)
|
||||||
return mesh
|
return mesh
|
||||||
|
|
||||||
|
|
||||||
|
@ -203,18 +196,12 @@ def cic_read_dx_impl(mesh, halo_size):
|
||||||
@partial(jax.jit, static_argnums=(1, ))
|
@partial(jax.jit, static_argnums=(1, ))
|
||||||
def cic_read_dx(mesh, halo_size=0):
|
def cic_read_dx(mesh, halo_size=0):
|
||||||
# return mesh
|
# return mesh
|
||||||
halo_padding, halo_extents = get_halo_size(halo_size)
|
halo_size, halo_extents = get_halo_size(halo_size)
|
||||||
mesh = slice_pad(mesh, halo_padding)
|
mesh = slice_pad(mesh, halo_size)
|
||||||
mesh = halo_exchange(mesh,
|
mesh = halo_exchange(mesh,
|
||||||
<<<<<<< HEAD
|
|
||||||
halo_extents=halo_extents,
|
halo_extents=halo_extents,
|
||||||
halo_periods=(True, True, True))
|
halo_periods=(True, True))
|
||||||
displacements = autoshmap(partial(cic_read_dx_impl, halo_size=halo_size),
|
displacements = autoshmap(partial(cic_read_dx_impl, halo_size=halo_size),
|
||||||
=======
|
|
||||||
halo_extents=halo_size//2,
|
|
||||||
halo_periods=True)
|
|
||||||
displacements = autoshmap(partial(cic_read_dx_impl , halo_size=halo_padding),
|
|
||||||
>>>>>>> glab/ASKabalan/jaxdecomp_proto
|
|
||||||
in_specs=(P('x', 'y')),
|
in_specs=(P('x', 'y')),
|
||||||
out_specs=P('x', 'y'))(mesh)
|
out_specs=P('x', 'y'))(mesh)
|
||||||
|
|
||||||
|
@ -235,3 +222,4 @@ def compensate_cic(field):
|
||||||
delta_k = jnp.fft.rfftn(field)
|
delta_k = jnp.fft.rfftn(field)
|
||||||
delta_k = cic_compensation(kvec) * delta_k
|
delta_k = cic_compensation(kvec) * delta_k
|
||||||
return jnp.fft.irfftn(delta_k)
|
return jnp.fft.irfftn(delta_k)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue