diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index 5498b3d..426f9f3 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -86,13 +86,13 @@ def get_halo_size(halo_size): halo_x_ext = 0 if pdims[0] == 1 else halo_size // 2 halo_y_ext = 0 if pdims[1] == 1 else halo_size // 2 - return ((halo_x, halo_y, (0, 0)), (halo_x_ext, halo_y_ext, 0)) + return ((halo_x, halo_y, (0, 0)), (halo_x_ext, halo_y_ext)) def halo_exchange(x, halo_extents, halo_periods=(True, True, True)): mesh = mesh_lib.thread_resources.env.physical_mesh - if distributed and not (mesh.empty) and (halo_extents > 0 - or halo_extents > 0): + if distributed and not (mesh.empty) and (halo_extents[0] > 0 + or halo_extents[1] > 0): return jaxdecomp.halo_exchange(x, halo_extents, halo_periods) else: return x diff --git a/jaxpm/painting.py b/jaxpm/painting.py index 597f3aa..c997739 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -50,23 +50,23 @@ def cic_paint_impl(mesh, displacement, weight=None): @partial(jax.jit, static_argnums=(2, )) def cic_paint(mesh, positions, halo_size=0, weight=None): - halo_padding, halo_extents = get_halo_size(halo_size) - mesh = slice_pad(mesh, halo_padding) + halo_size, halo_extents = get_halo_size(halo_size) + mesh = slice_pad(mesh, halo_size) mesh = autoshmap(cic_paint_impl, in_specs=(P('x', 'y'), P('x', 'y'), P()), out_specs=P('x', 'y'))(mesh, positions, weight) mesh = halo_exchange(mesh, - halo_extents=halo_size // 2, - halo_periods=True) - mesh = slice_unpad(mesh, halo_padding) + halo_extents=halo_extents, + halo_periods=(True, True)) + mesh = slice_unpad(mesh, halo_size) return mesh def cic_read_impl(mesh, displacement): """ Paints positions onto mesh - mesh: [nx, ny, nz] - displacement: [nx,ny,nz, 3] - """ + mesh: [nx, ny, nz] + displacement: [nx,ny,nz, 3] + """ # Compute the position of the particles on a regular grid part_shape = displacement.shape positions = jnp.stack(jnp.meshgrid(jnp.arange(part_shape[0]), @@ -95,11 +95,11 @@ def cic_read_impl(mesh, displacement): @partial(jax.jit, static_argnums=(2, )) def cic_read(mesh, displacement, halo_size=0): - halo_padding, halo_extents = get_halo_size(halo_size) - mesh = slice_pad(mesh, halo_padding) + halo_size, halo_extents = get_halo_size(halo_size) + mesh = slice_pad(mesh, halo_size) mesh = halo_exchange(mesh, - halo_extents=halo_size//2, - halo_periods=True) + halo_extents=halo_extents, + halo_periods=(True, True)) displacement = autoshmap(cic_read_impl, in_specs=(P('x', 'y'), P('x', 'y')), out_specs=P('x', 'y'))(mesh, displacement) @@ -159,24 +159,17 @@ def cic_paint_dx_impl(displacements, halo_size): @partial(jax.jit, static_argnums=(1, )) def cic_paint_dx(displacements, halo_size=0): -<<<<<<< HEAD halo_size, halo_extents = get_halo_size(halo_size) mesh = autoshmap(partial(cic_paint_dx_impl, halo_size=halo_size), -======= - - halo_padding, halo_extents = get_halo_size(halo_size) - - mesh = autoshmap(partial(cic_paint_dx_impl, halo_size=halo_padding), ->>>>>>> glab/ASKabalan/jaxdecomp_proto in_specs=(P('x', 'y')), out_specs=P('x', 'y'))(displacements) mesh = halo_exchange(mesh, - halo_extents=halo_size//2, - halo_periods=True) - mesh = slice_unpad(mesh, halo_padding) + halo_extents=halo_extents, + halo_periods=(True, True)) + mesh = slice_unpad(mesh, halo_size) return mesh @@ -203,18 +196,12 @@ def cic_read_dx_impl(mesh, halo_size): @partial(jax.jit, static_argnums=(1, )) def cic_read_dx(mesh, halo_size=0): # return mesh - halo_padding, halo_extents = get_halo_size(halo_size) - mesh = slice_pad(mesh, halo_padding) + halo_size, halo_extents = get_halo_size(halo_size) + mesh = slice_pad(mesh, halo_size) mesh = halo_exchange(mesh, -<<<<<<< HEAD halo_extents=halo_extents, - halo_periods=(True, True, True)) + halo_periods=(True, True)) displacements = autoshmap(partial(cic_read_dx_impl, halo_size=halo_size), -======= - halo_extents=halo_size//2, - halo_periods=True) - displacements = autoshmap(partial(cic_read_dx_impl , halo_size=halo_padding), ->>>>>>> glab/ASKabalan/jaxdecomp_proto in_specs=(P('x', 'y')), out_specs=P('x', 'y'))(mesh) @@ -235,3 +222,4 @@ def compensate_cic(field): delta_k = jnp.fft.rfftn(field) delta_k = cic_compensation(kvec) * delta_k return jnp.fft.irfftn(delta_k) +