Add finite difference laplace kernel + powerspec functions from Hugo

Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com>
This commit is contained in:
Wassim Kabalan 2024-12-05 18:21:24 +01:00
parent 435c7c848f
commit b32014b7ea
2 changed files with 135 additions and 58 deletions

View file

@ -67,20 +67,27 @@ def gradient_kernel(kvec, direction, order=1):
return wts return wts
def invlaplace_kernel(kvec): def invlaplace_kernel(kvec, fd=False):
""" """
Compute the inverse Laplace kernel Compute the inverse Laplace kernel.
cf. [Feng+2016](https://arxiv.org/pdf/1603.00476)
Parameters Parameters
----------- -----------
kvec: list kvec: list
List of wave-vectors List of wave-vectors
fd: bool
Finite difference kernel
Returns Returns
-------- --------
wts: array wts: array
Complex kernel values Complex kernel values
""" """
if fd:
kk = sum((ki * jnp.sinc(ki / (2 * jnp.pi)))**2 for ki in kvec)
else:
kk = sum(ki**2 for ki in kvec) kk = sum(ki**2 for ki in kvec)
kk_nozeros = jnp.where(kk == 0, 1, kk) kk_nozeros = jnp.where(kk == 0, 1, kk)
return -jnp.where(kk == 0, 0, 1 / kk_nozeros) return -jnp.where(kk == 0, 0, 1 / kk_nozeros)

View file

@ -1,86 +1,156 @@
from functools import partial
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np import numpy as np
from jax.scipy.stats import norm from jax.scipy.stats import norm
from scipy.special import legendre
__all__ = ['power_spectrum'] from jaxpm.growth import growth_factor, growth_rate
__all__ = [
def _initialize_pk(shape, boxsize, kmin, dk): 'power_spectrum', 'transfer', 'coherence', 'pktranscoh',
""" 'cross_correlation_coefficients', 'gaussian_smoothing'
Helper function to initialize various (fixed) values for powerspectra... not differentiable!
"""
I = np.eye(len(shape), dtype='int') * -2 + 1
W = np.empty(shape, dtype='f4')
W[...] = 2.0
W[..., 0] = 1.0
W[..., -1] = 1.0
kmax = np.pi * np.min(np.array(shape)) / np.max(np.array(boxsize)) + dk / 2
kedges = np.arange(kmin, kmax, dk)
k = [
np.fft.fftfreq(N, 1. / (N * 2 * np.pi / L))[:pkshape].reshape(kshape)
for N, L, kshape, pkshape in zip(shape, boxsize, I, shape)
] ]
kmag = sum(ki**2 for ki in k)**0.5
xsum = np.zeros(len(kedges) + 1)
Nsum = np.zeros(len(kedges) + 1)
dig = np.digitize(kmag.flat, kedges)
xsum.flat += np.bincount(dig, weights=(W * kmag).flat, minlength=xsum.size)
Nsum.flat += np.bincount(dig, weights=W.flat, minlength=xsum.size)
return dig, Nsum, xsum, W, k, kedges
def power_spectrum(field, kmin=5, dk=0.5, boxsize=False): def _initialize_pk(mesh_shape, box_shape, kedges, los):
""" """
Calculate the powerspectra given real space field Parameters
----------
Args: mesh_shape : tuple of int
Shape of the mesh grid.
field: real valued field box_shape : tuple of float
kmin: minimum k-value for binned powerspectra Physical dimensions of the box.
dk: differential in each kbin kedges : None, int, float, or list
boxsize: length of each boxlength (can be strangly shaped?) If None, set dk to twice the minimum.
If int, specifies number of edges.
Returns: If float, specifies dk.
los : array_like
kbins: the central value of the bins for plotting Line-of-sight vector.
power: real valued array of power in each bin
Returns
-------
dig : ndarray
Indices of the bins to which each value in input array belongs.
kcount : ndarray
Count of values in each bin.
kedges : ndarray
Edges of the bins.
mumesh : ndarray
Mu values for the mesh grid.
""" """
shape = field.shape kmax = np.pi * np.min(mesh_shape / box_shape) # = knyquist
nx, ny, nz = shape
#initialze values related to powerspectra (mode bins and weights) if isinstance(kedges, None | int | float):
dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk) if kedges is None:
dk = 2 * np.pi / np.min(
box_shape) * 2 # twice the minimum wavenumber
if isinstance(kedges, int):
dk = kmax / (kedges + 1) # final number of bins will be kedges-1
elif isinstance(kedges, float):
dk = kedges
kedges = np.arange(dk, kmax, dk) + dk / 2 # from dk/2 to kmax-dk/2
#fast fourier transform kshapes = np.eye(len(mesh_shape), dtype=np.int32) * -2 + 1
fft_image = jnp.fft.fftn(field) kvec = [(2 * np.pi * m / l) * np.fft.fftfreq(m).reshape(kshape)
for m, l, kshape in zip(mesh_shape, box_shape, kshapes)]
kmesh = sum(ki**2 for ki in kvec)**0.5
#absolute value of fast fourier transform dig = np.digitize(kmesh.reshape(-1), kedges)
pk = jnp.real(fft_image * jnp.conj(fft_image)) kcount = np.bincount(dig, minlength=len(kedges) + 1)
#calculating powerspectra # Central value of each bin
real = jnp.real(pk).reshape([-1]) # kavg = (kedges[1:] + kedges[:-1]) / 2
imag = jnp.imag(pk).reshape([-1]) kavg = np.bincount(
dig, weights=kmesh.reshape(-1), minlength=len(kedges) + 1) / kcount
kavg = kavg[1:-1]
Psum = jnp.bincount(dig, weights=(W.flatten() * imag), if los is None:
length=xsum.size) * 1j mumesh = 1.
Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size) else:
mumesh = sum(ki * losi for ki, losi in zip(kvec, los))
kmesh_nozeros = np.where(kmesh == 0, 1, kmesh)
mumesh = np.where(kmesh == 0, 0, mumesh / kmesh_nozeros)
P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32') return dig, kcount, kavg, mumesh
#normalization for powerspectra
norm = np.prod(np.array(shape[:])).astype('float32')**2
#find central values of each bin def power_spectrum(mesh,
kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2 mesh2=None,
box_shape=None,
kedges: int | float | list = None,
multipoles=0,
los=[0., 0., 1.]):
"""
Compute the auto and cross spectrum of 3D fields, with multipoles.
"""
# Initialize
mesh_shape = np.array(mesh.shape)
if box_shape is None:
box_shape = mesh_shape
else:
box_shape = np.asarray(box_shape)
return kbins, P / norm if multipoles == 0:
los = None
else:
los = np.asarray(los)
los = los / np.linalg.norm(los)
poles = np.atleast_1d(multipoles)
dig, kcount, kavg, mumesh = _initialize_pk(mesh_shape, box_shape, kedges,
los)
n_bins = len(kavg) + 2
# FFTs
meshk = jnp.fft.fftn(mesh, norm='ortho')
if mesh2 is None:
mmk = meshk.real**2 + meshk.imag**2
else:
mmk = meshk * jnp.fft.fftn(mesh2, norm='ortho').conj()
# Sum powers
pk = jnp.empty((len(poles), n_bins))
for i_ell, ell in enumerate(poles):
weights = (mmk * (2 * ell + 1) * legendre(ell)(mumesh)).reshape(-1)
if mesh2 is None:
psum = jnp.bincount(dig, weights=weights, length=n_bins)
else: # XXX: bincount is really slow with complex numbers
psum_real = jnp.bincount(dig, weights=weights.real, length=n_bins)
psum_imag = jnp.bincount(dig, weights=weights.imag, length=n_bins)
psum = (psum_real**2 + psum_imag**2)**.5
pk = pk.at[i_ell].set(psum)
# Normalization and conversion from cell units to [Mpc/h]^3
pk = (pk / kcount)[:, 1:-1] * (box_shape / mesh_shape).prod()
# pk = jnp.concatenate([kavg[None], pk])
if np.ndim(multipoles) == 0:
return kavg, pk[0]
else:
return kavg, pk
def transfer(mesh0, mesh1, box_shape, kedges: int | float | list = None):
pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
ks, pk0 = pk_fn(mesh0)
ks, pk1 = pk_fn(mesh1)
return ks, (pk1 / pk0)**.5
def coherence(mesh0, mesh1, box_shape, kedges: int | float | list = None):
pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
ks, pk01 = pk_fn(mesh0, mesh1)
ks, pk0 = pk_fn(mesh0)
ks, pk1 = pk_fn(mesh1)
return ks, pk01 / (pk0 * pk1)**.5
def pktranscoh(mesh0, mesh1, box_shape, kedges: int | float | list = None):
pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
ks, pk01 = pk_fn(mesh0, mesh1)
ks, pk0 = pk_fn(mesh0)
ks, pk1 = pk_fn(mesh1)
return ks, pk0, pk1, (pk1 / pk0)**.5, pk01 / (pk0 * pk1)**.5
def cross_correlation_coefficients(field_a, def cross_correlation_coefficients(field_a,