mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-08 04:40:53 +00:00
add assert to make pyright happy
This commit is contained in:
parent
75604d2396
commit
ccbfee3615
1 changed files with 4 additions and 3 deletions
|
@ -19,11 +19,12 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0, halo_size=0):
|
||||||
Computes gravitational forces on particles using a PM scheme
|
Computes gravitational forces on particles using a PM scheme
|
||||||
"""
|
"""
|
||||||
if mesh_shape is None:
|
if mesh_shape is None:
|
||||||
|
assert(delta is not None) , "If mesh_shape is not provided, delta should be provided"
|
||||||
mesh_shape = delta.shape
|
mesh_shape = delta.shape
|
||||||
kvec = fftk(mesh_shape)
|
kvec = fftk(mesh_shape)
|
||||||
|
|
||||||
if delta is None:
|
if delta is None:
|
||||||
delta_k = fft3d(cic_paint_dx(positions, halo_size=0))
|
delta_k = fft3d(cic_paint_dx(positions, halo_size=halo_size))
|
||||||
else:
|
else:
|
||||||
delta_k = fft3d(delta)
|
delta_k = fft3d(delta)
|
||||||
|
|
||||||
|
@ -32,7 +33,7 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0, halo_size=0):
|
||||||
r_split=r_split)
|
r_split=r_split)
|
||||||
# Computes gravitational forces
|
# Computes gravitational forces
|
||||||
forces = jnp.stack([
|
forces = jnp.stack([
|
||||||
cic_read_dx(ifft3d(gradient_kernel(kvec, i) * pot_k), halo_size=0)
|
cic_read_dx(ifft3d(gradient_kernel(kvec, i) * pot_k), halo_size=halo_size)
|
||||||
for i in range(3)
|
for i in range(3)
|
||||||
],
|
],
|
||||||
axis=-1)
|
axis=-1)
|
||||||
|
|
Loading…
Add table
Reference in a new issue