import jax import jax.numpy as jnp import jax_cosmo import jax_cosmo.constants as constants from jax.scipy.ndimage import map_coordinates from jaxpm.painting import cic_paint_2d from jaxpm.utils import gaussian_smoothing def density_plane(positions, box_shape, center, width, plane_resolution, smoothing_sigma=None): """ Extacts a density plane from the simulation """ nx, ny, nz = box_shape xy = positions[..., :2] d = positions[..., 2] # Apply 2d periodic conditions xy = jnp.mod(xy, nx) # Rescaling positions to target grid xy = xy / nx * plane_resolution # Selecting only particles that fall inside the volume of interest weight = jnp.where( (d > (center - width / 2)) & (d <= (center + width / 2)), 1., 0.) # Painting density plane density_plane = cic_paint_2d( jnp.zeros([plane_resolution, plane_resolution]), xy, weight) # Apply density normalization density_plane = density_plane / ((nx / plane_resolution) * (ny / plane_resolution) * (width)) # Apply Gaussian smoothing if requested if smoothing_sigma is not None: density_plane = gaussian_smoothing(density_plane, smoothing_sigma) return density_plane def convergence_Born(cosmo, density_planes, coords, z_source): """ Compute the Born convergence Args: cosmo: `Cosmology`, cosmology object. density_planes: list of dictionaries (r, a, density_plane, dx, dz), lens planes to use coords: a 3-D array of angular coordinates in radians of N points with shape [batch, N, 2]. z_source: 1-D `Tensor` of source redshifts with shape [Nz] . name: `string`, name of the operation. Returns: `Tensor` of shape [batch_size, N, Nz], of convergence values. """ # Compute constant prefactor: constant_factor = 3 / 2 * cosmo.Omega_m * (constants.H0 / constants.c)**2 # Compute comoving distance of source galaxies r_s = jax_cosmo.background.radial_comoving_distance( cosmo, 1 / (1 + z_source)) convergence = 0 for entry in density_planes: r = entry['r'] a = entry['a'] p = entry['plane'] dx = entry['dx'] dz = entry['dz'] # Normalize density planes density_normalization = dz * r / a p = (p - p.mean()) * constant_factor * density_normalization # Interpolate at the density plane coordinates im = map_coordinates(p, coords * r / dx - 0.5, order=1, mode="wrap") convergence += im * jnp.clip(1. - (r / r_s), 0, 1000).reshape([-1, 1, 1]) return convergence