From 03d3bc79279b2296162bbf0be32ef3c93f1e315d Mon Sep 17 00:00:00 2001 From: EiffL Date: Tue, 17 May 2022 11:19:56 +0200 Subject: [PATCH 1/9] adding function for doing 2d paintinng --- jaxpm/painting.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/jaxpm/painting.py b/jaxpm/painting.py index 27a1900..95bead9 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -52,6 +52,31 @@ def cic_read(mesh, positions): neighboor_coords[...,1], neighboor_coords[...,3]]*kernel).sum(axis=-1) +def cic_paint_2d(mesh, positions): + """ Paints positions onto a 2d mesh + mesh: [nx, ny] + positions: [npart, 2] + """ + positions = jnp.expand_dims(positions, 1) + floor = jnp.floor(positions) + connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]]) + + neighboor_coords = floor + connection + kernel = 1. - jnp.abs(positions - neighboor_coords) + kernel = kernel[..., 0] * kernel[..., 1] + + neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,4,2]).astype('int32'), jnp.array(mesh.shape)) + + dnums = jax.lax.ScatterDimensionNumbers( + update_window_dims=(), + inserted_window_dims=(0, 1), + scatter_dims_to_operand_dims=(0, 1)) + mesh = lax.scatter_add(mesh, + neighboor_coords, + kernel.reshape([-1,4]), + dnums) + return mesh + def compensate_cic(field): """ Compensate for CiC painting From da2836f698bd922dcc3645f411b5143c03388137 Mon Sep 17 00:00:00 2001 From: EiffL Date: Tue, 17 May 2022 11:26:38 +0200 Subject: [PATCH 2/9] adding density plane cutting code --- jaxpm/lensing.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 jaxpm/lensing.py diff --git a/jaxpm/lensing.py b/jaxpm/lensing.py new file mode 100644 index 0000000..723b71e --- /dev/null +++ b/jaxpm/lensing.py @@ -0,0 +1,32 @@ +import jax +import jax.numpy as jnp + +from jaxpm.painting import cic_paint_2d + +def density_plane(positions, + box_shape, + center, + width, + plane_resolution): + + nx, ny, nz = box_shape + xy = positions[..., :2] + d = positions[..., 2] + + # Apply 2d periodic conditions + xy = jnp.mod(xy, nx) + + # Rescaling positions to target grid + xy = xy / nx * plane_resolution + + # Selecting only particles that fall inside the volume of interest + mask = (d > (center - width / 2)) & (d <= (center + width / 2)) + + # Painting density plane + density_plane = cic_paint_2d(jnp.zeros([plane_resolution, plane_resolution]), xy[mask]) + + # Apply density normalization + density_plane = density_plane / ((nx / plane_resolution) * + (ny / plane_resolution) * (width)) + + return density_plane From e33504358d379c56ea5d960eead339ec316cbbfc Mon Sep 17 00:00:00 2001 From: EiffL Date: Tue, 17 May 2022 17:55:06 +0200 Subject: [PATCH 3/9] adds utilities for simple lensing --- jaxpm/lensing.py | 57 +++++++++++++++++++++++++++++++++++++++++++++--- jaxpm/utils.py | 17 +++++++++++++++ 2 files changed, 71 insertions(+), 3 deletions(-) diff --git a/jaxpm/lensing.py b/jaxpm/lensing.py index 723b71e..5d3d1e2 100644 --- a/jaxpm/lensing.py +++ b/jaxpm/lensing.py @@ -1,14 +1,20 @@ import jax import jax.numpy as jnp +import jax_cosmo.constants as constants +import jax_cosmo +from jax.scipy.ndimage import map_coordinates +from jaxpm.utils import gaussian_smoothing from jaxpm.painting import cic_paint_2d def density_plane(positions, box_shape, center, width, - plane_resolution): - + plane_resolution, + smoothing_sigma=None): + """ Extacts a density plane from the simulation + """ nx, ny, nz = box_shape xy = positions[..., :2] d = positions[..., 2] @@ -21,12 +27,57 @@ def density_plane(positions, # Selecting only particles that fall inside the volume of interest mask = (d > (center - width / 2)) & (d <= (center + width / 2)) + xy = xy[mask] # Painting density plane - density_plane = cic_paint_2d(jnp.zeros([plane_resolution, plane_resolution]), xy[mask]) + density_plane = cic_paint_2d(jnp.zeros([plane_resolution, plane_resolution]), xy) # Apply density normalization density_plane = density_plane / ((nx / plane_resolution) * (ny / plane_resolution) * (width)) + # Apply Gaussian smoothing if requested + if smoothing_sigma is not None: + density_plane = gaussian_smoothing(density_plane, + smoothing_sigma) + return density_plane + + +def convergence_Born(cosmo, + density_planes, + dx, dz, + coords, + z_source): + """ + Compute the Born convergence + Args: + cosmo: `Cosmology`, cosmology object. + density_planes: list of tuples (r, a, density_plane), lens planes to use + dx: float, transverse pixel resolution of the density planes [Mpc/h] + dz: float, width of the density planes [Mpc/h] + coords: a 3-D array of angular coordinates in radians of N points with shape [batch, N, 2]. + z_source: 1-D `Tensor` of source redshifts with shape [Nz] . + name: `string`, name of the operation. + Returns: + `Tensor` of shape [batch_size, N, Nz], of convergence values. + """ + # Compute constant prefactor: + constant_factor = 3 / 2 * cosmo.Omega_m * (constants.H0 / constants.c)**2 + # Compute comoving distance of source galaxies + r_s = jax_cosmo.background.radial_comoving_distance(cosmo, 1 / (1 + z_source)) + + convergence = 0 + for r, a, p in density_planes: + # Normalize density planes + density_normalization = dz * r / a + p = (p - p.mean()) * constant_factor * density_normalization + + # Interpolate at the density plane coordinates + im = map_coordinates(p, + coords * r / dx - 0.5, + order=1, mode="wrap") + + convergence += im * jnp.clip(1. - (r / r_s), 0, 1000).reshape([-1, 1, 1]) + + return convergence diff --git a/jaxpm/utils.py b/jaxpm/utils.py index 312f73f..189ab6c 100644 --- a/jaxpm/utils.py +++ b/jaxpm/utils.py @@ -1,5 +1,6 @@ import numpy as np import jax.numpy as jnp +from scipy.stats import norm __all__ = ['power_spectrum'] @@ -79,3 +80,19 @@ def power_spectrum(field, kmin=5, dk=0.5, boxsize=False): kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2 return kbins, P / norm + +def gaussian_smoothing(im, sigma): + """ + im: 2d image + sigma: smoothing scale in px + """ + # Compute k vector + kvec = jnp.stack(jnp.meshgrid(jnp.fft.fftfreq(im.shape[0]), + jnp.fft.fftfreq(im.shape[1])), + axis=-1) + k = jnp.linalg.norm(kvec, axis=-1) + # We compute the value of the filter at frequency k + filter = norm(0, 1. / (2. * np.pi * sigma)).pdf(k) + + return jnp.fft.ifft2(jnp.fft.fft2(im) * filter).real + From ca5f26c93f7b1fc8528e52116cf4c60269ce8cf9 Mon Sep 17 00:00:00 2001 From: EiffL Date: Tue, 17 May 2022 17:56:41 +0200 Subject: [PATCH 4/9] adding notebook demo From 5dc239927f841e96c08f84ea4db814b8c4c3e03a Mon Sep 17 00:00:00 2001 From: EiffL Date: Tue, 17 May 2022 23:02:01 +0200 Subject: [PATCH 5/9] minor correction to gaussian smoothing --- jaxpm/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jaxpm/utils.py b/jaxpm/utils.py index 189ab6c..8bf6e2e 100644 --- a/jaxpm/utils.py +++ b/jaxpm/utils.py @@ -93,6 +93,7 @@ def gaussian_smoothing(im, sigma): k = jnp.linalg.norm(kvec, axis=-1) # We compute the value of the filter at frequency k filter = norm(0, 1. / (2. * np.pi * sigma)).pdf(k) + filter /= filter[0,0] return jnp.fft.ifft2(jnp.fft.fft2(im) * filter).real From 5108e56ee883b1c3d1871f89279cc3f1092f6690 Mon Sep 17 00:00:00 2001 From: EiffL Date: Tue, 17 May 2022 23:37:55 +0200 Subject: [PATCH 6/9] adds fix to make code jittablel --- jaxpm/lensing.py | 6 ++---- jaxpm/painting.py | 6 +++++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/jaxpm/lensing.py b/jaxpm/lensing.py index 5d3d1e2..d6c4a93 100644 --- a/jaxpm/lensing.py +++ b/jaxpm/lensing.py @@ -26,11 +26,9 @@ def density_plane(positions, xy = xy / nx * plane_resolution # Selecting only particles that fall inside the volume of interest - mask = (d > (center - width / 2)) & (d <= (center + width / 2)) - xy = xy[mask] - + weight = jnp.where((d > (center - width / 2)) & (d <= (center + width / 2)), 1., 0.) # Painting density plane - density_plane = cic_paint_2d(jnp.zeros([plane_resolution, plane_resolution]), xy) + density_plane = cic_paint_2d(jnp.zeros([plane_resolution, plane_resolution]), xy, weight) # Apply density normalization density_plane = density_plane / ((nx / plane_resolution) * diff --git a/jaxpm/painting.py b/jaxpm/painting.py index 95bead9..9e323a2 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -1,3 +1,4 @@ +from tkinter import W import jax import jax.numpy as jnp import jax.lax as lax @@ -52,10 +53,11 @@ def cic_read(mesh, positions): neighboor_coords[...,1], neighboor_coords[...,3]]*kernel).sum(axis=-1) -def cic_paint_2d(mesh, positions): +def cic_paint_2d(mesh, positions, weight): """ Paints positions onto a 2d mesh mesh: [nx, ny] positions: [npart, 2] + weight: [npart] """ positions = jnp.expand_dims(positions, 1) floor = jnp.floor(positions) @@ -64,6 +66,8 @@ def cic_paint_2d(mesh, positions): neighboor_coords = floor + connection kernel = 1. - jnp.abs(positions - neighboor_coords) kernel = kernel[..., 0] * kernel[..., 1] + if weight is not None: + kernel = kernel * weight[...,jnp.newaxis] neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,4,2]).astype('int32'), jnp.array(mesh.shape)) From aff8db56f51c0e062ae2ce0c9136464c030eba6c Mon Sep 17 00:00:00 2001 From: EiffL Date: Tue, 17 May 2022 23:42:57 +0200 Subject: [PATCH 7/9] change impor --- jaxpm/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxpm/utils.py b/jaxpm/utils.py index 8bf6e2e..1a19b45 100644 --- a/jaxpm/utils.py +++ b/jaxpm/utils.py @@ -1,6 +1,6 @@ import numpy as np import jax.numpy as jnp -from scipy.stats import norm +from jax.scipy.stats import norm __all__ = ['power_spectrum'] From 60309e14907dfa2a325b3437345a122869d9bc4e Mon Sep 17 00:00:00 2001 From: EiffL Date: Tue, 17 May 2022 23:49:12 +0200 Subject: [PATCH 8/9] small fix --- jaxpm/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxpm/utils.py b/jaxpm/utils.py index 1a19b45..a01e188 100644 --- a/jaxpm/utils.py +++ b/jaxpm/utils.py @@ -92,7 +92,7 @@ def gaussian_smoothing(im, sigma): axis=-1) k = jnp.linalg.norm(kvec, axis=-1) # We compute the value of the filter at frequency k - filter = norm(0, 1. / (2. * np.pi * sigma)).pdf(k) + filter = norm.pdf(k, 0, 1. / (2. * np.pi * sigma)) filter /= filter[0,0] return jnp.fft.ifft2(jnp.fft.fft2(im) * filter).real From ff5fe8069efea8658727e562527765919b8d7a0b Mon Sep 17 00:00:00 2001 From: Francois Lanusse Date: Wed, 18 May 2022 10:22:21 +0200 Subject: [PATCH 9/9] Update jaxpm/painting.py --- jaxpm/painting.py | 1 - 1 file changed, 1 deletion(-) diff --git a/jaxpm/painting.py b/jaxpm/painting.py index 9e323a2..4237c23 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -1,4 +1,3 @@ -from tkinter import W import jax import jax.numpy as jnp import jax.lax as lax