diff --git a/jaxpm/lensing.py b/jaxpm/lensing.py new file mode 100644 index 0000000..d6c4a93 --- /dev/null +++ b/jaxpm/lensing.py @@ -0,0 +1,81 @@ +import jax +import jax.numpy as jnp +import jax_cosmo.constants as constants +import jax_cosmo + +from jax.scipy.ndimage import map_coordinates +from jaxpm.utils import gaussian_smoothing +from jaxpm.painting import cic_paint_2d + +def density_plane(positions, + box_shape, + center, + width, + plane_resolution, + smoothing_sigma=None): + """ Extacts a density plane from the simulation + """ + nx, ny, nz = box_shape + xy = positions[..., :2] + d = positions[..., 2] + + # Apply 2d periodic conditions + xy = jnp.mod(xy, nx) + + # Rescaling positions to target grid + xy = xy / nx * plane_resolution + + # Selecting only particles that fall inside the volume of interest + weight = jnp.where((d > (center - width / 2)) & (d <= (center + width / 2)), 1., 0.) + # Painting density plane + density_plane = cic_paint_2d(jnp.zeros([plane_resolution, plane_resolution]), xy, weight) + + # Apply density normalization + density_plane = density_plane / ((nx / plane_resolution) * + (ny / plane_resolution) * (width)) + + # Apply Gaussian smoothing if requested + if smoothing_sigma is not None: + density_plane = gaussian_smoothing(density_plane, + smoothing_sigma) + + return density_plane + + +def convergence_Born(cosmo, + density_planes, + dx, dz, + coords, + z_source): + """ + Compute the Born convergence + Args: + cosmo: `Cosmology`, cosmology object. + density_planes: list of tuples (r, a, density_plane), lens planes to use + dx: float, transverse pixel resolution of the density planes [Mpc/h] + dz: float, width of the density planes [Mpc/h] + coords: a 3-D array of angular coordinates in radians of N points with shape [batch, N, 2]. + z_source: 1-D `Tensor` of source redshifts with shape [Nz] . + name: `string`, name of the operation. + Returns: + `Tensor` of shape [batch_size, N, Nz], of convergence values. + """ + # Compute constant prefactor: + constant_factor = 3 / 2 * cosmo.Omega_m * (constants.H0 / constants.c)**2 + # Compute comoving distance of source galaxies + r_s = jax_cosmo.background.radial_comoving_distance(cosmo, 1 / (1 + z_source)) + + convergence = 0 + for r, a, p in density_planes: + # Normalize density planes + density_normalization = dz * r / a + p = (p - p.mean()) * constant_factor * density_normalization + + # Interpolate at the density plane coordinates + im = map_coordinates(p, + coords * r / dx - 0.5, + order=1, mode="wrap") + + convergence += im * jnp.clip(1. - (r / r_s), 0, 1000).reshape([-1, 1, 1]) + + return convergence diff --git a/jaxpm/painting.py b/jaxpm/painting.py index 27a1900..4237c23 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -52,6 +52,34 @@ def cic_read(mesh, positions): neighboor_coords[...,1], neighboor_coords[...,3]]*kernel).sum(axis=-1) +def cic_paint_2d(mesh, positions, weight): + """ Paints positions onto a 2d mesh + mesh: [nx, ny] + positions: [npart, 2] + weight: [npart] + """ + positions = jnp.expand_dims(positions, 1) + floor = jnp.floor(positions) + connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]]) + + neighboor_coords = floor + connection + kernel = 1. - jnp.abs(positions - neighboor_coords) + kernel = kernel[..., 0] * kernel[..., 1] + if weight is not None: + kernel = kernel * weight[...,jnp.newaxis] + + neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,4,2]).astype('int32'), jnp.array(mesh.shape)) + + dnums = jax.lax.ScatterDimensionNumbers( + update_window_dims=(), + inserted_window_dims=(0, 1), + scatter_dims_to_operand_dims=(0, 1)) + mesh = lax.scatter_add(mesh, + neighboor_coords, + kernel.reshape([-1,4]), + dnums) + return mesh + def compensate_cic(field): """ Compensate for CiC painting diff --git a/jaxpm/utils.py b/jaxpm/utils.py index 312f73f..a01e188 100644 --- a/jaxpm/utils.py +++ b/jaxpm/utils.py @@ -1,5 +1,6 @@ import numpy as np import jax.numpy as jnp +from jax.scipy.stats import norm __all__ = ['power_spectrum'] @@ -79,3 +80,20 @@ def power_spectrum(field, kmin=5, dk=0.5, boxsize=False): kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2 return kbins, P / norm + +def gaussian_smoothing(im, sigma): + """ + im: 2d image + sigma: smoothing scale in px + """ + # Compute k vector + kvec = jnp.stack(jnp.meshgrid(jnp.fft.fftfreq(im.shape[0]), + jnp.fft.fftfreq(im.shape[1])), + axis=-1) + k = jnp.linalg.norm(kvec, axis=-1) + # We compute the value of the filter at frequency k + filter = norm.pdf(k, 0, 1. / (2. * np.pi * sigma)) + filter /= filter[0,0] + + return jnp.fft.ifft2(jnp.fft.fft2(im) * filter).real +