add assert to make pyright happy

This commit is contained in:
Wassim KABALAN 2024-08-02 21:21:59 +02:00
parent 75604d2396
commit ccbfee3615

View file

@ -19,11 +19,12 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0, halo_size=0):
Computes gravitational forces on particles using a PM scheme
"""
if mesh_shape is None:
assert(delta is not None) , "If mesh_shape is not provided, delta should be provided"
mesh_shape = delta.shape
kvec = fftk(mesh_shape)
if delta is None:
delta_k = fft3d(cic_paint_dx(positions, halo_size=0))
delta_k = fft3d(cic_paint_dx(positions, halo_size=halo_size))
else:
delta_k = fft3d(delta)
@ -32,7 +33,7 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0, halo_size=0):
r_split=r_split)
# Computes gravitational forces
forces = jnp.stack([
cic_read_dx(ifft3d(gradient_kernel(kvec, i) * pot_k), halo_size=0)
cic_read_dx(ifft3d(gradient_kernel(kvec, i) * pot_k), halo_size=halo_size)
for i in range(3)
],
axis=-1)