mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-29 16:41:11 +00:00
Add Spherical lensing example
This commit is contained in:
parent
2d21985279
commit
f6d547e31f
5 changed files with 1048 additions and 381 deletions
221
jaxpm/lensing.py
221
jaxpm/lensing.py
|
@ -1,82 +1,181 @@
|
|||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax_cosmo
|
||||
import jax_cosmo as jc
|
||||
import jax_cosmo.constants as constants
|
||||
from jax.scipy.ndimage import map_coordinates
|
||||
|
||||
from jaxpm.painting import cic_paint_2d
|
||||
from jaxpm.distributed import uniform_particles
|
||||
from jaxpm.painting import cic_paint, cic_paint_2d, cic_paint_dx
|
||||
from jaxpm.spherical import paint_spherical
|
||||
from jaxpm.utils import gaussian_smoothing
|
||||
|
||||
|
||||
def density_plane(positions,
|
||||
box_shape,
|
||||
center,
|
||||
width,
|
||||
plane_resolution,
|
||||
smoothing_sigma=None):
|
||||
""" Extacts a density plane from the simulation
|
||||
def density_plane_fn(box_shape,
|
||||
box_size,
|
||||
density_plane_width,
|
||||
density_plane_npix,
|
||||
sharding=None):
|
||||
|
||||
def f(t, y, args):
|
||||
positions = y[0]
|
||||
cosmo = args
|
||||
nx, ny, nz = box_shape
|
||||
|
||||
# Converts time t to comoving distance in voxel coordinates
|
||||
w = density_plane_width / box_size[2] * box_shape[2]
|
||||
center = jc.background.radial_comoving_distance(
|
||||
cosmo, t) / box_size[2] * box_shape[2]
|
||||
positions = uniform_particles(box_shape) + positions
|
||||
xy = positions[..., :2]
|
||||
d = positions[..., 2]
|
||||
|
||||
# Apply 2d periodic conditions
|
||||
xy = jnp.mod(xy, nx)
|
||||
|
||||
# Rescaling positions to target grid
|
||||
xy = xy / nx * density_plane_npix
|
||||
# Selecting only particles that fall inside the volume of interest
|
||||
weight = jnp.where((d > (center - w / 2)) & (d <= (center + w / 2)),
|
||||
1.0, 0.0)
|
||||
# Painting density plane
|
||||
zero_mesh = jnp.zeros([density_plane_npix, density_plane_npix])
|
||||
# Apply sharding in order to recover sharding when taking gradients
|
||||
if sharding is not None:
|
||||
xy = jax.lax.with_sharding_constraint(xy, sharding)
|
||||
# Apply CIC painting
|
||||
density_plane = cic_paint_2d(zero_mesh, xy, weight)
|
||||
|
||||
# Apply density normalization
|
||||
density_plane = density_plane / ((nx / density_plane_npix) *
|
||||
(ny / density_plane_npix) * w)
|
||||
return density_plane
|
||||
|
||||
return f
|
||||
|
||||
|
||||
def spherical_density_fn(box_shape,
|
||||
box_size,
|
||||
nside,
|
||||
fov,
|
||||
center_radec,
|
||||
observer_position,
|
||||
d_R,
|
||||
sharding=None):
|
||||
|
||||
def f(t, y, args):
|
||||
positions = y[0]
|
||||
nx, ny, nz = box_shape
|
||||
bx, by, bz = box_size
|
||||
cosmo = args
|
||||
# Converts time t to comoving distance in voxel coordinates
|
||||
w = d_R / box_size[2] * box_shape[2]
|
||||
center = ((jc.background.radial_comoving_distance(cosmo, t)) / bz) * nz
|
||||
|
||||
# Apply sharding in order to recover sharding when taking gradients
|
||||
if sharding is not None:
|
||||
positions = jax.lax.with_sharding_constraint(positions, sharding)
|
||||
|
||||
density_mesh = cic_paint_dx(positions)
|
||||
# Project to spherical map
|
||||
spherical_map = paint_spherical(density_mesh, nside, fov, center_radec,
|
||||
observer_position, box_size, center,
|
||||
d_R)
|
||||
return spherical_map
|
||||
|
||||
return f
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# Weak Lensing Born Approximation
|
||||
# ==========================================================
|
||||
def convergence_Born(cosmo, density_planes, r, a, dx, dz, coords, z_source):
|
||||
"""
|
||||
nx, ny, nz = box_shape
|
||||
xy = positions[..., :2]
|
||||
d = positions[..., 2]
|
||||
Compute Born-approximation lensing convergence maps.
|
||||
|
||||
# Apply 2d periodic conditions
|
||||
xy = jnp.mod(xy, nx)
|
||||
Parameters
|
||||
----------
|
||||
cosmo : jc.Cosmology
|
||||
Cosmology object.
|
||||
density_planes : ndarray
|
||||
3D array of lensing density planes [nx, ny, n_planes].
|
||||
r, a : ndarray
|
||||
Comoving distances and scale factors per plane.
|
||||
dx : float
|
||||
Pixel scale.
|
||||
dz : float
|
||||
Redshift bin width.
|
||||
coords : ndarray
|
||||
Angular coordinates grid [2, N, 2] in radians.
|
||||
z_source : ndarray
|
||||
Source redshifts.
|
||||
|
||||
# Rescaling positions to target grid
|
||||
xy = xy / nx * plane_resolution
|
||||
|
||||
# Selecting only particles that fall inside the volume of interest
|
||||
weight = jnp.where(
|
||||
(d > (center - width / 2)) & (d <= (center + width / 2)), 1., 0.)
|
||||
# Painting density plane
|
||||
density_plane = cic_paint_2d(
|
||||
jnp.zeros([plane_resolution, plane_resolution]), xy, weight)
|
||||
|
||||
# Apply density normalization
|
||||
density_plane = density_plane / ((nx / plane_resolution) *
|
||||
(ny / plane_resolution) * (width))
|
||||
|
||||
# Apply Gaussian smoothing if requested
|
||||
if smoothing_sigma is not None:
|
||||
density_plane = gaussian_smoothing(density_plane, smoothing_sigma)
|
||||
|
||||
return density_plane
|
||||
|
||||
|
||||
def convergence_Born(cosmo, density_planes, coords, z_source):
|
||||
Returns
|
||||
-------
|
||||
convergence : ndarray
|
||||
2D convergence map for each source redshift.
|
||||
"""
|
||||
Compute the Born convergence
|
||||
Args:
|
||||
cosmo: `Cosmology`, cosmology object.
|
||||
density_planes: list of dictionaries (r, a, density_plane, dx, dz), lens planes to use
|
||||
coords: a 3-D array of angular coordinates in radians of N points with shape [batch, N, 2].
|
||||
z_source: 1-D `Tensor` of source redshifts with shape [Nz] .
|
||||
name: `string`, name of the operation.
|
||||
Returns:
|
||||
`Tensor` of shape [batch_size, N, Nz], of convergence values.
|
||||
"""
|
||||
# Compute constant prefactor:
|
||||
constant_factor = 3 / 2 * cosmo.Omega_m * (constants.H0 / constants.c)**2
|
||||
# Compute comoving distance of source galaxies
|
||||
r_s = jax_cosmo.background.radial_comoving_distance(
|
||||
cosmo, 1 / (1 + z_source))
|
||||
r_s = jc.background.radial_comoving_distance(cosmo, 1 / (1 + z_source))
|
||||
n_planes = len(r)
|
||||
|
||||
convergence = 0
|
||||
for entry in density_planes:
|
||||
r = entry['r']
|
||||
a = entry['a']
|
||||
p = entry['plane']
|
||||
dx = entry['dx']
|
||||
dz = entry['dz']
|
||||
# Normalize density planes
|
||||
density_normalization = dz * r / a
|
||||
def scan_fn(carry, i):
|
||||
density_planes, a, r = carry
|
||||
|
||||
p = density_planes[:, :, i]
|
||||
density_normalization = dz * r[i] / a[i]
|
||||
p = (p - p.mean()) * constant_factor * density_normalization
|
||||
|
||||
# Interpolate at the density plane coordinates
|
||||
im = map_coordinates(p, coords * r / dx - 0.5, order=1, mode="wrap")
|
||||
im = map_coordinates(p, coords * r[i] / dx - 0.5, order=1, mode="wrap")
|
||||
|
||||
convergence += im * jnp.clip(1. -
|
||||
(r / r_s), 0, 1000).reshape([-1, 1, 1])
|
||||
return carry, im * jnp.clip(1.0 -
|
||||
(r[i] / r_s), 0, 1000).reshape([-1, 1, 1])
|
||||
|
||||
return convergence
|
||||
_, convergence = jax.lax.scan(scan_fn, (density_planes, a, r),
|
||||
jnp.arange(n_planes))
|
||||
return convergence.sum(axis=0)
|
||||
|
||||
|
||||
def spherical_convergence_Born(cosmo, density_planes, r, a, nside, z_source):
|
||||
"""
|
||||
Compute Born-approximation lensing convergence maps on a sphere.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cosmo : jc.Cosmology
|
||||
Cosmology object.
|
||||
density_planes : ndarray
|
||||
3D array of lensing density planes [n_planes, npix].
|
||||
r, a : ndarray
|
||||
Comoving distances and scale factors per plane.
|
||||
nside : int
|
||||
Healpix nside parameter.
|
||||
z_source : ndarray
|
||||
Source redshifts.
|
||||
|
||||
Returns
|
||||
-------
|
||||
convergence : ndarray
|
||||
2D convergence map for each source redshift.
|
||||
"""
|
||||
constant_factor = 3 / 2 * cosmo.Omega_m * (constants.H0 / constants.c)**2
|
||||
# Compute comoving distance of source galaxies
|
||||
r_s = jc.background.radial_comoving_distance(cosmo, 1 / (1 + z_source))
|
||||
n_planes = len(r)
|
||||
|
||||
def scan_fn(carry, i):
|
||||
density_planes, a, r = carry
|
||||
|
||||
p = density_planes[i, :]
|
||||
density_normalization = r[i] / a[
|
||||
i] # This normalization needs to be checked
|
||||
p = (p - p.mean()) * constant_factor * density_normalization
|
||||
|
||||
return carry, p * jnp.clip(1.0 -
|
||||
(r[i] / r_s), 0, 1000).reshape([-1, 1])
|
||||
|
||||
_, convergence = jax.lax.scan(scan_fn, (density_planes, a, r),
|
||||
jnp.arange(n_planes))
|
||||
return convergence.sum(axis=0)
|
||||
|
|
133
jaxpm/ode.py
Normal file
133
jaxpm/ode.py
Normal file
|
@ -0,0 +1,133 @@
|
|||
from jaxpm.growth import E, Gf, dGfa, gp
|
||||
from jaxpm.growth import growth_factor as Gp
|
||||
from jaxpm.pm import pm_forces
|
||||
|
||||
|
||||
def symplectic_fpm_ode(mesh_shape, dt0, paint_absolute_pos=True, halo_size=0, sharding=None):
|
||||
def drift(a, vel, args):
|
||||
"""
|
||||
state is a tuple (position, velocities)
|
||||
"""
|
||||
cosmo = args[0]
|
||||
# Get the time steps
|
||||
t0 = a
|
||||
t1 = a + dt0
|
||||
# Set the scale factors
|
||||
ai = t0
|
||||
ac = (t0 * t1) ** 0.5 # Geometric mean of t0 and t1
|
||||
af = t1
|
||||
|
||||
#drift_contr = (Gp(cosmo, af) - Gp(cosmo, ai)) / gp(cosmo, ac)
|
||||
drift_contr = (af - ai )/ ac
|
||||
# Computes the update of position (drift)
|
||||
dpos = 1 / (ac**3 * E(cosmo, ac)) * vel
|
||||
|
||||
return dpos * (drift_contr / dt0)
|
||||
|
||||
def kick(a, pos, args):
|
||||
"""
|
||||
state is a tuple (position, velocities)
|
||||
"""
|
||||
# Computes the update of velocity (kick)
|
||||
cosmo = args
|
||||
# Get the time steps
|
||||
t0 = a
|
||||
t1 = t0 + dt0
|
||||
t2 = t1 + dt0
|
||||
t0t1 = (t0 * t1) ** 0.5 # Geometric mean of t0 and t1
|
||||
t1t2 = (t1 * t2) ** 0.5 # Geometric mean of t1 and t2
|
||||
# Set the scale factors
|
||||
ac = t1
|
||||
|
||||
forces = (
|
||||
pm_forces(
|
||||
pos,
|
||||
mesh_shape=mesh_shape,
|
||||
paint_absolute_pos=paint_absolute_pos,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding,
|
||||
)
|
||||
* 1.5
|
||||
* cosmo.Omega_m
|
||||
)
|
||||
|
||||
# Computes the update of velocity (kick)
|
||||
dvel = 1.0 / (ac**2 * E(cosmo, ac)) * forces
|
||||
# First kick control factor
|
||||
kick_factor_1 = (t1 - t0t1) / t1
|
||||
#kick_factor_1 = (Gf(cosmo, t1) - Gf(cosmo, t0t1)) / dGfa(cosmo, t1)
|
||||
# Second kick control factor
|
||||
kick_factor_2 = (t2 - t1t2) / t2
|
||||
#kick_factor_2 = (Gf(cosmo, t1t2) - Gf(cosmo, t1)) / dGfa(cosmo, t1)
|
||||
|
||||
return dvel * ((kick_factor_1 + kick_factor_2) / dt0)
|
||||
|
||||
def first_kick(a, pos, args):
|
||||
"""
|
||||
state is a tuple (position, velocities)
|
||||
"""
|
||||
# Computes the update of velocity (kick)
|
||||
cosmo = args
|
||||
# Get the time steps
|
||||
t0 = a
|
||||
t1 = t0 + dt0
|
||||
t0t1 = (t0 * t1) ** 0.5 # Geometric mean of t0 and t1
|
||||
|
||||
forces = (
|
||||
pm_forces(
|
||||
pos,
|
||||
mesh_shape=mesh_shape,
|
||||
paint_absolute_pos=paint_absolute_pos,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding,
|
||||
)
|
||||
* 1.5
|
||||
* cosmo.Omega_m
|
||||
)
|
||||
|
||||
# Computes the update of velocity (kick)
|
||||
dvel = 1.0 / (a**2 * E(cosmo, a)) * forces
|
||||
# First kick control factor
|
||||
kick_factor = (Gf(cosmo, t0t1) - Gf(cosmo, t0)) / dGfa(cosmo, t0)
|
||||
|
||||
return dvel * (kick_factor / dt0)
|
||||
|
||||
return drift, kick, first_kick
|
||||
|
||||
def symplectic_ode(mesh_shape, paint_absolute_pos=True, halo_size=0, sharding=None):
|
||||
def drift(a, vel, args):
|
||||
"""
|
||||
state is a tuple (position, velocities)
|
||||
"""
|
||||
cosmo = args
|
||||
# Computes the update of position (drift)
|
||||
dpos = 1 / (a**3 * E(cosmo, a)) * vel
|
||||
|
||||
return dpos
|
||||
|
||||
def kick(a, pos, args):
|
||||
"""
|
||||
state is a tuple (position, velocities)
|
||||
"""
|
||||
# Computes the update of velocity (kick)
|
||||
|
||||
cosmo = args
|
||||
|
||||
forces = (
|
||||
pm_forces(
|
||||
pos,
|
||||
mesh_shape=mesh_shape,
|
||||
paint_absolute_pos=paint_absolute_pos,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding,
|
||||
)
|
||||
* 1.5
|
||||
* cosmo.Omega_m
|
||||
)
|
||||
|
||||
# Computes the update of velocity (kick)
|
||||
dvel = 1.0 / (a**2 * E(cosmo, a)) * forces
|
||||
|
||||
return dvel
|
||||
|
||||
return drift, kick
|
50
jaxpm/spherical.py
Normal file
50
jaxpm/spherical.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
import jax.numpy as jnp
|
||||
import jax_healpy as jhp
|
||||
import matplotlib.pyplot as plt
|
||||
import jax
|
||||
from functools import partial
|
||||
import healpy as hp
|
||||
|
||||
@partial(jax.jit, static_argnames=('nside', 'fov', 'center_radec' , 'd_R' , 'box_size'))
|
||||
def paint_spherical(volume, nside, fov, center_radec, observer_position, box_size, R, d_R):
|
||||
width, height, depth = volume.shape
|
||||
ra0, dec0 = center_radec
|
||||
fov_width, fov_height = fov
|
||||
|
||||
pixel_scale_x = fov_width / width
|
||||
pixel_scale_y = fov_height / height
|
||||
|
||||
res_deg = jhp.nside2resol(nside, arcmin=True) / 60
|
||||
if pixel_scale_x > res_deg or pixel_scale_y > res_deg:
|
||||
print(f"WARNING Pixel scale ({pixel_scale_x:.4f} deg, {pixel_scale_y:.4f} deg) is larger than the Healpy resolution ({res_deg:.4f} deg). Increase the field of view or decrease the nside.")
|
||||
|
||||
y_idx, x_idx = jnp.indices((height, width))
|
||||
ra_grid = ra0 + x_idx * pixel_scale_x
|
||||
dec_grid = dec0 + y_idx * pixel_scale_y
|
||||
|
||||
ra_flat = ra_grid.flatten() * jnp.pi / 180.0
|
||||
dec_flat = dec_grid.flatten() * jnp.pi / 180.0
|
||||
R_s = jnp.arange(0 , d_R, 1.0) + R
|
||||
|
||||
XYZ = R_s.reshape(-1, 1, 1) * jhp.ang2vec(ra_flat, dec_flat, lonlat=False)
|
||||
observer_position = jnp.array(observer_position)
|
||||
# Convert observer position from box units to grid units
|
||||
observer_position = observer_position / jnp.array(box_size) * jnp.array(volume.shape)
|
||||
|
||||
coords = XYZ + jnp.asarray(observer_position)[jnp.newaxis, jnp.newaxis, :]
|
||||
|
||||
pixels = jhp.ang2pix(nside, ra_flat, dec_flat, lonlat=False)
|
||||
|
||||
npix = jhp.nside2npix(nside)
|
||||
|
||||
@partial(jax.vmap, in_axes=(0, None, None))
|
||||
def interpolate_volume(coords, volume, pixels):
|
||||
voxels = jax.scipy.ndimage.map_coordinates(volume, coords.T, order=1)
|
||||
sums = jnp.bincount(pixels, weights=voxels, length=npix)
|
||||
return sums
|
||||
|
||||
sum_map = interpolate_volume(coords, volume, pixels).sum(axis=0)
|
||||
counts = jnp.bincount(pixels, length=npix)
|
||||
sum_map = jnp.where(counts > 0, sum_map / counts, jhp.UNSEEN)
|
||||
|
||||
return sum_map
|
Loading…
Add table
Add a link
Reference in a new issue