diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py
index 235ddec..170f5e9 100644
--- a/jaxpm/kernels.py
+++ b/jaxpm/kernels.py
@@ -67,21 +67,28 @@ def gradient_kernel(kvec, direction, order=1):
         return wts
 
 
-def invlaplace_kernel(kvec):
+def invlaplace_kernel(kvec, fd=False):
     """
-    Compute the inverse Laplace kernel
+    Compute the inverse Laplace kernel.
+
+    cf. [Feng+2016](https://arxiv.org/pdf/1603.00476)
 
     Parameters
     -----------
     kvec: list
         List of wave-vectors
+    fd: bool
+        Finite difference kernel
 
     Returns
     --------
     wts: array
         Complex kernel values
     """
-    kk = sum(ki**2 for ki in kvec)
+    if fd:
+        kk = sum((ki * jnp.sinc(ki / (2 * jnp.pi)))**2 for ki in kvec)
+    else:
+        kk = sum(ki**2 for ki in kvec)
     kk_nozeros = jnp.where(kk == 0, 1, kk)
     return -jnp.where(kk == 0, 0, 1 / kk_nozeros)
 
diff --git a/jaxpm/utils.py b/jaxpm/utils.py
index 7c6af44..659ab3f 100644
--- a/jaxpm/utils.py
+++ b/jaxpm/utils.py
@@ -1,86 +1,156 @@
+from functools import partial
+
 import jax.numpy as jnp
 import numpy as np
 from jax.scipy.stats import norm
+from scipy.special import legendre
 
-__all__ = ['power_spectrum']
+from jaxpm.growth import growth_factor, growth_rate
+
+__all__ = [
+    'power_spectrum', 'transfer', 'coherence', 'pktranscoh',
+    'cross_correlation_coefficients', 'gaussian_smoothing'
+]
 
 
-def _initialize_pk(shape, boxsize, kmin, dk):
+def _initialize_pk(mesh_shape, box_shape, kedges, los):
     """
-       Helper function to initialize various (fixed) values for powerspectra... not differentiable!
+    Parameters
+    ----------
+    mesh_shape : tuple of int
+        Shape of the mesh grid.
+    box_shape : tuple of float
+        Physical dimensions of the box.
+    kedges : None, int, float, or list
+        If None, set dk to twice the minimum.
+        If int, specifies number of edges.
+        If float, specifies dk.
+    los : array_like
+        Line-of-sight vector.
+
+    Returns
+    -------
+    dig : ndarray
+        Indices of the bins to which each value in input array belongs.
+    kcount : ndarray
+        Count of values in each bin.
+    kedges : ndarray
+        Edges of the bins.
+    mumesh : ndarray
+        Mu values for the mesh grid.
     """
-    I = np.eye(len(shape), dtype='int') * -2 + 1
+    kmax = np.pi * np.min(mesh_shape / box_shape)  # = knyquist
 
-    W = np.empty(shape, dtype='f4')
-    W[...] = 2.0
-    W[..., 0] = 1.0
-    W[..., -1] = 1.0
+    if isinstance(kedges, None | int | float):
+        if kedges is None:
+            dk = 2 * np.pi / np.min(
+                box_shape) * 2  # twice the minimum wavenumber
+        if isinstance(kedges, int):
+            dk = kmax / (kedges + 1)  # final number of bins will be kedges-1
+        elif isinstance(kedges, float):
+            dk = kedges
+        kedges = np.arange(dk, kmax, dk) + dk / 2  # from dk/2 to kmax-dk/2
 
-    kmax = np.pi * np.min(np.array(shape)) / np.max(np.array(boxsize)) + dk / 2
-    kedges = np.arange(kmin, kmax, dk)
+    kshapes = np.eye(len(mesh_shape), dtype=np.int32) * -2 + 1
+    kvec = [(2 * np.pi * m / l) * np.fft.fftfreq(m).reshape(kshape)
+            for m, l, kshape in zip(mesh_shape, box_shape, kshapes)]
+    kmesh = sum(ki**2 for ki in kvec)**0.5
 
-    k = [
-        np.fft.fftfreq(N, 1. / (N * 2 * np.pi / L))[:pkshape].reshape(kshape)
-        for N, L, kshape, pkshape in zip(shape, boxsize, I, shape)
-    ]
-    kmag = sum(ki**2 for ki in k)**0.5
+    dig = np.digitize(kmesh.reshape(-1), kedges)
+    kcount = np.bincount(dig, minlength=len(kedges) + 1)
 
-    xsum = np.zeros(len(kedges) + 1)
-    Nsum = np.zeros(len(kedges) + 1)
+    # Central value of each bin
+    # kavg = (kedges[1:] + kedges[:-1]) / 2
+    kavg = np.bincount(
+        dig, weights=kmesh.reshape(-1), minlength=len(kedges) + 1) / kcount
+    kavg = kavg[1:-1]
 
