mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-02-23 10:00:54 +00:00
Merge pull request #11 from DifferentiableUniverseInitiative/u/EiffL/lensing
Adds basic utilities for Born lensing
This commit is contained in:
commit
0991789553
3 changed files with 127 additions and 0 deletions
81
jaxpm/lensing.py
Normal file
81
jaxpm/lensing.py
Normal file
|
@ -0,0 +1,81 @@
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import jax_cosmo.constants as constants
|
||||||
|
import jax_cosmo
|
||||||
|
|
||||||
|
from jax.scipy.ndimage import map_coordinates
|
||||||
|
from jaxpm.utils import gaussian_smoothing
|
||||||
|
from jaxpm.painting import cic_paint_2d
|
||||||
|
|
||||||
|
def density_plane(positions,
|
||||||
|
box_shape,
|
||||||
|
center,
|
||||||
|
width,
|
||||||
|
plane_resolution,
|
||||||
|
smoothing_sigma=None):
|
||||||
|
""" Extacts a density plane from the simulation
|
||||||
|
"""
|
||||||
|
nx, ny, nz = box_shape
|
||||||
|
xy = positions[..., :2]
|
||||||
|
d = positions[..., 2]
|
||||||
|
|
||||||
|
# Apply 2d periodic conditions
|
||||||
|
xy = jnp.mod(xy, nx)
|
||||||
|
|
||||||
|
# Rescaling positions to target grid
|
||||||
|
xy = xy / nx * plane_resolution
|
||||||
|
|
||||||
|
# Selecting only particles that fall inside the volume of interest
|
||||||
|
weight = jnp.where((d > (center - width / 2)) & (d <= (center + width / 2)), 1., 0.)
|
||||||
|
# Painting density plane
|
||||||
|
density_plane = cic_paint_2d(jnp.zeros([plane_resolution, plane_resolution]), xy, weight)
|
||||||
|
|
||||||
|
# Apply density normalization
|
||||||
|
density_plane = density_plane / ((nx / plane_resolution) *
|
||||||
|
(ny / plane_resolution) * (width))
|
||||||
|
|
||||||
|
# Apply Gaussian smoothing if requested
|
||||||
|
if smoothing_sigma is not None:
|
||||||
|
density_plane = gaussian_smoothing(density_plane,
|
||||||
|
smoothing_sigma)
|
||||||
|
|
||||||
|
return density_plane
|
||||||
|
|
||||||
|
|
||||||
|
def convergence_Born(cosmo,
|
||||||
|
density_planes,
|
||||||
|
dx, dz,
|
||||||
|
coords,
|
||||||
|
z_source):
|
||||||
|
"""
|
||||||
|
Compute the Born convergence
|
||||||
|
Args:
|
||||||
|
cosmo: `Cosmology`, cosmology object.
|
||||||
|
density_planes: list of tuples (r, a, density_plane), lens planes to use
|
||||||
|
dx: float, transverse pixel resolution of the density planes [Mpc/h]
|
||||||
|
dz: float, width of the density planes [Mpc/h]
|
||||||
|
coords: a 3-D array of angular coordinates in radians of N points with shape [batch, N, 2].
|
||||||
|
z_source: 1-D `Tensor` of source redshifts with shape [Nz] .
|
||||||
|
name: `string`, name of the operation.
|
||||||
|
Returns:
|
||||||
|
`Tensor` of shape [batch_size, N, Nz], of convergence values.
|
||||||
|
"""
|
||||||
|
# Compute constant prefactor:
|
||||||
|
constant_factor = 3 / 2 * cosmo.Omega_m * (constants.H0 / constants.c)**2
|
||||||
|
# Compute comoving distance of source galaxies
|
||||||
|
r_s = jax_cosmo.background.radial_comoving_distance(cosmo, 1 / (1 + z_source))
|
||||||
|
|
||||||
|
convergence = 0
|
||||||
|
for r, a, p in density_planes:
|
||||||
|
# Normalize density planes
|
||||||
|
density_normalization = dz * r / a
|
||||||
|
p = (p - p.mean()) * constant_factor * density_normalization
|
||||||
|
|
||||||
|
# Interpolate at the density plane coordinates
|
||||||
|
im = map_coordinates(p,
|
||||||
|
coords * r / dx - 0.5,
|
||||||
|
order=1, mode="wrap")
|
||||||
|
|
||||||
|
convergence += im * jnp.clip(1. - (r / r_s), 0, 1000).reshape([-1, 1, 1])
|
||||||
|
|
||||||
|
return convergence
|
|
@ -52,6 +52,34 @@ def cic_read(mesh, positions):
|
||||||
neighboor_coords[...,1],
|
neighboor_coords[...,1],
|
||||||
neighboor_coords[...,3]]*kernel).sum(axis=-1)
|
neighboor_coords[...,3]]*kernel).sum(axis=-1)
|
||||||
|
|
||||||
|
def cic_paint_2d(mesh, positions, weight):
|
||||||
|
""" Paints positions onto a 2d mesh
|
||||||
|
mesh: [nx, ny]
|
||||||
|
positions: [npart, 2]
|
||||||
|
weight: [npart]
|
||||||
|
"""
|
||||||
|
positions = jnp.expand_dims(positions, 1)
|
||||||
|
floor = jnp.floor(positions)
|
||||||
|
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
|
||||||
|
|
||||||
|
neighboor_coords = floor + connection
|
||||||
|
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||||
|
kernel = kernel[..., 0] * kernel[..., 1]
|
||||||
|
if weight is not None:
|
||||||
|
kernel = kernel * weight[...,jnp.newaxis]
|
||||||
|
|
||||||
|
neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,4,2]).astype('int32'), jnp.array(mesh.shape))
|
||||||
|
|
||||||
|
dnums = jax.lax.ScatterDimensionNumbers(
|
||||||
|
update_window_dims=(),
|
||||||
|
inserted_window_dims=(0, 1),
|
||||||
|
scatter_dims_to_operand_dims=(0, 1))
|
||||||
|
mesh = lax.scatter_add(mesh,
|
||||||
|
neighboor_coords,
|
||||||
|
kernel.reshape([-1,4]),
|
||||||
|
dnums)
|
||||||
|
return mesh
|
||||||
|
|
||||||
def compensate_cic(field):
|
def compensate_cic(field):
|
||||||
"""
|
"""
|
||||||
Compensate for CiC painting
|
Compensate for CiC painting
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
from jax.scipy.stats import norm
|
||||||
|
|
||||||
__all__ = ['power_spectrum']
|
__all__ = ['power_spectrum']
|
||||||
|
|
||||||
|
@ -79,3 +80,20 @@ def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
|
||||||
kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2
|
kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2
|
||||||
|
|
||||||
return kbins, P / norm
|
return kbins, P / norm
|
||||||
|
|
||||||
|
def gaussian_smoothing(im, sigma):
|
||||||
|
"""
|
||||||
|
im: 2d image
|
||||||
|
sigma: smoothing scale in px
|
||||||
|
"""
|
||||||
|
# Compute k vector
|
||||||
|
kvec = jnp.stack(jnp.meshgrid(jnp.fft.fftfreq(im.shape[0]),
|
||||||
|
jnp.fft.fftfreq(im.shape[1])),
|
||||||
|
axis=-1)
|
||||||
|
k = jnp.linalg.norm(kvec, axis=-1)
|
||||||
|
# We compute the value of the filter at frequency k
|
||||||
|
filter = norm.pdf(k, 0, 1. / (2. * np.pi * sigma))
|
||||||
|
filter /= filter[0,0]
|
||||||
|
|
||||||
|
return jnp.fft.ifft2(jnp.fft.fft2(im) * filter).real
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue