diff --git a/jaxpm/distributed_ops.py b/jaxpm/distributed_ops.py index 9adcfc9..6417f35 100644 --- a/jaxpm/distributed_ops.py +++ b/jaxpm/distributed_ops.py @@ -147,24 +147,24 @@ def cic_paint(pos, mesh_shape, halo_size=0): # Perform halo exchange # Halo exchange along x - left = lax.pshuffle(mesh[-halo_size:], + left = lax.pshuffle(mesh[-2*halo_size:], perm=range(mesh_size['nx'])[::-1], axis_name='x') - right = lax.pshuffle(mesh[:halo_size], + right = lax.pshuffle(mesh[:2*halo_size], perm=range(mesh_size['nx'])[::-1], axis_name='x') - mesh = mesh.at[:halo_size].add(left) - mesh = mesh.at[-halo_size:].add(right) + mesh = mesh.at[:2*halo_size].add(left) + mesh = mesh.at[-2*halo_size:].add(right) # Halo exchange along y - left = lax.pshuffle(mesh[:, -halo_size:], + left = lax.pshuffle(mesh[:, -2*halo_size:], perm=range(mesh_size['ny'])[::-1], axis_name='y') - right = lax.pshuffle(mesh[:, :halo_size], + right = lax.pshuffle(mesh[:, :2*halo_size], perm=range(mesh_size['ny'])[::-1], axis_name='y') - mesh = mesh.at[:, :halo_size].add(left) - mesh = mesh.at[:, -halo_size:].add(right) + mesh = mesh.at[:, :2*halo_size].add(left) + mesh = mesh.at[:, -2*halo_size:].add(right) # removing halo and returning mesh return mesh[halo_size:-halo_size, halo_size:-halo_size] diff --git a/jaxpm/distributed_pm.py b/jaxpm/distributed_pm.py index 7944fd2..69ea2c0 100644 --- a/jaxpm/distributed_pm.py +++ b/jaxpm/distributed_pm.py @@ -64,7 +64,6 @@ def lpt(cosmo, initial_conditions, positions, a): Computes first order LPT displacement """ initial_force = pm_forces(positions, delta_k=initial_conditions) - print(initial_force.shape) a = jnp.atleast_1d(a) dx = dops.scalar_multiply(initial_force, growth_factor(cosmo, a)) p = dops.scalar_multiply(dx, a**2 * growth_rate(cosmo, a) *