From b08641d51d8334fbbf6c8e41c5fbbfa9ca6e4bea Mon Sep 17 00:00:00 2001 From: EiffL Date: Fri, 25 Mar 2022 21:29:32 +0100 Subject: [PATCH] Add utility to compute the power spectrum --- jaxpm/utils.py | 81 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 jaxpm/utils.py diff --git a/jaxpm/utils.py b/jaxpm/utils.py new file mode 100644 index 0000000..b541c65 --- /dev/null +++ b/jaxpm/utils.py @@ -0,0 +1,81 @@ +import numpy as np +import jax.numpy as jnp + +__all__ = ['power_spectrum'] + +def _initialize_pk(shape, boxsize, kmin, dk): + """ + Helper function to initialize various (fixed) values for powerspectra... not differentiable! + """ + I = np.eye(len(shape), dtype='int') * -2 + 1 + + W = np.empty(shape, dtype='f4') + W[...] = 2.0 + W[..., 0] = 1.0 + W[..., -1] = 1.0 + + kmax = np.pi * np.min(np.array(shape)) / np.max(np.array(boxsize)) + dk / 2 + kedges = np.arange(kmin, kmax, dk) + + k = [ + np.fft.fftfreq(N, 1. / (N * 2 * np.pi / L))[:pkshape].reshape(kshape) + for N, L, kshape, pkshape in zip(shape, boxsize, I, shape) + ] + kmag = sum(ki**2 for ki in k)**0.5 + + xsum = np.zeros(len(kedges) + 1) + Nsum = np.zeros(len(kedges) + 1) + + dig = np.digitize(kmag.flat, kedges) + + xsum.flat += np.bincount(dig, weights=(W * kmag).flat, minlength=xsum.size) + Nsum.flat += np.bincount(dig, weights=W.flat, minlength=xsum.size) + return dig, Nsum, xsum, W, k, kedges + + +def power_spectrum(field, kmin=5, dk=0.5, boxsize=False): + """ + Calculate the powerspectra given real space field + + Args: + + field: real valued field + kmin: minimum k-value for binned powerspectra + dk: differential in each kbin + boxsize: length of each boxlength (can be strangly shaped?) + + Returns: + + kbins: the central value of the bins for plotting + power: real valued array of power in each bin + + """ + shape = field.shape + nx, ny, nz = shape + + #initialze values related to powerspectra (mode bins and weights) + dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk) + + #fast fourier transform + fft_image = jnp.fft.fftn(field) + + #absolute value of fast fourier transform + pk = jnp.real(fft_image * jnp.conj(fft_image)) + + + #calculating powerspectra + real = jnp.real(pk).reshape([-1]) + imag = jnp.imag(pk).reshape([-1]) + + Psum = jnp.bincount(dig, weights=(W.flatten() * imag), minlength=xsum.size) * 1j + Psum += jnp.bincount(dig, weights=(W.flatten() * real), minlength=xsum.size) + + P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32') + + #normalization for powerspectra + norm = np.prod(np.array(shape[:])).astype('float32')**2 + + #find central values of each bin + kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2 + + return kbins, P / norm \ No newline at end of file