2024-12-20 11:44:02 +01:00
|
|
|
from functools import partial
|
|
|
|
|
2022-03-25 21:29:32 +01:00
|
|
|
import jax.numpy as jnp
|
2024-07-09 14:54:34 -04:00
|
|
|
import numpy as np
|
2022-05-17 23:42:57 +02:00
|
|
|
from jax.scipy.stats import norm
|
2024-12-20 11:44:02 +01:00
|
|
|
from scipy.special import legendre
|
2022-03-25 21:29:32 +01:00
|
|
|
|
2024-12-20 11:44:02 +01:00
|
|
|
__all__ = [
|
|
|
|
'power_spectrum', 'transfer', 'coherence', 'pktranscoh',
|
|
|
|
'cross_correlation_coefficients', 'gaussian_smoothing'
|
|
|
|
]
|
2022-03-25 21:29:32 +01:00
|
|
|
|
2024-07-09 14:54:34 -04:00
|
|
|
|
2024-12-20 11:44:02 +01:00
|
|
|
def _initialize_pk(mesh_shape, box_shape, kedges, los):
|
2024-07-09 14:54:34 -04:00
|
|
|
"""
|
2024-12-20 11:44:02 +01:00
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
mesh_shape : tuple of int
|
|
|
|
Shape of the mesh grid.
|
|
|
|
box_shape : tuple of float
|
|
|
|
Physical dimensions of the box.
|
|
|
|
kedges : None, int, float, or list
|
|
|
|
If None, set dk to twice the minimum.
|
|
|
|
If int, specifies number of edges.
|
|
|
|
If float, specifies dk.
|
|
|
|
los : array_like
|
|
|
|
Line-of-sight vector.
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
dig : ndarray
|
|
|
|
Indices of the bins to which each value in input array belongs.
|
|
|
|
kcount : ndarray
|
|
|
|
Count of values in each bin.
|
|
|
|
kedges : ndarray
|
|
|
|
Edges of the bins.
|
|
|
|
mumesh : ndarray
|
|
|
|
Mu values for the mesh grid.
|
2022-03-25 21:29:32 +01:00
|
|
|
"""
|
2024-12-20 11:44:02 +01:00
|
|
|
kmax = np.pi * np.min(mesh_shape / box_shape) # = knyquist
|
|
|
|
|
|
|
|
if isinstance(kedges, None | int | float):
|
|
|
|
if kedges is None:
|
|
|
|
dk = 2 * np.pi / np.min(
|
|
|
|
box_shape) * 2 # twice the minimum wavenumber
|
|
|
|
if isinstance(kedges, int):
|
|
|
|
dk = kmax / (kedges + 1) # final number of bins will be kedges-1
|
|
|
|
elif isinstance(kedges, float):
|
|
|
|
dk = kedges
|
|
|
|
kedges = np.arange(dk, kmax, dk) + dk / 2 # from dk/2 to kmax-dk/2
|
|
|
|
|
|
|
|
kshapes = np.eye(len(mesh_shape), dtype=np.int32) * -2 + 1
|
|
|
|
kvec = [(2 * np.pi * m / l) * np.fft.fftfreq(m).reshape(kshape)
|
|
|
|
for m, l, kshape in zip(mesh_shape, box_shape, kshapes)]
|
|
|
|
kmesh = sum(ki**2 for ki in kvec)**0.5
|
|
|
|
|
|
|
|
dig = np.digitize(kmesh.reshape(-1), kedges)
|
|
|
|
kcount = np.bincount(dig, minlength=len(kedges) + 1)
|
|
|
|
|
|
|
|
# Central value of each bin
|
|
|
|
# kavg = (kedges[1:] + kedges[:-1]) / 2
|
|
|
|
kavg = np.bincount(
|
|
|
|
dig, weights=kmesh.reshape(-1), minlength=len(kedges) + 1) / kcount
|
|
|
|
kavg = kavg[1:-1]
|
|
|
|
|
|
|
|
if los is None:
|
|
|
|
mumesh = 1.
|
|
|
|
else:
|
|
|
|
mumesh = sum(ki * losi for ki, losi in zip(kvec, los))
|
|
|
|
kmesh_nozeros = np.where(kmesh == 0, 1, kmesh)
|
|
|
|
mumesh = np.where(kmesh == 0, 0, mumesh / kmesh_nozeros)
|
|
|
|
|
|
|
|
return dig, kcount, kavg, mumesh
|
|
|
|
|
|
|
|
|
|
|
|
def power_spectrum(mesh,
|
|
|
|
mesh2=None,
|
|
|
|
box_shape=None,
|
|
|
|
kedges: int | float | list = None,
|
|
|
|
multipoles=0,
|
|
|
|
los=[0., 0., 1.]):
|
|
|
|
"""
|
|
|
|
Compute the auto and cross spectrum of 3D fields, with multipoles.
|
2024-07-09 14:54:34 -04:00
|
|
|
"""
|
2024-12-20 11:44:02 +01:00
|
|
|
# Initialize
|
|
|
|
mesh_shape = np.array(mesh.shape)
|
|
|
|
if box_shape is None:
|
|
|
|
box_shape = mesh_shape
|
|
|
|
else:
|
|
|
|
box_shape = np.asarray(box_shape)
|
|
|
|
|
|
|
|
if multipoles == 0:
|
|
|
|
los = None
|
|
|
|
else:
|
|
|
|
los = np.asarray(los)
|
|
|
|
los = los / np.linalg.norm(los)
|
|
|
|
poles = np.atleast_1d(multipoles)
|
|
|
|
dig, kcount, kavg, mumesh = _initialize_pk(mesh_shape, box_shape, kedges,
|
|
|
|
los)
|
|
|
|
n_bins = len(kavg) + 2
|
|
|
|
|
|
|
|
# FFTs
|
|
|
|
meshk = jnp.fft.fftn(mesh, norm='ortho')
|
|
|
|
if mesh2 is None:
|
|
|
|
mmk = meshk.real**2 + meshk.imag**2
|
|
|
|
else:
|
|
|
|
mmk = meshk * jnp.fft.fftn(mesh2, norm='ortho').conj()
|
|
|
|
|
|
|
|
# Sum powers
|
|
|
|
pk = jnp.empty((len(poles), n_bins))
|
|
|
|
for i_ell, ell in enumerate(poles):
|
|
|
|
weights = (mmk * (2 * ell + 1) * legendre(ell)(mumesh)).reshape(-1)
|
|
|
|
if mesh2 is None:
|
|
|
|
psum = jnp.bincount(dig, weights=weights, length=n_bins)
|
|
|
|
else: # XXX: bincount is really slow with complex numbers
|
|
|
|
psum_real = jnp.bincount(dig, weights=weights.real, length=n_bins)
|
|
|
|
psum_imag = jnp.bincount(dig, weights=weights.imag, length=n_bins)
|
|
|
|
psum = (psum_real**2 + psum_imag**2)**.5
|
|
|
|
pk = pk.at[i_ell].set(psum)
|
|
|
|
|
|
|
|
# Normalization and conversion from cell units to [Mpc/h]^3
|
|
|
|
pk = (pk / kcount)[:, 1:-1] * (box_shape / mesh_shape).prod()
|
|
|
|
|
|
|
|
# pk = jnp.concatenate([kavg[None], pk])
|
|
|
|
if np.ndim(multipoles) == 0:
|
|
|
|
return kavg, pk[0]
|
|
|
|
else:
|
|
|
|
return kavg, pk
|
|
|
|
|
|
|
|
|
|
|
|
def transfer(mesh0, mesh1, box_shape, kedges: int | float | list = None):
|
|
|
|
pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
|
|
|
|
ks, pk0 = pk_fn(mesh0)
|
|
|
|
ks, pk1 = pk_fn(mesh1)
|
|
|
|
return ks, (pk1 / pk0)**.5
|
|
|
|
|
|
|
|
|
|
|
|
def coherence(mesh0, mesh1, box_shape, kedges: int | float | list = None):
|
|
|
|
pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
|
|
|
|
ks, pk01 = pk_fn(mesh0, mesh1)
|
|
|
|
ks, pk0 = pk_fn(mesh0)
|
|
|
|
ks, pk1 = pk_fn(mesh1)
|
|
|
|
return ks, pk01 / (pk0 * pk1)**.5
|
|
|
|
|
|
|
|
|
|
|
|
def pktranscoh(mesh0, mesh1, box_shape, kedges: int | float | list = None):
|
|
|
|
pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
|
|
|
|
ks, pk01 = pk_fn(mesh0, mesh1)
|
|
|
|
ks, pk0 = pk_fn(mesh0)
|
|
|
|
ks, pk1 = pk_fn(mesh1)
|
|
|
|
return ks, pk0, pk1, (pk1 / pk0)**.5, pk01 / (pk0 * pk1)**.5
|
|
|
|
|
|
|
|
|
|
|
|
def cross_correlation_coefficients(field_a,
|
|
|
|
field_b,
|
|
|
|
kmin=5,
|
|
|
|
dk=0.5,
|
|
|
|
boxsize=False):
|
|
|
|
"""
|
|
|
|
Calculate the cross correlation coefficients given two real space field
|
2024-07-09 14:54:34 -04:00
|
|
|
|
2022-03-25 21:29:32 +01:00
|
|
|
Args:
|
2024-07-09 14:54:34 -04:00
|
|
|
|
2024-12-20 11:44:02 +01:00
|
|
|
field_a: real valued field
|
|
|
|
field_b: real valued field
|
2022-03-25 21:29:32 +01:00
|
|
|
kmin: minimum k-value for binned powerspectra
|
|
|
|
dk: differential in each kbin
|
|
|
|
boxsize: length of each boxlength (can be strangly shaped?)
|
2024-07-09 14:54:34 -04:00
|
|
|
|
2022-03-25 21:29:32 +01:00
|
|
|
Returns:
|
2024-07-09 14:54:34 -04:00
|
|
|
|
2022-03-25 21:29:32 +01:00
|
|
|
kbins: the central value of the bins for plotting
|
2024-12-20 11:44:02 +01:00
|
|
|
P / norm: normalized cross correlation coefficient between two field a and b
|
2024-07-09 14:54:34 -04:00
|
|
|
|
2022-03-25 21:29:32 +01:00
|
|
|
"""
|
2024-12-20 11:44:02 +01:00
|
|
|
shape = field_a.shape
|
2024-07-09 14:54:34 -04:00
|
|
|
nx, ny, nz = shape
|
2022-03-25 21:29:32 +01:00
|
|
|
|
2024-07-09 14:54:34 -04:00
|
|
|
#initialze values related to powerspectra (mode bins and weights)
|
|
|
|
dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk)
|
2022-03-25 21:29:32 +01:00
|
|
|
|
2024-07-09 14:54:34 -04:00
|
|
|
#fast fourier transform
|
2024-12-20 11:44:02 +01:00
|
|
|
fft_image_a = jnp.fft.fftn(field_a)
|
|
|
|
fft_image_b = jnp.fft.fftn(field_b)
|
2022-03-25 21:29:32 +01:00
|
|
|
|
2024-07-09 14:54:34 -04:00
|
|
|
#absolute value of fast fourier transform
|
2024-12-20 11:44:02 +01:00
|
|
|
pk = fft_image_a * jnp.conj(fft_image_b)
|
2022-03-25 21:29:32 +01:00
|
|
|
|
2024-07-09 14:54:34 -04:00
|
|
|
#calculating powerspectra
|
|
|
|
real = jnp.real(pk).reshape([-1])
|
|
|
|
imag = jnp.imag(pk).reshape([-1])
|
2022-03-25 21:29:32 +01:00
|
|
|
|
2024-07-09 14:54:34 -04:00
|
|
|
Psum = jnp.bincount(dig, weights=(W.flatten() * imag),
|
|
|
|
length=xsum.size) * 1j
|
|
|
|
Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
|
2022-04-28 00:21:46 +02:00
|
|
|
|
2024-07-09 14:54:34 -04:00
|
|
|
P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')
|
2022-04-28 00:21:46 +02:00
|
|
|
|
2024-07-09 14:54:34 -04:00
|
|
|
#normalization for powerspectra
|
|
|
|
norm = np.prod(np.array(shape[:])).astype('float32')**2
|
2022-03-25 21:29:32 +01:00
|
|
|
|
2024-07-09 14:54:34 -04:00
|
|
|
#find central values of each bin
|
|
|
|
kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2
|
2022-03-25 21:29:32 +01:00
|
|
|
|
2024-07-09 14:54:34 -04:00
|
|
|
return kbins, P / norm
|
2022-03-25 21:29:32 +01:00
|
|
|
|
2022-05-17 17:55:06 +02:00
|
|
|
|
|
|
|
def gaussian_smoothing(im, sigma):
|
2024-07-09 14:54:34 -04:00
|
|
|
"""
|
2022-05-17 17:55:06 +02:00
|
|
|
im: 2d image
|
2024-07-09 14:54:34 -04:00
|
|
|
sigma: smoothing scale in px
|
2022-05-17 17:55:06 +02:00
|
|
|
"""
|
2024-07-09 14:54:34 -04:00
|
|
|
# Compute k vector
|
|
|
|
kvec = jnp.stack(jnp.meshgrid(jnp.fft.fftfreq(im.shape[0]),
|
|
|
|
jnp.fft.fftfreq(im.shape[1])),
|
|
|
|
axis=-1)
|
|
|
|
k = jnp.linalg.norm(kvec, axis=-1)
|
|
|
|
# We compute the value of the filter at frequency k
|
|
|
|
filter = norm.pdf(k, 0, 1. / (2. * np.pi * sigma))
|
|
|
|
filter /= filter[0, 0]
|
|
|
|
|
|
|
|
return jnp.fft.ifft2(jnp.fft.fft2(im) * filter).real
|