From b32014b7eaeea736a6bd1038b2237964dac84ba4 Mon Sep 17 00:00:00 2001 From: Wassim Kabalan Date: Thu, 5 Dec 2024 18:21:24 +0100 Subject: [PATCH] Add finite difference laplace kernel + powerspec functions from Hugo Co-authored-by: Hugo Simonfroy --- jaxpm/kernels.py | 13 +++- jaxpm/utils.py | 180 ++++++++++++++++++++++++++++++++--------------- 2 files changed, 135 insertions(+), 58 deletions(-) diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py index 235ddec..170f5e9 100644 --- a/jaxpm/kernels.py +++ b/jaxpm/kernels.py @@ -67,21 +67,28 @@ def gradient_kernel(kvec, direction, order=1): return wts -def invlaplace_kernel(kvec): +def invlaplace_kernel(kvec, fd=False): """ - Compute the inverse Laplace kernel + Compute the inverse Laplace kernel. + + cf. [Feng+2016](https://arxiv.org/pdf/1603.00476) Parameters ----------- kvec: list List of wave-vectors + fd: bool + Finite difference kernel Returns -------- wts: array Complex kernel values """ - kk = sum(ki**2 for ki in kvec) + if fd: + kk = sum((ki * jnp.sinc(ki / (2 * jnp.pi)))**2 for ki in kvec) + else: + kk = sum(ki**2 for ki in kvec) kk_nozeros = jnp.where(kk == 0, 1, kk) return -jnp.where(kk == 0, 0, 1 / kk_nozeros) diff --git a/jaxpm/utils.py b/jaxpm/utils.py index 7c6af44..659ab3f 100644 --- a/jaxpm/utils.py +++ b/jaxpm/utils.py @@ -1,86 +1,156 @@ +from functools import partial + import jax.numpy as jnp import numpy as np from jax.scipy.stats import norm +from scipy.special import legendre -__all__ = ['power_spectrum'] +from jaxpm.growth import growth_factor, growth_rate + +__all__ = [ + 'power_spectrum', 'transfer', 'coherence', 'pktranscoh', + 'cross_correlation_coefficients', 'gaussian_smoothing' +] -def _initialize_pk(shape, boxsize, kmin, dk): +def _initialize_pk(mesh_shape, box_shape, kedges, los): """ - Helper function to initialize various (fixed) values for powerspectra... not differentiable! + Parameters + ---------- + mesh_shape : tuple of int + Shape of the mesh grid. + box_shape : tuple of float + Physical dimensions of the box. + kedges : None, int, float, or list + If None, set dk to twice the minimum. + If int, specifies number of edges. + If float, specifies dk. + los : array_like + Line-of-sight vector. + + Returns + ------- + dig : ndarray + Indices of the bins to which each value in input array belongs. + kcount : ndarray + Count of values in each bin. + kedges : ndarray + Edges of the bins. + mumesh : ndarray + Mu values for the mesh grid. """ - I = np.eye(len(shape), dtype='int') * -2 + 1 + kmax = np.pi * np.min(mesh_shape / box_shape) # = knyquist - W = np.empty(shape, dtype='f4') - W[...] = 2.0 - W[..., 0] = 1.0 - W[..., -1] = 1.0 + if isinstance(kedges, None | int | float): + if kedges is None: + dk = 2 * np.pi / np.min( + box_shape) * 2 # twice the minimum wavenumber + if isinstance(kedges, int): + dk = kmax / (kedges + 1) # final number of bins will be kedges-1 + elif isinstance(kedges, float): + dk = kedges + kedges = np.arange(dk, kmax, dk) + dk / 2 # from dk/2 to kmax-dk/2 - kmax = np.pi * np.min(np.array(shape)) / np.max(np.array(boxsize)) + dk / 2 - kedges = np.arange(kmin, kmax, dk) + kshapes = np.eye(len(mesh_shape), dtype=np.int32) * -2 + 1 + kvec = [(2 * np.pi * m / l) * np.fft.fftfreq(m).reshape(kshape) + for m, l, kshape in zip(mesh_shape, box_shape, kshapes)] + kmesh = sum(ki**2 for ki in kvec)**0.5 - k = [ - np.fft.fftfreq(N, 1. / (N * 2 * np.pi / L))[:pkshape].reshape(kshape) - for N, L, kshape, pkshape in zip(shape, boxsize, I, shape) - ] - kmag = sum(ki**2 for ki in k)**0.5 + dig = np.digitize(kmesh.reshape(-1), kedges) + kcount = np.bincount(dig, minlength=len(kedges) + 1) - xsum = np.zeros(len(kedges) + 1) - Nsum = np.zeros(len(kedges) + 1) + # Central value of each bin + # kavg = (kedges[1:] + kedges[:-1]) / 2 + kavg = np.bincount( + dig, weights=kmesh.reshape(-1), minlength=len(kedges) + 1) / kcount + kavg = kavg[1:-1] - dig = np.digitize(kmag.flat, kedges) + if los is None: + mumesh = 1. + else: + mumesh = sum(ki * losi for ki, losi in zip(kvec, los)) + kmesh_nozeros = np.where(kmesh == 0, 1, kmesh) + mumesh = np.where(kmesh == 0, 0, mumesh / kmesh_nozeros) - xsum.flat += np.bincount(dig, weights=(W * kmag).flat, minlength=xsum.size) - Nsum.flat += np.bincount(dig, weights=W.flat, minlength=xsum.size) - return dig, Nsum, xsum, W, k, kedges + return dig, kcount, kavg, mumesh -def power_spectrum(field, kmin=5, dk=0.5, boxsize=False): +def power_spectrum(mesh, + mesh2=None, + box_shape=None, + kedges: int | float | list = None, + multipoles=0, + los=[0., 0., 1.]): """ - Calculate the powerspectra given real space field + Compute the auto and cross spectrum of 3D fields, with multipoles. + """ + # Initialize + mesh_shape = np.array(mesh.shape) + if box_shape is None: + box_shape = mesh_shape + else: + box_shape = np.asarray(box_shape) - Args: + if multipoles == 0: + los = None + else: + los = np.asarray(los) + los = los / np.linalg.norm(los) + poles = np.atleast_1d(multipoles) + dig, kcount, kavg, mumesh = _initialize_pk(mesh_shape, box_shape, kedges, + los) + n_bins = len(kavg) + 2 - field: real valued field - kmin: minimum k-value for binned powerspectra - dk: differential in each kbin - boxsize: length of each boxlength (can be strangly shaped?) + # FFTs + meshk = jnp.fft.fftn(mesh, norm='ortho') + if mesh2 is None: + mmk = meshk.real**2 + meshk.imag**2 + else: + mmk = meshk * jnp.fft.fftn(mesh2, norm='ortho').conj() - Returns: + # Sum powers + pk = jnp.empty((len(poles), n_bins)) + for i_ell, ell in enumerate(poles): + weights = (mmk * (2 * ell + 1) * legendre(ell)(mumesh)).reshape(-1) + if mesh2 is None: + psum = jnp.bincount(dig, weights=weights, length=n_bins) + else: # XXX: bincount is really slow with complex numbers + psum_real = jnp.bincount(dig, weights=weights.real, length=n_bins) + psum_imag = jnp.bincount(dig, weights=weights.imag, length=n_bins) + psum = (psum_real**2 + psum_imag**2)**.5 + pk = pk.at[i_ell].set(psum) - kbins: the central value of the bins for plotting - power: real valued array of power in each bin + # Normalization and conversion from cell units to [Mpc/h]^3 + pk = (pk / kcount)[:, 1:-1] * (box_shape / mesh_shape).prod() - """ - shape = field.shape - nx, ny, nz = shape + # pk = jnp.concatenate([kavg[None], pk]) + if np.ndim(multipoles) == 0: + return kavg, pk[0] + else: + return kavg, pk - #initialze values related to powerspectra (mode bins and weights) - dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk) - #fast fourier transform - fft_image = jnp.fft.fftn(field) +def transfer(mesh0, mesh1, box_shape, kedges: int | float | list = None): + pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges) + ks, pk0 = pk_fn(mesh0) + ks, pk1 = pk_fn(mesh1) + return ks, (pk1 / pk0)**.5 - #absolute value of fast fourier transform - pk = jnp.real(fft_image * jnp.conj(fft_image)) - #calculating powerspectra - real = jnp.real(pk).reshape([-1]) - imag = jnp.imag(pk).reshape([-1]) +def coherence(mesh0, mesh1, box_shape, kedges: int | float | list = None): + pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges) + ks, pk01 = pk_fn(mesh0, mesh1) + ks, pk0 = pk_fn(mesh0) + ks, pk1 = pk_fn(mesh1) + return ks, pk01 / (pk0 * pk1)**.5 - Psum = jnp.bincount(dig, weights=(W.flatten() * imag), - length=xsum.size) * 1j - Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size) - P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32') - - #normalization for powerspectra - norm = np.prod(np.array(shape[:])).astype('float32')**2 - - #find central values of each bin - kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2 - - return kbins, P / norm +def pktranscoh(mesh0, mesh1, box_shape, kedges: int | float | list = None): + pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges) + ks, pk01 = pk_fn(mesh0, mesh1) + ks, pk0 = pk_fn(mesh0) + ks, pk1 = pk_fn(mesh1) + return ks, pk0, pk1, (pk1 / pk0)**.5, pk01 / (pk0 * pk1)**.5 def cross_correlation_coefficients(field_a,