-    dig = np.digitize(kmag.flat, kedges)
+    if los is None:
+        mumesh = 1.
+    else:
+        mumesh = sum(ki * losi for ki, losi in zip(kvec, los))
+        kmesh_nozeros = np.where(kmesh == 0, 1, kmesh)
+        mumesh = np.where(kmesh == 0, 0, mumesh / kmesh_nozeros)
 
-    xsum.flat += np.bincount(dig, weights=(W * kmag).flat, minlength=xsum.size)
-    Nsum.flat += np.bincount(dig, weights=W.flat, minlength=xsum.size)
-    return dig, Nsum, xsum, W, k, kedges
+    return dig, kcount, kavg, mumesh
 
 
-def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
+def power_spectrum(mesh,
+                   mesh2=None,
+                   box_shape=None,
+                   kedges: int | float | list = None,
+                   multipoles=0,
+                   los=[0., 0., 1.]):
     """
-    Calculate the powerspectra given real space field
+    Compute the auto and cross spectrum of 3D fields, with multipoles.
+    """
+    # Initialize
+    mesh_shape = np.array(mesh.shape)
+    if box_shape is None:
+        box_shape = mesh_shape
+    else:
+        box_shape = np.asarray(box_shape)
 
-    Args:
+    if multipoles == 0:
+        los = None
+    else:
+        los = np.asarray(los)
+        los = los / np.linalg.norm(los)
+    poles = np.atleast_1d(multipoles)
+    dig, kcount, kavg, mumesh = _initialize_pk(mesh_shape, box_shape, kedges,
+                                               los)
+    n_bins = len(kavg) + 2
 
-        field: real valued field
-        kmin: minimum k-value for binned powerspectra
-        dk: differential in each kbin
-        boxsize: length of each boxlength (can be strangly shaped?)
+    # FFTs
+    meshk = jnp.fft.fftn(mesh, norm='ortho')
+    if mesh2 is None:
+        mmk = meshk.real**2 + meshk.imag**2
+    else:
+        mmk = meshk * jnp.fft.fftn(mesh2, norm='ortho').conj()
 
-    Returns:
+    # Sum powers
+    pk = jnp.empty((len(poles), n_bins))
+    for i_ell, ell in enumerate(poles):
+        weights = (mmk * (2 * ell + 1) * legendre(ell)(mumesh)).reshape(-1)
+        if mesh2 is None:
+            psum = jnp.bincount(dig, weights=weights, length=n_bins)
+        else:  # XXX: bincount is really slow with complex numbers
+            psum_real = jnp.bincount(dig, weights=weights.real, length=n_bins)
+            psum_imag = jnp.bincount(dig, weights=weights.imag, length=n_bins)
+            psum = (psum_real**2 + psum_imag**2)**.5
+        pk = pk.at[i_ell].set(psum)
 
-        kbins: the central value of the bins for plotting
-        power: real valued array of power in each bin
+    # Normalization and conversion from cell units to [Mpc/h]^3
+    pk = (pk / kcount)[:, 1:-1] * (box_shape / mesh_shape).prod()
 
-  """
-    shape = field.shape
-    nx, ny, nz = shape
+    # pk = jnp.concatenate([kavg[None], pk])
+    if np.ndim(multipoles) == 0:
+        return kavg, pk[0]
+    else:
+        return kavg, pk
 
-    #initialze values related to powerspectra (mode bins and weights)
-    dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk)
 
-    #fast fourier transform
-    fft_image = jnp.fft.fftn(field)
+def transfer(mesh0, mesh1, box_shape, kedges: int | float | list = None):
+    pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
+    ks, pk0 = pk_fn(mesh0)
+    ks, pk1 = pk_fn(mesh1)
+    return ks, (pk1 / pk0)**.5
 
-    #absolute value of fast fourier transform
-    pk = jnp.real(fft_image * jnp.conj(fft_image))
 
-    #calculating powerspectra
-    real = jnp.real(pk).reshape([-1])
-    imag = jnp.imag(pk).reshape([-1])
+def coherence(mesh0, mesh1, box_shape, kedges: int | float | list = None):
+    pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
+    ks, pk01 = pk_fn(mesh0, mesh1)
+    ks, pk0 = pk_fn(mesh0)
+    ks, pk1 = pk_fn(mesh1)
+    return ks, pk01 / (pk0 * pk1)**.5
 
-    Psum = jnp.bincount(dig, weights=(W.flatten() * imag),
-                        length=xsum.size) * 1j
-    Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
 
-    P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')
-
-    #normalization for powerspectra
-    norm = np.prod(np.array(shape[:])).astype('float32')**2
-
-    #find central values of each bin
-    kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2
-
-    return kbins, P / norm
+def pktranscoh(mesh0, mesh1, box_shape, kedges: int | float | list = None):
+    pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
+    ks, pk01 = pk_fn(mesh0, mesh1)
+    ks, pk0 = pk_fn(mesh0)
+    ks, pk1 = pk_fn(mesh1)
+    return ks, pk0, pk1, (pk1 / pk0)**.5, pk01 / (pk0 * pk1)**.5
 
 
 def cross_correlation_coefficients(field_a,