diff --git a/jaxpm/utils.py b/jaxpm/utils.py index b541c65..d315b00 100644 --- a/jaxpm/utils.py +++ b/jaxpm/utils.py @@ -67,8 +67,8 @@ def power_spectrum(field, kmin=5, dk=0.5, boxsize=False): real = jnp.real(pk).reshape([-1]) imag = jnp.imag(pk).reshape([-1]) - Psum = jnp.bincount(dig, weights=(W.flatten() * imag), minlength=xsum.size) * 1j - Psum += jnp.bincount(dig, weights=(W.flatten() * real), minlength=xsum.size) + Psum = jnp.bincount(dig, weights=(W.flatten() * imag), length=xsum.size) * 1j + Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size) P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')