diff --git a/jaxpm/lensing.py b/jaxpm/lensing.py index d6c4a93..b4beeef 100644 --- a/jaxpm/lensing.py +++ b/jaxpm/lensing.py @@ -44,16 +44,13 @@ def density_plane(positions, def convergence_Born(cosmo, density_planes, - dx, dz, coords, z_source): """ Compute the Born convergence Args: cosmo: `Cosmology`, cosmology object. - density_planes: list of tuples (r, a, density_plane), lens planes to use - dx: float, transverse pixel resolution of the density planes [Mpc/h] - dz: float, width of the density planes [Mpc/h] + density_planes: list of dictionaries (r, a, density_plane, dx, dz), lens planes to use coords: a 3-D array of angular coordinates in radians of N points with shape [batch, N, 2]. z_source: 1-D `Tensor` of source redshifts with shape [Nz] . name: `string`, name of the operation. @@ -66,7 +63,9 @@ def convergence_Born(cosmo, r_s = jax_cosmo.background.radial_comoving_distance(cosmo, 1 / (1 + z_source)) convergence = 0 - for r, a, p in density_planes: + for entry in density_planes: + r = entry['r']; a = entry['a']; p = entry['plane'] + dx = entry['dx']; dz = entry['dz'] # Normalize density planes density_normalization = dz * r / a p = (p - p.mean()) * constant_factor * density_normalization