adds utilities for simple lensing

This commit is contained in:
EiffL 2022-05-17 17:55:06 +02:00
parent da2836f698
commit e33504358d
2 changed files with 71 additions and 3 deletions

View file

@ -1,14 +1,20 @@
import jax
import jax.numpy as jnp
import jax_cosmo.constants as constants
import jax_cosmo
from jax.scipy.ndimage import map_coordinates
from jaxpm.utils import gaussian_smoothing
from jaxpm.painting import cic_paint_2d
def density_plane(positions,
box_shape,
center,
width,
plane_resolution):
plane_resolution,
smoothing_sigma=None):
""" Extacts a density plane from the simulation
"""
nx, ny, nz = box_shape
xy = positions[..., :2]
d = positions[..., 2]
@ -21,12 +27,57 @@ def density_plane(positions,
# Selecting only particles that fall inside the volume of interest
mask = (d > (center - width / 2)) & (d <= (center + width / 2))
xy = xy[mask]
# Painting density plane
density_plane = cic_paint_2d(jnp.zeros([plane_resolution, plane_resolution]), xy[mask])
density_plane = cic_paint_2d(jnp.zeros([plane_resolution, plane_resolution]), xy)
# Apply density normalization
density_plane = density_plane / ((nx / plane_resolution) *
(ny / plane_resolution) * (width))
# Apply Gaussian smoothing if requested
if smoothing_sigma is not None:
density_plane = gaussian_smoothing(density_plane,
smoothing_sigma)
return density_plane
def convergence_Born(cosmo,
density_planes,
dx, dz,
coords,
z_source):
"""
Compute the Born convergence
Args:
cosmo: `Cosmology`, cosmology object.
density_planes: list of tuples (r, a, density_plane), lens planes to use
dx: float, transverse pixel resolution of the density planes [Mpc/h]
dz: float, width of the density planes [Mpc/h]
coords: a 3-D array of angular coordinates in radians of N points with shape [batch, N, 2].
z_source: 1-D `Tensor` of source redshifts with shape [Nz] .
name: `string`, name of the operation.
Returns:
`Tensor` of shape [batch_size, N, Nz], of convergence values.
"""
# Compute constant prefactor:
constant_factor = 3 / 2 * cosmo.Omega_m * (constants.H0 / constants.c)**2
# Compute comoving distance of source galaxies
r_s = jax_cosmo.background.radial_comoving_distance(cosmo, 1 / (1 + z_source))
convergence = 0
for r, a, p in density_planes:
# Normalize density planes
density_normalization = dz * r / a
p = (p - p.mean()) * constant_factor * density_normalization
# Interpolate at the density plane coordinates
im = map_coordinates(p,
coords * r / dx - 0.5,
order=1, mode="wrap")
convergence += im * jnp.clip(1. - (r / r_s), 0, 1000).reshape([-1, 1, 1])
return convergence

View file

@ -1,5 +1,6 @@
import numpy as np
import jax.numpy as jnp
from scipy.stats import norm
__all__ = ['power_spectrum']
@ -79,3 +80,19 @@ def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2
return kbins, P / norm
def gaussian_smoothing(im, sigma):
"""
im: 2d image
sigma: smoothing scale in px
"""
# Compute k vector
kvec = jnp.stack(jnp.meshgrid(jnp.fft.fftfreq(im.shape[0]),
jnp.fft.fftfreq(im.shape[1])),
axis=-1)
k = jnp.linalg.norm(kvec, axis=-1)
# We compute the value of the filter at frequency k
filter = norm(0, 1. / (2. * np.pi * sigma)).pdf(k)
return jnp.fft.ifft2(jnp.fft.fft2(im) * filter).real