diff --git a/jaxpm/pm.py b/jaxpm/pm.py index ef4e887..3055058 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -19,11 +19,12 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0, halo_size=0): Computes gravitational forces on particles using a PM scheme """ if mesh_shape is None: + assert(delta is not None) , "If mesh_shape is not provided, delta should be provided" mesh_shape = delta.shape kvec = fftk(mesh_shape) - + if delta is None: - delta_k = fft3d(cic_paint_dx(positions, halo_size=0)) + delta_k = fft3d(cic_paint_dx(positions, halo_size=halo_size)) else: delta_k = fft3d(delta) @@ -32,7 +33,7 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0, halo_size=0): r_split=r_split) # Computes gravitational forces forces = jnp.stack([ - cic_read_dx(ifft3d(gradient_kernel(kvec, i) * pot_k), halo_size=0) + cic_read_dx(ifft3d(gradient_kernel(kvec, i) * pot_k), halo_size=halo_size) for i in range(3) ], axis=-1)