mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 12:20:54 +00:00
add assert to make pyright happy
This commit is contained in:
parent
75604d2396
commit
ccbfee3615
1 changed files with 4 additions and 3 deletions
|
@ -19,11 +19,12 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0, halo_size=0):
|
|||
Computes gravitational forces on particles using a PM scheme
|
||||
"""
|
||||
if mesh_shape is None:
|
||||
assert(delta is not None) , "If mesh_shape is not provided, delta should be provided"
|
||||
mesh_shape = delta.shape
|
||||
kvec = fftk(mesh_shape)
|
||||
|
||||
|
||||
if delta is None:
|
||||
delta_k = fft3d(cic_paint_dx(positions, halo_size=0))
|
||||
delta_k = fft3d(cic_paint_dx(positions, halo_size=halo_size))
|
||||
else:
|
||||
delta_k = fft3d(delta)
|
||||
|
||||
|
@ -32,7 +33,7 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0, halo_size=0):
|
|||
r_split=r_split)
|
||||
# Computes gravitational forces
|
||||
forces = jnp.stack([
|
||||
cic_read_dx(ifft3d(gradient_kernel(kvec, i) * pot_k), halo_size=0)
|
||||
cic_read_dx(ifft3d(gradient_kernel(kvec, i) * pot_k), halo_size=halo_size)
|
||||
for i in range(3)
|
||||
],
|
||||
axis=-1)
|
||||
|
|
Loading…
Add table
Reference in a new issue