Add finite difference laplace kernel + powerspec functions from Hugo

Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com>
This commit is contained in:
Wassim Kabalan 2024-12-05 18:21:24 +01:00
parent 435c7c848f
commit b32014b7ea
2 changed files with 135 additions and 58 deletions

View file

@ -67,21 +67,28 @@ def gradient_kernel(kvec, direction, order=1):
return wts return wts
def invlaplace_kernel(kvec): def invlaplace_kernel(kvec, fd=False):
""" """
Compute the inverse Laplace kernel Compute the inverse Laplace kernel.
cf. [Feng+2016](https://arxiv.org/pdf/1603.00476)
Parameters Parameters
----------- -----------
kvec: list kvec: list
List of wave-vectors List of wave-vectors
fd: bool
Finite difference kernel
Returns Returns
-------- --------
wts: array wts: array
Complex kernel values Complex kernel values
""" """
kk = sum(ki**2 for ki in kvec) if fd:
kk = sum((ki * jnp.sinc(ki / (2 * jnp.pi)))**2 for ki in kvec)
else:
kk = sum(ki**2 for ki in kvec)
kk_nozeros = jnp.where(kk == 0, 1, kk) kk_nozeros = jnp.where(kk == 0, 1, kk)
return -jnp.where(kk == 0, 0, 1 / kk_nozeros) return -jnp.where(kk == 0, 0, 1 / kk_nozeros)

View file

@ -1,86 +1,156 @@
from functools import partial
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np import numpy as np
from jax.scipy.stats import norm from jax.scipy.stats import norm
from scipy.special import legendre
__all__ = ['power_spectrum'] from jaxpm.growth import growth_factor, growth_rate
__all__ = [
'power_spectrum', 'transfer', 'coherence', 'pktranscoh',
'cross_correlation_coefficients', 'gaussian_smoothing'
]
def _initialize_pk(shape, boxsize, kmin, dk): def _initialize_pk(mesh_shape, box_shape, kedges, los):
""" """
Helper function to initialize various (fixed) values for powerspectra... not differentiable! Parameters
----------
mesh_shape : tuple of int
Shape of the mesh grid.
box_shape : tuple of float
Physical dimensions of the box.
kedges : None, int, float, or list
If None, set dk to twice the minimum.
If int, specifies number of edges.
If float, specifies dk.
los : array_like
Line-of-sight vector.
Returns
-------
dig : ndarray
Indices of the bins to which each value in input array belongs.
kcount : ndarray
Count of values in each bin.
kedges : ndarray
Edges of the bins.
mumesh : ndarray
Mu values for the mesh grid.
""" """
I = np.eye(len(shape), dtype='int') * -2 + 1 kmax = np.pi * np.min(mesh_shape / box_shape) # = knyquist
W = np.empty(shape, dtype='f4') if isinstance(kedges, None | int | float):
W[...] = 2.0 if kedges is None:
W[..., 0] = 1.0 dk = 2 * np.pi / np.min(
W[..., -1] = 1.0 box_shape) * 2 # twice the minimum wavenumber
if isinstance(kedges, int):
dk = kmax / (kedges + 1) # final number of bins will be kedges-1
elif isinstance(kedges, float):
dk = kedges
kedges = np.arange(dk, kmax, dk) + dk / 2 # from dk/2 to kmax-dk/2
kmax = np.pi * np.min(np.array(shape)) / np.max(np.array(boxsize)) + dk / 2 kshapes = np.eye(len(mesh_shape), dtype=np.int32) * -2 + 1
kedges = np.arange(kmin, kmax, dk) kvec = [(2 * np.pi * m / l) * np.fft.fftfreq(m).reshape(kshape)
for m, l, kshape in zip(mesh_shape, box_shape, kshapes)]
kmesh = sum(ki**2 for ki in kvec)**0.5
k = [ dig = np.digitize(kmesh.reshape(-1), kedges)
np.fft.fftfreq(N, 1. / (N * 2 * np.pi / L))[:pkshape].reshape(kshape) kcount = np.bincount(dig, minlength=len(kedges) + 1)
for N, L, kshape, pkshape in zip(shape, boxsize, I, shape)
]
kmag = sum(ki**2 for ki in k)**0.5
xsum = np.zeros(len(kedges) + 1) # Central value of each bin
Nsum = np.zeros(len(kedges) + 1) # kavg = (kedges[1:] + kedges[:-1]) / 2
kavg = np.bincount(
dig, weights=kmesh.reshape(-1), minlength=len(kedges) + 1) / kcount
kavg = kavg[1:-1]
dig = np.digitize(kmag.flat, kedges) if los is None:
mumesh = 1.
else:
mumesh = sum(ki * losi for ki, losi in zip(kvec, los))
kmesh_nozeros = np.where(kmesh == 0, 1, kmesh)
mumesh = np.where(kmesh == 0, 0, mumesh / kmesh_nozeros)
xsum.flat += np.bincount(dig, weights=(W * kmag).flat, minlength=xsum.size) return dig, kcount, kavg, mumesh
Nsum.flat += np.bincount(dig, weights=W.flat, minlength=xsum.size)
return dig, Nsum, xsum, W, k, kedges
def power_spectrum(field, kmin=5, dk=0.5, boxsize=False): def power_spectrum(mesh,
mesh2=None,
box_shape=None,
kedges: int | float | list = None,
multipoles=0,
los=[0., 0., 1.]):
""" """
Calculate the powerspectra given real space field Compute the auto and cross spectrum of 3D fields, with multipoles.
"""
# Initialize
mesh_shape = np.array(mesh.shape)
if box_shape is None:
box_shape = mesh_shape
else:
box_shape = np.asarray(box_shape)
Args: if multipoles == 0:
los = None
else:
los = np.asarray(los)
los = los / np.linalg.norm(los)
poles = np.atleast_1d(multipoles)
dig, kcount, kavg, mumesh = _initialize_pk(mesh_shape, box_shape, kedges,
los)
n_bins = len(kavg) + 2
field: real valued field # FFTs
kmin: minimum k-value for binned powerspectra meshk = jnp.fft.fftn(mesh, norm='ortho')
dk: differential in each kbin if mesh2 is None:
boxsize: length of each boxlength (can be strangly shaped?) mmk = meshk.real**2 + meshk.imag**2
else:
mmk = meshk * jnp.fft.fftn(mesh2, norm='ortho').conj()
Returns: # Sum powers
pk = jnp.empty((len(poles), n_bins))
for i_ell, ell in enumerate(poles):
weights = (mmk * (2 * ell + 1) * legendre(ell)(mumesh)).reshape(-1)
if mesh2 is None:
psum = jnp.bincount(dig, weights=weights, length=n_bins)
else: # XXX: bincount is really slow with complex numbers
psum_real = jnp.bincount(dig, weights=weights.real, length=n_bins)
psum_imag = jnp.bincount(dig, weights=weights.imag, length=n_bins)
psum = (psum_real**2 + psum_imag**2)**.5
pk = pk.at[i_ell].set(psum)
kbins: the central value of the bins for plotting # Normalization and conversion from cell units to [Mpc/h]^3
power: real valued array of power in each bin pk = (pk / kcount)[:, 1:-1] * (box_shape / mesh_shape).prod()
""" # pk = jnp.concatenate([kavg[None], pk])
shape = field.shape if np.ndim(multipoles) == 0:
nx, ny, nz = shape return kavg, pk[0]
else:
return kavg, pk
#initialze values related to powerspectra (mode bins and weights)
dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk)
#fast fourier transform def transfer(mesh0, mesh1, box_shape, kedges: int | float | list = None):
fft_image = jnp.fft.fftn(field) pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
ks, pk0 = pk_fn(mesh0)
ks, pk1 = pk_fn(mesh1)
return ks, (pk1 / pk0)**.5
#absolute value of fast fourier transform
pk = jnp.real(fft_image * jnp.conj(fft_image))
#calculating powerspectra def coherence(mesh0, mesh1, box_shape, kedges: int | float | list = None):
real = jnp.real(pk).reshape([-1]) pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
imag = jnp.imag(pk).reshape([-1]) ks, pk01 = pk_fn(mesh0, mesh1)
ks, pk0 = pk_fn(mesh0)
ks, pk1 = pk_fn(mesh1)
return ks, pk01 / (pk0 * pk1)**.5
Psum = jnp.bincount(dig, weights=(W.flatten() * imag),
length=xsum.size) * 1j
Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32') def pktranscoh(mesh0, mesh1, box_shape, kedges: int | float | list = None):
pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
#normalization for powerspectra ks, pk01 = pk_fn(mesh0, mesh1)
norm = np.prod(np.array(shape[:])).astype('float32')**2 ks, pk0 = pk_fn(mesh0)
ks, pk1 = pk_fn(mesh1)
#find central values of each bin return ks, pk0, pk1, (pk1 / pk0)**.5, pk01 / (pk0 * pk1)**.5
kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2
return kbins, P / norm
def cross_correlation_coefficients(field_a, def cross_correlation_coefficients(field_a,