mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-02-22 17:47:11 +00:00
adds utilities for simple lensing
This commit is contained in:
parent
da2836f698
commit
e33504358d
2 changed files with 71 additions and 3 deletions
|
@ -1,14 +1,20 @@
|
|||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax_cosmo.constants as constants
|
||||
import jax_cosmo
|
||||
|
||||
from jax.scipy.ndimage import map_coordinates
|
||||
from jaxpm.utils import gaussian_smoothing
|
||||
from jaxpm.painting import cic_paint_2d
|
||||
|
||||
def density_plane(positions,
|
||||
box_shape,
|
||||
center,
|
||||
width,
|
||||
plane_resolution):
|
||||
|
||||
plane_resolution,
|
||||
smoothing_sigma=None):
|
||||
""" Extacts a density plane from the simulation
|
||||
"""
|
||||
nx, ny, nz = box_shape
|
||||
xy = positions[..., :2]
|
||||
d = positions[..., 2]
|
||||
|
@ -21,12 +27,57 @@ def density_plane(positions,
|
|||
|
||||
# Selecting only particles that fall inside the volume of interest
|
||||
mask = (d > (center - width / 2)) & (d <= (center + width / 2))
|
||||
xy = xy[mask]
|
||||
|
||||
# Painting density plane
|
||||
density_plane = cic_paint_2d(jnp.zeros([plane_resolution, plane_resolution]), xy[mask])
|
||||
density_plane = cic_paint_2d(jnp.zeros([plane_resolution, plane_resolution]), xy)
|
||||
|
||||
# Apply density normalization
|
||||
density_plane = density_plane / ((nx / plane_resolution) *
|
||||
(ny / plane_resolution) * (width))
|
||||
|
||||
# Apply Gaussian smoothing if requested
|
||||
if smoothing_sigma is not None:
|
||||
density_plane = gaussian_smoothing(density_plane,
|
||||
smoothing_sigma)
|
||||
|
||||
return density_plane
|
||||
|
||||
|
||||
def convergence_Born(cosmo,
|
||||
density_planes,
|
||||
dx, dz,
|
||||
coords,
|
||||
z_source):
|
||||
"""
|
||||
Compute the Born convergence
|
||||
Args:
|
||||
cosmo: `Cosmology`, cosmology object.
|
||||
density_planes: list of tuples (r, a, density_plane), lens planes to use
|
||||
dx: float, transverse pixel resolution of the density planes [Mpc/h]
|
||||
dz: float, width of the density planes [Mpc/h]
|
||||
coords: a 3-D array of angular coordinates in radians of N points with shape [batch, N, 2].
|
||||
z_source: 1-D `Tensor` of source redshifts with shape [Nz] .
|
||||
name: `string`, name of the operation.
|
||||
Returns:
|
||||
`Tensor` of shape [batch_size, N, Nz], of convergence values.
|
||||
"""
|
||||
# Compute constant prefactor:
|
||||
constant_factor = 3 / 2 * cosmo.Omega_m * (constants.H0 / constants.c)**2
|
||||
# Compute comoving distance of source galaxies
|
||||
r_s = jax_cosmo.background.radial_comoving_distance(cosmo, 1 / (1 + z_source))
|
||||
|
||||
convergence = 0
|
||||
for r, a, p in density_planes:
|
||||
# Normalize density planes
|
||||
density_normalization = dz * r / a
|
||||
p = (p - p.mean()) * constant_factor * density_normalization
|
||||
|
||||
# Interpolate at the density plane coordinates
|
||||
im = map_coordinates(p,
|
||||
coords * r / dx - 0.5,
|
||||
order=1, mode="wrap")
|
||||
|
||||
convergence += im * jnp.clip(1. - (r / r_s), 0, 1000).reshape([-1, 1, 1])
|
||||
|
||||
return convergence
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
from scipy.stats import norm
|
||||
|
||||
__all__ = ['power_spectrum']
|
||||
|
||||
|
@ -79,3 +80,19 @@ def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
|
|||
kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2
|
||||
|
||||
return kbins, P / norm
|
||||
|
||||
def gaussian_smoothing(im, sigma):
|
||||
"""
|
||||
im: 2d image
|
||||
sigma: smoothing scale in px
|
||||
"""
|
||||
# Compute k vector
|
||||
kvec = jnp.stack(jnp.meshgrid(jnp.fft.fftfreq(im.shape[0]),
|
||||
jnp.fft.fftfreq(im.shape[1])),
|
||||
axis=-1)
|
||||
k = jnp.linalg.norm(kvec, axis=-1)
|
||||
# We compute the value of the filter at frequency k
|
||||
filter = norm(0, 1. / (2. * np.pi * sigma)).pdf(k)
|
||||
|
||||
return jnp.fft.ifft2(jnp.fft.fft2(im) * filter).real
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue