From 0008f8549b4c01b916f1db80973666cf3b927fec Mon Sep 17 00:00:00 2001 From: EiffL Date: Tue, 17 May 2022 17:55:06 +0200 Subject: [PATCH] adds utilities for simple lensing --- jaxpm/lensing.py | 57 +++++++++++++++++++++++++++++++++++++++++++++--- jaxpm/utils.py | 17 +++++++++++++++ 2 files changed, 71 insertions(+), 3 deletions(-) diff --git a/jaxpm/lensing.py b/jaxpm/lensing.py index 723b71e..5d3d1e2 100644 --- a/jaxpm/lensing.py +++ b/jaxpm/lensing.py @@ -1,14 +1,20 @@ import jax import jax.numpy as jnp +import jax_cosmo.constants as constants +import jax_cosmo +from jax.scipy.ndimage import map_coordinates +from jaxpm.utils import gaussian_smoothing from jaxpm.painting import cic_paint_2d def density_plane(positions, box_shape, center, width, - plane_resolution): - + plane_resolution, + smoothing_sigma=None): + """ Extacts a density plane from the simulation + """ nx, ny, nz = box_shape xy = positions[..., :2] d = positions[..., 2] @@ -21,12 +27,57 @@ def density_plane(positions, # Selecting only particles that fall inside the volume of interest mask = (d > (center - width / 2)) & (d <= (center + width / 2)) + xy = xy[mask] # Painting density plane - density_plane = cic_paint_2d(jnp.zeros([plane_resolution, plane_resolution]), xy[mask]) + density_plane = cic_paint_2d(jnp.zeros([plane_resolution, plane_resolution]), xy) # Apply density normalization density_plane = density_plane / ((nx / plane_resolution) * (ny / plane_resolution) * (width)) + # Apply Gaussian smoothing if requested + if smoothing_sigma is not None: + density_plane = gaussian_smoothing(density_plane, + smoothing_sigma) + return density_plane + + +def convergence_Born(cosmo, + density_planes, + dx, dz, + coords, + z_source): + """ + Compute the Born convergence + Args: + cosmo: `Cosmology`, cosmology object. + density_planes: list of tuples (r, a, density_plane), lens planes to use + dx: float, transverse pixel resolution of the density planes [Mpc/h] + dz: float, width of the density planes [Mpc/h] + coords: a 3-D array of angular coordinates in radians of N points with shape [batch, N, 2]. + z_source: 1-D `Tensor` of source redshifts with shape [Nz] . + name: `string`, name of the operation. + Returns: + `Tensor` of shape [batch_size, N, Nz], of convergence values. + """ + # Compute constant prefactor: + constant_factor = 3 / 2 * cosmo.Omega_m * (constants.H0 / constants.c)**2 + # Compute comoving distance of source galaxies + r_s = jax_cosmo.background.radial_comoving_distance(cosmo, 1 / (1 + z_source)) + + convergence = 0 + for r, a, p in density_planes: + # Normalize density planes + density_normalization = dz * r / a + p = (p - p.mean()) * constant_factor * density_normalization + + # Interpolate at the density plane coordinates + im = map_coordinates(p, + coords * r / dx - 0.5, + order=1, mode="wrap") + + convergence += im * jnp.clip(1. - (r / r_s), 0, 1000).reshape([-1, 1, 1]) + + return convergence diff --git a/jaxpm/utils.py b/jaxpm/utils.py index 312f73f..189ab6c 100644 --- a/jaxpm/utils.py +++ b/jaxpm/utils.py @@ -1,5 +1,6 @@ import numpy as np import jax.numpy as jnp +from scipy.stats import norm __all__ = ['power_spectrum'] @@ -79,3 +80,19 @@ def power_spectrum(field, kmin=5, dk=0.5, boxsize=False): kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2 return kbins, P / norm + +def gaussian_smoothing(im, sigma): + """ + im: 2d image + sigma: smoothing scale in px + """ + # Compute k vector + kvec = jnp.stack(jnp.meshgrid(jnp.fft.fftfreq(im.shape[0]), + jnp.fft.fftfreq(im.shape[1])), + axis=-1) + k = jnp.linalg.norm(kvec, axis=-1) + # We compute the value of the filter at frequency k + filter = norm(0, 1. / (2. * np.pi * sigma)).pdf(k) + + return jnp.fft.ifft2(jnp.fft.fft2(im) * filter).real +