mirror of
https://github.com/Richard-Sti/csiborgtools.git
synced 2024-12-22 17:38:02 +00:00
Add more flow modelling (#115)
* Add SN calibration model * Update imports * Scale Carrick field * Minor updates to flow validation * Update script * Update flow model * Add CSiBORG2 * Add CSiboRG2 params * Update imports * Add regular grid interpolator to LOS * Add nb * Remove old code * Update scripts * Minor updates * Minor updates * Add TF * Minor update * Update notebook * Update imports * Add scan and loss of numpyro * Add Pantheon * Update svript * Updaten b * Add model loader * Add jackknife * Add evidence * Update dr * Add BIC to the flow model * Update srcipt * Update nb * Update nb * Update scripts
This commit is contained in:
parent
b503a6f003
commit
fb4abebeb6
11 changed files with 1949 additions and 342 deletions
|
@ -17,7 +17,7 @@ from .density import (DensityField, PotentialField, TidalTensorField,
|
||||||
overdensity_field) # noqa
|
overdensity_field) # noqa
|
||||||
from .enclosed_mass import (particles_enclosed_mass, # noqa
|
from .enclosed_mass import (particles_enclosed_mass, # noqa
|
||||||
particles_enclosed_momentum, field_enclosed_mass) # noqa
|
particles_enclosed_momentum, field_enclosed_mass) # noqa
|
||||||
from .interp import (evaluate_cartesian, evaluate_sky, evaluate_los, # noqa
|
from .interp import (evaluate_cartesian_cic, evaluate_sky, evaluate_los, # noqa
|
||||||
field2rsp, fill_outside, make_sky, # noqa
|
field2rsp, fill_outside, make_sky, # noqa
|
||||||
observer_peculiar_velocity, smoothen_field, # noqa
|
observer_peculiar_velocity, smoothen_field, # noqa
|
||||||
field_at_distance) # noqa
|
field_at_distance) # noqa
|
||||||
|
|
|
@ -18,6 +18,7 @@ Tools for interpolating 3D fields at arbitrary positions.
|
||||||
import MAS_library as MASL
|
import MAS_library as MASL
|
||||||
import numpy
|
import numpy
|
||||||
import smoothing_library as SL
|
import smoothing_library as SL
|
||||||
|
from scipy.interpolate import RegularGridInterpolator
|
||||||
from numba import jit
|
from numba import jit
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
|
@ -30,9 +31,10 @@ from .utils import divide_nonzero, force_single_precision, nside2radec
|
||||||
###############################################################################
|
###############################################################################
|
||||||
|
|
||||||
|
|
||||||
def evaluate_cartesian(*fields, pos, smooth_scales=None, verbose=False):
|
def evaluate_cartesian_cic(*fields, pos, smooth_scales=None, verbose=False):
|
||||||
"""
|
"""
|
||||||
Evaluate a scalar field(s) at Cartesian coordinates `pos`.
|
Evaluate a scalar field(s) at Cartesian coordinates `pos` using CIC
|
||||||
|
interpolation.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
@ -82,6 +84,75 @@ def evaluate_cartesian(*fields, pos, smooth_scales=None, verbose=False):
|
||||||
return interp_fields
|
return interp_fields
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_cartesian_regular(*fields, pos, smooth_scales=None,
|
||||||
|
method="linear", verbose=False):
|
||||||
|
"""
|
||||||
|
Evaluate a scalar field(s) at Cartesian coordinates `pos` using linear
|
||||||
|
interpolation on a regular grid.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
*fields : (list of) 3-dimensional array of shape `(grid, grid, grid)`
|
||||||
|
Fields to be interpolated.
|
||||||
|
pos : 2-dimensional array of shape `(n_samples, 3)`
|
||||||
|
Query positions in box units.
|
||||||
|
smooth_scales : (list of) float, optional
|
||||||
|
Smoothing scales in box units. If `None`, no smoothing is performed.
|
||||||
|
method : str, optional
|
||||||
|
Interpolation method, must be one of the methods of
|
||||||
|
`scipy.interpolate.RegularGridInterpolator`.
|
||||||
|
verbose : bool, optional
|
||||||
|
Smoothing verbosity flag.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
(list of) 2-dimensional array of shape `(n_samples, len(smooth_scales))`
|
||||||
|
"""
|
||||||
|
pos = force_single_precision(pos)
|
||||||
|
|
||||||
|
if isinstance(smooth_scales, (int, float)):
|
||||||
|
smooth_scales = [smooth_scales]
|
||||||
|
|
||||||
|
if smooth_scales is None:
|
||||||
|
shape = (pos.shape[0],)
|
||||||
|
else:
|
||||||
|
shape = (pos.shape[0], len(smooth_scales))
|
||||||
|
|
||||||
|
ngrid = fields[0].shape[0]
|
||||||
|
cellsize = 1. / ngrid
|
||||||
|
|
||||||
|
X = numpy.linspace(0.5 * cellsize, 1 - 0.5 * cellsize, ngrid)
|
||||||
|
Y, Z = numpy.copy(X), numpy.copy(X)
|
||||||
|
|
||||||
|
interp_fields = [numpy.full(shape, numpy.nan, dtype=numpy.float32)
|
||||||
|
for __ in range(len(fields))]
|
||||||
|
for i, field in enumerate(fields):
|
||||||
|
if smooth_scales is None:
|
||||||
|
field_interp = RegularGridInterpolator(
|
||||||
|
(X, Y, Z), field, fill_value=None, bounds_error=False,
|
||||||
|
method=method)
|
||||||
|
interp_fields[i] = field_interp(pos)
|
||||||
|
else:
|
||||||
|
desc = f"Smoothing and interpolating field {i + 1}/{len(fields)}"
|
||||||
|
iterator = tqdm(smooth_scales, desc=desc, disable=not verbose)
|
||||||
|
|
||||||
|
for j, scale in enumerate(iterator):
|
||||||
|
if not scale > 0:
|
||||||
|
fsmooth = numpy.copy(field)
|
||||||
|
else:
|
||||||
|
fsmooth = smoothen_field(field, scale, 1., make_copy=True)
|
||||||
|
|
||||||
|
field_interp = RegularGridInterpolator(
|
||||||
|
(X, Y, Z), fsmooth, fill_value=None, bounds_error=False,
|
||||||
|
method=method)
|
||||||
|
interp_fields[i][:, j] = field_interp(pos)
|
||||||
|
|
||||||
|
if len(fields) == 1:
|
||||||
|
return interp_fields[0]
|
||||||
|
|
||||||
|
return interp_fields
|
||||||
|
|
||||||
|
|
||||||
def observer_peculiar_velocity(velocity_field, smooth_scales=None,
|
def observer_peculiar_velocity(velocity_field, smooth_scales=None,
|
||||||
observer=None, verbose=True):
|
observer=None, verbose=True):
|
||||||
"""
|
"""
|
||||||
|
@ -108,7 +179,7 @@ def observer_peculiar_velocity(velocity_field, smooth_scales=None,
|
||||||
else:
|
else:
|
||||||
pos = numpy.asanyarray(observer).reshape(1, 3)
|
pos = numpy.asanyarray(observer).reshape(1, 3)
|
||||||
|
|
||||||
vx, vy, vz = evaluate_cartesian(
|
vx, vy, vz = evaluate_cartesian_cic(
|
||||||
*velocity_field, pos=pos, smooth_scales=smooth_scales, verbose=verbose)
|
*velocity_field, pos=pos, smooth_scales=smooth_scales, verbose=verbose)
|
||||||
|
|
||||||
# Reshape since we evaluated only one point
|
# Reshape since we evaluated only one point
|
||||||
|
@ -127,7 +198,7 @@ def observer_peculiar_velocity(velocity_field, smooth_scales=None,
|
||||||
|
|
||||||
|
|
||||||
def evaluate_los(*fields, sky_pos, boxsize, rmax, dr, smooth_scales=None,
|
def evaluate_los(*fields, sky_pos, boxsize, rmax, dr, smooth_scales=None,
|
||||||
verbose=False):
|
interpolation_method="cic", verbose=False):
|
||||||
"""
|
"""
|
||||||
Interpolate the fields for a set of lines of sights from the observer
|
Interpolate the fields for a set of lines of sights from the observer
|
||||||
in the centre of the box.
|
in the centre of the box.
|
||||||
|
@ -146,6 +217,9 @@ def evaluate_los(*fields, sky_pos, boxsize, rmax, dr, smooth_scales=None,
|
||||||
Radial distance step in `Mpc / h`.
|
Radial distance step in `Mpc / h`.
|
||||||
smooth_scales : (list of) float, optional
|
smooth_scales : (list of) float, optional
|
||||||
Smoothing scales in `Mpc / h`.
|
Smoothing scales in `Mpc / h`.
|
||||||
|
interpolation_method : str, optional
|
||||||
|
Interpolation method. Must be one of `cic` or one of the methods of
|
||||||
|
`scipy.interpolate.RegularGridInterpolator`.
|
||||||
verbose : bool, optional
|
verbose : bool, optional
|
||||||
Smoothing verbosity flag.
|
Smoothing verbosity flag.
|
||||||
|
|
||||||
|
@ -191,9 +265,15 @@ def evaluate_los(*fields, sky_pos, boxsize, rmax, dr, smooth_scales=None,
|
||||||
|
|
||||||
smooth_scales *= mpc2box
|
smooth_scales *= mpc2box
|
||||||
|
|
||||||
field_interp = evaluate_cartesian(*fields, pos=pos,
|
if interpolation_method == "cic":
|
||||||
smooth_scales=smooth_scales,
|
field_interp = evaluate_cartesian_cic(
|
||||||
verbose=verbose)
|
*fields, pos=pos, smooth_scales=smooth_scales,
|
||||||
|
verbose=verbose)
|
||||||
|
else:
|
||||||
|
field_interp = evaluate_cartesian_regular(
|
||||||
|
*fields, pos=pos, smooth_scales=smooth_scales,
|
||||||
|
method=interpolation_method, verbose=verbose)
|
||||||
|
|
||||||
if len(fields) == 1:
|
if len(fields) == 1:
|
||||||
field_interp = [field_interp]
|
field_interp = [field_interp]
|
||||||
|
|
||||||
|
@ -228,7 +308,7 @@ def evaluate_los(*fields, sky_pos, boxsize, rmax, dr, smooth_scales=None,
|
||||||
def evaluate_sky(*fields, pos, mpc2box, smooth_scales=None, verbose=False):
|
def evaluate_sky(*fields, pos, mpc2box, smooth_scales=None, verbose=False):
|
||||||
"""
|
"""
|
||||||
Evaluate a scalar field(s) at radial distance `Mpc / h`, right ascensions
|
Evaluate a scalar field(s) at radial distance `Mpc / h`, right ascensions
|
||||||
[0, 360) deg and declinations [-90, 90] deg.
|
[0, 360) deg and declinations [-90, 90] deg. Uses CIC interpolation.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
@ -264,8 +344,9 @@ def evaluate_sky(*fields, pos, mpc2box, smooth_scales=None, verbose=False):
|
||||||
|
|
||||||
smooth_scales *= mpc2box
|
smooth_scales *= mpc2box
|
||||||
|
|
||||||
return evaluate_cartesian(*fields, pos=cart_pos,
|
return evaluate_cartesian_cic(*fields, pos=cart_pos,
|
||||||
smooth_scales=smooth_scales, verbose=verbose)
|
smooth_scales=smooth_scales,
|
||||||
|
verbose=verbose)
|
||||||
|
|
||||||
|
|
||||||
def make_sky(field, angpos, dist, boxsize, verbose=True):
|
def make_sky(field, angpos, dist, boxsize, verbose=True):
|
||||||
|
@ -324,7 +405,7 @@ def make_sky(field, angpos, dist, boxsize, verbose=True):
|
||||||
def field_at_distance(field, distance, boxsize, smooth_scales=None, nside=128,
|
def field_at_distance(field, distance, boxsize, smooth_scales=None, nside=128,
|
||||||
verbose=True):
|
verbose=True):
|
||||||
"""
|
"""
|
||||||
Evaluate a scalar field at uniformly spaced angular coordinates at a
|
Evaluate a scalar field at uniformly spaced angular coordinates at a
|
||||||
given distance from the observer
|
given distance from the observer
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
|
@ -355,8 +436,8 @@ def field_at_distance(field, distance, boxsize, smooth_scales=None, nside=128,
|
||||||
angpos])
|
angpos])
|
||||||
X = radec_to_cartesian(X) / boxsize + 0.5
|
X = radec_to_cartesian(X) / boxsize + 0.5
|
||||||
|
|
||||||
return evaluate_cartesian(field, pos=X, smooth_scales=smooth_scales,
|
return evaluate_cartesian_cic(field, pos=X, smooth_scales=smooth_scales,
|
||||||
verbose=verbose)
|
verbose=verbose)
|
||||||
|
|
||||||
|
|
||||||
###############################################################################
|
###############################################################################
|
||||||
|
|
|
@ -14,4 +14,7 @@
|
||||||
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
|
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
|
||||||
from .flow_model import (DataLoader, radial_velocity_los, dist2redshift, # noqa
|
from .flow_model import (DataLoader, radial_velocity_los, dist2redshift, # noqa
|
||||||
dist2distmodulus, predict_zobs, project_Vext, # noqa
|
dist2distmodulus, predict_zobs, project_Vext, # noqa
|
||||||
SD_PV_validation_model) # noqa
|
SD_PV_validation_model, SN_PV_validation_model, # noqa
|
||||||
|
TF_PV_validation_model, radec_to_galactic, # noqa
|
||||||
|
sample_prior, make_loss, get_model, # noqa
|
||||||
|
optimize_model_with_jackknife) # noqa
|
||||||
|
|
|
@ -20,7 +20,7 @@ References
|
||||||
[1] https://arxiv.org/abs/1912.09383.
|
[1] https://arxiv.org/abs/1912.09383.
|
||||||
"""
|
"""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from warnings import warn
|
from warnings import catch_warnings, simplefilter, warn
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpyro
|
import numpyro
|
||||||
|
@ -29,12 +29,18 @@ from astropy import units as u
|
||||||
from astropy.coordinates import SkyCoord
|
from astropy.coordinates import SkyCoord
|
||||||
from astropy.cosmology import FlatLambdaCDM
|
from astropy.cosmology import FlatLambdaCDM
|
||||||
from h5py import File
|
from h5py import File
|
||||||
|
from jax import jit
|
||||||
from jax import numpy as jnp
|
from jax import numpy as jnp
|
||||||
from jax import vmap
|
from jax import vmap
|
||||||
|
from jax.lax import cond, scan
|
||||||
|
from jax.random import PRNGKey
|
||||||
|
from numpyro.infer import Predictive, util
|
||||||
|
from scipy.optimize import fmin_powell
|
||||||
|
from sklearn.model_selection import KFold
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
from numdifftools import Hessian
|
||||||
|
|
||||||
from ..params import simname2Omega_m
|
from ..params import simname2Omega_m
|
||||||
from ..read import CSiBORG1Catalogue
|
|
||||||
|
|
||||||
SPEED_OF_LIGHT = 299792.458 # km / s
|
SPEED_OF_LIGHT = 299792.458 # km / s
|
||||||
|
|
||||||
|
@ -130,19 +136,22 @@ class DataLoader:
|
||||||
if not store_full_velocity:
|
if not store_full_velocity:
|
||||||
self._los_velocity = None
|
self._los_velocity = None
|
||||||
|
|
||||||
Omega_m = simname2Omega_m(simname)
|
self._Omega_m = simname2Omega_m(simname)
|
||||||
|
|
||||||
# Normalize the CSiBORG density by the mean matter density
|
# Normalize the CSiBORG density by the mean matter density
|
||||||
if "csiborg" in simname:
|
if "csiborg" in simname:
|
||||||
cosmo = FlatLambdaCDM(H0=100, Om0=Omega_m)
|
cosmo = FlatLambdaCDM(H0=100, Om0=self._Omega_m)
|
||||||
mean_rho_matter = cosmo.critical_density0.to("Msun/kpc^3").value
|
mean_rho_matter = cosmo.critical_density0.to("Msun/kpc^3").value
|
||||||
mean_rho_matter *= Omega_m
|
mean_rho_matter *= self._Omega_m
|
||||||
self._los_density /= mean_rho_matter
|
self._los_density /= mean_rho_matter
|
||||||
|
|
||||||
# Since Carrick+2015 provide `rho / <rho> - 1`
|
# Since Carrick+2015 provide `rho / <rho> - 1`
|
||||||
if simname == "Carrick2015":
|
if simname == "Carrick2015":
|
||||||
self._los_density += 1
|
self._los_density += 1
|
||||||
|
|
||||||
|
self._mask = np.ones(len(self._cat), dtype=bool)
|
||||||
|
self._catname = catalogue
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cat(self):
|
def cat(self):
|
||||||
"""
|
"""
|
||||||
|
@ -152,7 +161,7 @@ class DataLoader:
|
||||||
-------
|
-------
|
||||||
structured array
|
structured array
|
||||||
"""
|
"""
|
||||||
return self._cat
|
return self._cat[self._mask]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def catname(self):
|
def catname(self):
|
||||||
|
@ -185,7 +194,7 @@ class DataLoader:
|
||||||
----------
|
----------
|
||||||
3-dimensional array of shape (n_objects, n_simulations, n_steps)
|
3-dimensional array of shape (n_objects, n_simulations, n_steps)
|
||||||
"""
|
"""
|
||||||
return self._los_density
|
return self._los_density[self._mask]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def los_velocity(self):
|
def los_velocity(self):
|
||||||
|
@ -198,7 +207,7 @@ class DataLoader:
|
||||||
"""
|
"""
|
||||||
if self._los_velocity is None:
|
if self._los_velocity is None:
|
||||||
raise ValueError("The 3D velocities were not stored.")
|
raise ValueError("The 3D velocities were not stored.")
|
||||||
return self._los_velocity
|
return self._los_velocity[self._mask]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def los_radial_velocity(self):
|
def los_radial_velocity(self):
|
||||||
|
@ -209,7 +218,7 @@ class DataLoader:
|
||||||
-------
|
-------
|
||||||
3-dimensional array of shape (n_objects, n_simulations, n_steps)
|
3-dimensional array of shape (n_objects, n_simulations, n_steps)
|
||||||
"""
|
"""
|
||||||
return self._los_radial_velocity
|
return self._los_radial_velocity[self._mask]
|
||||||
|
|
||||||
def _read_field(self, simname, catalogue, k, paths):
|
def _read_field(self, simname, catalogue, k, paths):
|
||||||
"""Read in the interpolated field."""
|
"""Read in the interpolated field."""
|
||||||
|
@ -250,7 +259,8 @@ class DataLoader:
|
||||||
arr = np.empty(len(f["RA"]), dtype=dtype)
|
arr = np.empty(len(f["RA"]), dtype=dtype)
|
||||||
for key in f.keys():
|
for key in f.keys():
|
||||||
arr[key] = f[key][:]
|
arr[key] = f[key][:]
|
||||||
elif catalogue == "LOSS" or catalogue == "Foundation":
|
elif catalogue in ["LOSS", "Foundation", "SFI_gals", "2MTF",
|
||||||
|
"Pantheon+"]:
|
||||||
with File(catalogue_fpath, 'r') as f:
|
with File(catalogue_fpath, 'r') as f:
|
||||||
grp = f[catalogue]
|
grp = f[catalogue]
|
||||||
|
|
||||||
|
@ -258,28 +268,46 @@ class DataLoader:
|
||||||
arr = np.empty(len(grp["RA"]), dtype=dtype)
|
arr = np.empty(len(grp["RA"]), dtype=dtype)
|
||||||
for key in grp.keys():
|
for key in grp.keys():
|
||||||
arr[key] = grp[key][:]
|
arr[key] = grp[key][:]
|
||||||
elif "csiborg1" in catalogue:
|
|
||||||
nsim = int(catalogue.split("_")[-1])
|
|
||||||
cat = CSiBORG1Catalogue(nsim, bounds={"totmass": (1e13, None)})
|
|
||||||
|
|
||||||
seed = 42
|
|
||||||
gen = np.random.default_rng(seed)
|
|
||||||
mask = gen.choice(len(cat), size=100, replace=False)
|
|
||||||
|
|
||||||
keys = ["r_hMpc", "RA", "DEC"]
|
|
||||||
dtype = [(key, np.float32) for key in keys]
|
|
||||||
arr = np.empty(len(mask), dtype=dtype)
|
|
||||||
|
|
||||||
sph_pos = cat["spherical_pos"]
|
|
||||||
arr["r_hMpc"] = sph_pos[mask, 0]
|
|
||||||
arr["RA"] = sph_pos[mask, 1]
|
|
||||||
arr["DEC"] = sph_pos[mask, 2]
|
|
||||||
# TODO: add peculiar velocit
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown catalogue: `{catalogue}`.")
|
raise ValueError(f"Unknown catalogue: `{catalogue}`.")
|
||||||
|
|
||||||
return arr
|
return arr
|
||||||
|
|
||||||
|
def make_jackknife_mask(self, i, n_splits, seed=42):
|
||||||
|
"""
|
||||||
|
Set the jackknife mask to exclude the `i`-th split.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
i : int
|
||||||
|
Index of the split to exclude.
|
||||||
|
n_splits : int
|
||||||
|
Number of splits.
|
||||||
|
seed : int, optional
|
||||||
|
Random seed.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
None, sets `mask` internally.
|
||||||
|
"""
|
||||||
|
cv = KFold(n_splits=n_splits, shuffle=True, random_state=seed)
|
||||||
|
n = len(self._cat)
|
||||||
|
indxs = np.arange(n)
|
||||||
|
|
||||||
|
gen = np.random.default_rng(seed)
|
||||||
|
gen.shuffle(indxs)
|
||||||
|
|
||||||
|
for j, (train_index, __) in enumerate(cv.split(np.arange(n))):
|
||||||
|
if i == j:
|
||||||
|
self._mask = indxs[train_index]
|
||||||
|
return
|
||||||
|
|
||||||
|
raise ValueError("The index `i` must be in the range of `n_splits`.")
|
||||||
|
|
||||||
|
def reset_mask(self):
|
||||||
|
"""Reset the jackknife mask."""
|
||||||
|
self._mask = np.ones(len(self._cat), dtype=bool)
|
||||||
|
|
||||||
|
|
||||||
###############################################################################
|
###############################################################################
|
||||||
# Supplementary flow functions #
|
# Supplementary flow functions #
|
||||||
|
@ -405,6 +433,19 @@ def dist2distmodulus(dist, Omega_m):
|
||||||
return 5 * jnp.log10(luminosity_distance) + 25
|
return 5 * jnp.log10(luminosity_distance) + 25
|
||||||
|
|
||||||
|
|
||||||
|
# def distmodulus2dist(distmodulus, Omega_m):
|
||||||
|
# """
|
||||||
|
# Copied from Supranta. Make sure this actually works.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# """
|
||||||
|
# dL = 10 ** ((distmodulus - 25.) / 5.)
|
||||||
|
# r_hMpc = dL
|
||||||
|
# for i in range(4):
|
||||||
|
# r_hMpc = dL / (1.0 + dist2redshift(r_hMpc, Omega_m))
|
||||||
|
# return r_hMpc
|
||||||
|
|
||||||
|
|
||||||
def project_Vext(Vext_x, Vext_y, Vext_z, RA, dec):
|
def project_Vext(Vext_x, Vext_y, Vext_z, RA, dec):
|
||||||
"""
|
"""
|
||||||
Project the external velocity onto the line of sight along direction
|
Project the external velocity onto the line of sight along direction
|
||||||
|
@ -459,8 +500,8 @@ def predict_zobs(dist, beta, Vext_radial, vpec_radial, Omega_m):
|
||||||
# Flow validation models #
|
# Flow validation models #
|
||||||
###############################################################################
|
###############################################################################
|
||||||
|
|
||||||
|
def calculate_ptilde_wo_bias(xrange, mu, err, r_squared_xrange=None,
|
||||||
def calculate_ptilde_wo_bias(xrange, mu, err, r_squared_xrange=None):
|
is_err_squared=False):
|
||||||
"""
|
"""
|
||||||
Calculate `ptilde(r)` without any bias.
|
Calculate `ptilde(r)` without any bias.
|
||||||
|
|
||||||
|
@ -475,12 +516,17 @@ def calculate_ptilde_wo_bias(xrange, mu, err, r_squared_xrange=None):
|
||||||
r_squared_xrange : 1-dimensional array, optional
|
r_squared_xrange : 1-dimensional array, optional
|
||||||
Radial distances squared where the field was interpolated for each
|
Radial distances squared where the field was interpolated for each
|
||||||
object. If not provided, the `r^2` correction is not applied.
|
object. If not provided, the `r^2` correction is not applied.
|
||||||
|
is_err_squared : bool, optional
|
||||||
|
Whether the error is already squared.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
1-dimensional array
|
1-dimensional array
|
||||||
"""
|
"""
|
||||||
ptilde = jnp.exp(-0.5 * ((xrange - mu) / err)**2)
|
if is_err_squared:
|
||||||
|
ptilde = jnp.exp(-0.5 * (xrange - mu)**2 / err)
|
||||||
|
else:
|
||||||
|
ptilde = jnp.exp(-0.5 * ((xrange - mu) / err)**2)
|
||||||
|
|
||||||
if r_squared_xrange is not None:
|
if r_squared_xrange is not None:
|
||||||
ptilde *= r_squared_xrange
|
ptilde *= r_squared_xrange
|
||||||
|
@ -538,15 +584,17 @@ class SD_PV_validation_model:
|
||||||
|
|
||||||
def __init__(self, los_density, los_velocity, RA, dec, z_obs,
|
def __init__(self, los_density, los_velocity, RA, dec, z_obs,
|
||||||
r_hMpc, e_r_hMpc, r_xrange, Omega_m):
|
r_hMpc, e_r_hMpc, r_xrange, Omega_m):
|
||||||
# Convert everything to JAX arrays.
|
|
||||||
dt = jnp.float32
|
dt = jnp.float32
|
||||||
|
# Convert everything to JAX arrays.
|
||||||
self._los_density = jnp.asarray(los_density, dtype=dt)
|
self._los_density = jnp.asarray(los_density, dtype=dt)
|
||||||
self._los_velocity = jnp.asarray(los_velocity, dtype=dt)
|
self._los_velocity = jnp.asarray(los_velocity, dtype=dt)
|
||||||
|
|
||||||
self._RA = jnp.asarray(np.deg2rad(RA), dtype=dt)
|
self._RA = jnp.asarray(np.deg2rad(RA), dtype=dt)
|
||||||
self._dec = jnp.asarray(np.deg2rad(dec), dtype=dt)
|
self._dec = jnp.asarray(np.deg2rad(dec), dtype=dt)
|
||||||
self._z_obs = jnp.asarray(z_obs, dtype=dt)
|
self._z_obs = jnp.asarray(z_obs, dtype=dt)
|
||||||
|
|
||||||
self._r_hMpc = jnp.asarray(r_hMpc, dtype=dt)
|
self._r_hMpc = jnp.asarray(r_hMpc, dtype=dt)
|
||||||
self._e_rhMpc = jnp.asarray(e_r_hMpc, dtype=dt)
|
self._e2_rhMpc = jnp.asarray(e_r_hMpc**2, dtype=dt)
|
||||||
|
|
||||||
# Get radius squared
|
# Get radius squared
|
||||||
r2_xrange = r_xrange**2
|
r2_xrange = r_xrange**2
|
||||||
|
@ -558,43 +606,44 @@ class SD_PV_validation_model:
|
||||||
raise ValueError("The radial step size must be constant.")
|
raise ValueError("The radial step size must be constant.")
|
||||||
dr = dr[0]
|
dr = dr[0]
|
||||||
|
|
||||||
|
self._r_xrange = r_xrange
|
||||||
|
|
||||||
# Get the various vmapped functions
|
# Get the various vmapped functions
|
||||||
self._vmap_ptilde_wo_bias = vmap(lambda mu, err: calculate_ptilde_wo_bias(r_xrange, mu, err, r2_xrange)) # noqa
|
self._vmap_ptilde_wo_bias = vmap(lambda mu, err: calculate_ptilde_wo_bias(r_xrange, mu, err, r2_xrange, True)) # noqa
|
||||||
self._vmap_simps = vmap(lambda y: simps(y, dr))
|
self._vmap_simps = vmap(lambda y: simps(y, dr))
|
||||||
self._vmap_zobs = vmap(lambda beta, Vr, vpec_rad: predict_zobs(r_xrange, beta, Vr, vpec_rad, Omega_m), in_axes=(None, 0, 0)) # noqa
|
self._vmap_zobs = vmap(lambda beta, Vr, vpec_rad: predict_zobs(r_xrange, beta, Vr, vpec_rad, Omega_m), in_axes=(None, 0, 0)) # noqa
|
||||||
self._vmap_ll_zobs = vmap(lambda zobs, zobs_pred, sigma_v: calculate_ll_zobs(zobs, zobs_pred, sigma_v), in_axes=(0, 0, None)) # noqa
|
self._vmap_ll_zobs = vmap(lambda zobs, zobs_pred, sigma_v: calculate_ll_zobs(zobs, zobs_pred, sigma_v), in_axes=(0, 0, None)) # noqa
|
||||||
|
|
||||||
# Vext_x, Vext_y, Vext_z: external velocity components
|
# Distribution of external velocity components
|
||||||
self._dist_Vext = dist.Uniform(-1000, 1000)
|
self._Vext = dist.Uniform(-500, 500)
|
||||||
# We want sigma_v to be 150 +- 100 km / s (lognormal)
|
# Distribution of density, velocity and location bias parameters
|
||||||
self._dist_sigma_v = dist.LogNormal(
|
self._alpha = dist.LogNormal(*lognorm_mean_std_to_loc_scale(1.0, 0.5)) # noqa
|
||||||
*lognorm_mean_std_to_loc_scale(150, 100))
|
self._beta = dist.Normal(1., 0.5)
|
||||||
# Density power-law bias
|
# Distribution of velocity uncertainty sigma_v
|
||||||
self._dist_alpha = dist.LogNormal(
|
self._sv = dist.LogNormal(*lognorm_mean_std_to_loc_scale(150, 100))
|
||||||
*lognorm_mean_std_to_loc_scale(1.0, 0.5))
|
|
||||||
# Velocity bias
|
|
||||||
self._dist_beta = dist.Normal(1., 0.5)
|
|
||||||
|
|
||||||
def __call__(self):
|
def __call__(self, sample_alpha=False):
|
||||||
"""
|
"""
|
||||||
The simple distance NumPyro PV validation model. Samples the following
|
The simple distance NumPyro PV validation model.
|
||||||
parameters:
|
|
||||||
- `Vext_x`, `Vext_y`, `Vext_z`: external velocity components
|
Parameters
|
||||||
- `alpha`: density bias parameter
|
----------
|
||||||
- `beta`: velocity bias parameter
|
sample_alpha : bool, optional
|
||||||
- `sigma_v`: velocity uncertainty
|
Whether to sample the density bias parameter `alpha`, otherwise
|
||||||
|
it is fixed to 1.
|
||||||
"""
|
"""
|
||||||
Vx = numpyro.sample("Vext_x", self._dist_Vext)
|
Vx = numpyro.sample("Vext_x", self._Vext)
|
||||||
Vy = numpyro.sample("Vext_y", self._dist_Vext)
|
Vy = numpyro.sample("Vext_y", self._Vext)
|
||||||
Vz = numpyro.sample("Vext_z", self._dist_Vext)
|
Vz = numpyro.sample("Vext_z", self._Vext)
|
||||||
alpha = numpyro.sample("alpha", self._dist_alpha)
|
|
||||||
beta = numpyro.sample("beta", self._dist_beta)
|
alpha = numpyro.sample("alpha", self._alpha) if sample_alpha else 1.0
|
||||||
sigma_v = numpyro.sample("sigma_v", self._dist_sigma_v)
|
beta = numpyro.sample("beta", self._beta)
|
||||||
|
sigma_v = numpyro.sample("sigma_v", self._sv)
|
||||||
|
|
||||||
Vext_rad = project_Vext(Vx, Vy, Vz, self._RA, self._dec)
|
Vext_rad = project_Vext(Vx, Vy, Vz, self._RA, self._dec)
|
||||||
|
|
||||||
# Calculate p(r) and multiply it by the galaxy bias
|
# Calculate p(r) and multiply it by the galaxy bias
|
||||||
ptilde = self._vmap_ptilde_wo_bias(self._r_hMpc, self._e_rhMpc)
|
ptilde = self._vmap_ptilde_wo_bias(self._r_hMpc, self._e2_rhMpc)
|
||||||
ptilde *= self._los_density**alpha
|
ptilde *= self._los_density**alpha
|
||||||
|
|
||||||
# Normalization of p(r)
|
# Normalization of p(r)
|
||||||
|
@ -608,50 +657,507 @@ class SD_PV_validation_model:
|
||||||
numpyro.factor("ll", ll)
|
numpyro.factor("ll", ll)
|
||||||
|
|
||||||
|
|
||||||
# def SN_PV_wcal_validation_model(los_overdensity=None, los_velocity=None,
|
class SN_PV_validation_model:
|
||||||
# RA=None, dec=None, z_CMB=None,
|
"""
|
||||||
# mB=None, x1=None, c=None,
|
Supernova peculiar velocity (PV) validation model that includes the
|
||||||
# e_mB=None, e_x1=None, e_c=None,
|
calibration of the SALT2 light curve parameters.
|
||||||
# mu_xrange=None, r_xrange=None,
|
|
||||||
# norm_r2_xrange=None, Omega_m=None, dr=None):
|
Parameters
|
||||||
# """
|
----------
|
||||||
# Pass
|
los_density : 2-dimensional array of shape (n_objects, n_steps)
|
||||||
# """
|
LOS density field.
|
||||||
# Vx = numpyro.sample("Vext_x", dist.Uniform(-1000, 1000))
|
los_velocity : 3-dimensional array of shape (n_objects, n_steps)
|
||||||
# Vy = numpyro.sample("Vext_y", dist.Uniform(-1000, 1000))
|
LOS radial velocity field.
|
||||||
# Vz = numpyro.sample("Vext_z", dist.Uniform(-1000, 1000))
|
RA, dec : 1-dimensional arrays of shape (n_objects)
|
||||||
# beta = numpyro.sample("beta", dist.Uniform(-10, 10))
|
Right ascension and declination in degrees.
|
||||||
#
|
z_obs : 1-dimensional array of shape (n_objects)
|
||||||
# # TODO: Later sample these as well.
|
Observed redshifts.
|
||||||
# e_mu_intrinsic = 0.064
|
mB, x1, c : 1-dimensional arrays of shape (n_objects)
|
||||||
# alpha_cal = 0.135
|
SALT2 light curve parameters.
|
||||||
# beta_cal = 2.9
|
e_mB, e_x1, e_c : 1-dimensional arrays of shape (n_objects)
|
||||||
# mag_cal = -18.555
|
Errors on the SALT2 light curve parameters.
|
||||||
# sigma_v = 112
|
r_xrange : 1-dimensional array
|
||||||
#
|
Radial distances where the field was interpolated for each object.
|
||||||
# # TODO: Check these for fiducial values.
|
Omega_m : float
|
||||||
# mu = mB - mag_cal + alpha_cal * x1 - beta_cal * c
|
Matter density parameter.
|
||||||
# squared_e_mu = e_mB**2 + alpha_cal**2 * e_x1**2 + beta_cal**2 * e_c**2
|
"""
|
||||||
#
|
|
||||||
# squared_e_mu += e_mu_intrinsic**2
|
def __init__(self, los_density, los_velocity, RA, dec, z_obs,
|
||||||
# ll = 0.
|
mB, x1, c, e_mB, e_x1, e_c, r_xrange, Omega_m):
|
||||||
# for i in range(len(los_overdensity)):
|
dt = jnp.float32
|
||||||
# # Project the external velocity for this galaxy.
|
# Convert everything to JAX arrays.
|
||||||
# Vext_rad = project_Vext(Vx, Vy, Vz, RA[i], dec[i])
|
self._los_density = jnp.asarray(los_density, dtype=dt)
|
||||||
#
|
self._los_velocity = jnp.asarray(los_velocity, dtype=dt)
|
||||||
# dmu = mu_xrange - mu[i]
|
|
||||||
# ptilde = norm_r2_xrange * jnp.exp(-0.5 * dmu**2 / squared_e_mu[i])
|
self._RA = jnp.asarray(np.deg2rad(RA), dtype=dt)
|
||||||
# # TODO: Add some bias
|
self._dec = jnp.asarray(np.deg2rad(dec), dtype=dt)
|
||||||
# ptilde *= (1 + los_overdensity[i])
|
self._z_obs = jnp.asarray(z_obs, dtype=dt)
|
||||||
#
|
|
||||||
# zobs_pred = predict_zobs(r_xrange, beta, Vext_rad, los_velocity[i],
|
self._mB = jnp.asarray(mB, dtype=dt)
|
||||||
# Omega_m)
|
self._x1 = jnp.asarray(x1, dtype=dt)
|
||||||
#
|
self._c = jnp.asarray(c, dtype=dt)
|
||||||
# dczobs = SPEED_OF_LIGHT * (z_CMB[i] - zobs_pred)
|
self._e2_mB = jnp.asarray(e_mB**2, dtype=dt)
|
||||||
#
|
self._e2_x1 = jnp.asarray(e_x1**2, dtype=dt)
|
||||||
# ll_zobs = jnp.exp(-0.5 * (dczobs / sigma_v)**2) / sigma_v
|
self._e2_c = jnp.asarray(e_c**2, dtype=dt)
|
||||||
#
|
|
||||||
# ll += jnp.log(simps(ptilde * ll_zobs, dr))
|
# Get radius squared
|
||||||
# ll -= jnp.log(simps(ptilde, dr))
|
r2_xrange = r_xrange**2
|
||||||
#
|
r2_xrange /= r2_xrange.mean()
|
||||||
# numpyro.factor("ll", ll)
|
mu_xrange = dist2distmodulus(r_xrange, Omega_m)
|
||||||
|
|
||||||
|
# Get the stepsize, we need it to be constant for Simpson's rule.
|
||||||
|
dr = np.diff(r_xrange)
|
||||||
|
if not np.all(np.isclose(dr, dr[0], atol=1e-5)):
|
||||||
|
raise ValueError("The radial step size must be constant.")
|
||||||
|
dr = dr[0]
|
||||||
|
|
||||||
|
# Get the various vmapped functions
|
||||||
|
self._f_ptilde_wo_bias = lambda mu, err: calculate_ptilde_wo_bias(mu_xrange, mu, err, r2_xrange, True) # noqa
|
||||||
|
self._f_simps = lambda y: simps(y, dr) # noqa
|
||||||
|
self._f_zobs = lambda beta, Vr, vpec_rad: predict_zobs(r_xrange, beta, Vr, vpec_rad, Omega_m) # noqa
|
||||||
|
|
||||||
|
# Distribution of external velocity components
|
||||||
|
self._Vext = dist.Uniform(-500, 500)
|
||||||
|
# Distribution of velocity and density bias parameters
|
||||||
|
self._alpha = dist.LogNormal(*lognorm_mean_std_to_loc_scale(1.0, 0.5))
|
||||||
|
self._beta = dist.Normal(1., 0.5)
|
||||||
|
# Distribution of velocity uncertainty
|
||||||
|
self._sigma_v = dist.LogNormal(*lognorm_mean_std_to_loc_scale(150, 100)) # noqa
|
||||||
|
|
||||||
|
# Distribution of light curve calibration parameters
|
||||||
|
self._mag_cal = dist.Normal(-18.25, 0.5)
|
||||||
|
self._alpha_cal = dist.Normal(0.148, 0.05)
|
||||||
|
self._beta_cal = dist.Normal(3.112, 1.0)
|
||||||
|
self._e_mu = dist.LogNormal(*lognorm_mean_std_to_loc_scale(0.1, 0.05))
|
||||||
|
|
||||||
|
def __call__(self, sample_alpha=True, fix_calibration=False):
|
||||||
|
"""
|
||||||
|
The supernova NumPyro PV validation model with SALT2 calibration.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sample_alpha : bool, optional
|
||||||
|
Whether to sample the density bias parameter `alpha`, otherwise
|
||||||
|
it is fixed to 1.
|
||||||
|
fix_calibration : str, optional
|
||||||
|
Whether to fix the calibration parameters. If not provided, they
|
||||||
|
are sampled. If "Foundation" or "LOSS" is provided, the parameters
|
||||||
|
are fixed to the best inverse parameters for the Foundation or LOSS
|
||||||
|
catalogues.
|
||||||
|
"""
|
||||||
|
Vx = numpyro.sample("Vext_x", self._Vext)
|
||||||
|
Vy = numpyro.sample("Vext_y", self._Vext)
|
||||||
|
Vz = numpyro.sample("Vext_z", self._Vext)
|
||||||
|
alpha = numpyro.sample("alpha", self._alpha) if sample_alpha else 1.0
|
||||||
|
beta = numpyro.sample("beta", self._beta)
|
||||||
|
sigma_v = numpyro.sample("sigma_v", self._sigma_v)
|
||||||
|
|
||||||
|
if fix_calibration == "Foundation":
|
||||||
|
# Foundation inverse best parameters
|
||||||
|
e_mu_intrinsic = 0.064
|
||||||
|
alpha_cal = 0.135
|
||||||
|
beta_cal = 2.9
|
||||||
|
sigma_v = 149
|
||||||
|
mag_cal = -18.555
|
||||||
|
elif fix_calibration == "LOSS":
|
||||||
|
# LOSS inverse best parameters
|
||||||
|
e_mu_intrinsic = 0.123
|
||||||
|
alpha_cal = 0.123
|
||||||
|
beta_cal = 3.52
|
||||||
|
mag_cal = -18.195
|
||||||
|
sigma_v = 149
|
||||||
|
else:
|
||||||
|
e_mu_intrinsic = numpyro.sample("e_mu_intrinsic", self._e_mu)
|
||||||
|
mag_cal = numpyro.sample("mag_cal", self._mag_cal)
|
||||||
|
alpha_cal = numpyro.sample("alpha_cal", self._alpha_cal)
|
||||||
|
beta_cal = numpyro.sample("beta_cal", self._beta_cal)
|
||||||
|
|
||||||
|
Vext_rad = project_Vext(Vx, Vy, Vz, self._RA, self._dec)
|
||||||
|
|
||||||
|
mu = self._mB - mag_cal + alpha_cal * self._x1 - beta_cal * self._c
|
||||||
|
squared_e_mu = (self._e2_mB + alpha_cal**2 * self._e2_x1
|
||||||
|
+ beta_cal**2 * self._e2_c + e_mu_intrinsic**2)
|
||||||
|
|
||||||
|
def scan_body(ll, i):
|
||||||
|
# Calculate p(r) and multiply it by the galaxy bias
|
||||||
|
ptilde = self._f_ptilde_wo_bias(mu[i], squared_e_mu[i])
|
||||||
|
ptilde *= self._los_density[i]**alpha
|
||||||
|
|
||||||
|
# Normalization of p(r)
|
||||||
|
pnorm = self._f_simps(ptilde)
|
||||||
|
|
||||||
|
# Calculate p(z_obs) and multiply it by p(r)
|
||||||
|
zobs_pred = self._f_zobs(beta, Vext_rad[i], self._los_velocity[i])
|
||||||
|
ptilde *= calculate_ll_zobs(self._z_obs[i], zobs_pred, sigma_v)
|
||||||
|
|
||||||
|
return ll + jnp.log(self._f_simps(ptilde) / pnorm), None
|
||||||
|
|
||||||
|
ll = 0.
|
||||||
|
ll, __ = scan(scan_body, ll, jnp.arange(len(self._RA)))
|
||||||
|
numpyro.factor("ll", ll)
|
||||||
|
|
||||||
|
|
||||||
|
class TF_PV_validation_model:
|
||||||
|
"""
|
||||||
|
Tully-Fisher peculiar velocity (PV) validation model that includes the
|
||||||
|
calibration of the Tully-Fisher distance `mu = m - (a + b * eta)`.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
los_density : 2-dimensional array of shape (n_objects, n_steps)
|
||||||
|
LOS density field.
|
||||||
|
los_velocity : 3-dimensional array of shape (n_objects, n_steps)
|
||||||
|
LOS radial velocity field.
|
||||||
|
RA, dec : 1-dimensional arrays of shape (n_objects)
|
||||||
|
Right ascension and declination in degrees.
|
||||||
|
z_obs : 1-dimensional array of shape (n_objects)
|
||||||
|
Observed redshifts.
|
||||||
|
mag, eta : 1-dimensional arrays of shape (n_objects)
|
||||||
|
Apparent magnitude and `eta` parameter.
|
||||||
|
e_mag, e_eta : 1-dimensional arrays of shape (n_objects)
|
||||||
|
Errors on the apparent magnitude and `eta` parameter.
|
||||||
|
r_xrange : 1-dimensional array
|
||||||
|
Radial distances where the field was interpolated for each object.
|
||||||
|
Omega_m : float
|
||||||
|
Matter density parameter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, los_density, los_velocity, RA, dec, z_obs,
|
||||||
|
mag, eta, e_mag, e_eta, r_xrange, Omega_m):
|
||||||
|
dt = jnp.float32
|
||||||
|
# Convert everything to JAX arrays.
|
||||||
|
self._los_density = jnp.asarray(los_density, dtype=dt)
|
||||||
|
self._los_velocity = jnp.asarray(los_velocity, dtype=dt)
|
||||||
|
|
||||||
|
self._RA = jnp.asarray(np.deg2rad(RA), dtype=dt)
|
||||||
|
self._dec = jnp.asarray(np.deg2rad(dec), dtype=dt)
|
||||||
|
self._z_obs = jnp.asarray(z_obs, dtype=dt)
|
||||||
|
|
||||||
|
self._mag = jnp.asarray(mag, dtype=dt)
|
||||||
|
self._eta = jnp.asarray(eta, dtype=dt)
|
||||||
|
self._e2_mag = jnp.asarray(e_mag**2, dtype=dt)
|
||||||
|
self._e2_eta = jnp.asarray(e_eta**2, dtype=dt)
|
||||||
|
|
||||||
|
# Get radius squared
|
||||||
|
r2_xrange = r_xrange**2
|
||||||
|
r2_xrange /= r2_xrange.mean()
|
||||||
|
mu_xrange = dist2distmodulus(r_xrange, Omega_m)
|
||||||
|
|
||||||
|
# Get the stepsize, we need it to be constant for Simpson's rule.
|
||||||
|
dr = np.diff(r_xrange)
|
||||||
|
if not np.all(np.isclose(dr, dr[0], atol=1e-5)):
|
||||||
|
raise ValueError("The radial step size must be constant.")
|
||||||
|
dr = dr[0]
|
||||||
|
|
||||||
|
# Get the various vmapped functions
|
||||||
|
self._f_ptilde_wo_bias = lambda mu, err: calculate_ptilde_wo_bias(mu_xrange, mu, err, r2_xrange, True) # noqa
|
||||||
|
self._f_simps = lambda y: simps(y, dr) # noqa
|
||||||
|
self._f_zobs = lambda beta, Vr, vpec_rad: predict_zobs(r_xrange, beta, Vr, vpec_rad, Omega_m) # noqa
|
||||||
|
|
||||||
|
# Distribution of external velocity components
|
||||||
|
self._Vext = dist.Uniform(-1000, 1000)
|
||||||
|
# Distribution of velocity and density bias parameters
|
||||||
|
self._alpha = dist.LogNormal(*lognorm_mean_std_to_loc_scale(1.0, 0.5)) # noqa
|
||||||
|
self._beta = dist.Normal(1., 0.5)
|
||||||
|
# Distribution of velocity uncertainty
|
||||||
|
self._sigma_v = dist.LogNormal(*lognorm_mean_std_to_loc_scale(150, 100)) # noqa
|
||||||
|
|
||||||
|
# Distribution of Tully-Fisher calibration parameters
|
||||||
|
self._a = dist.Normal(-21., 0.5)
|
||||||
|
self._b = dist.Normal(-5.95, 0.1)
|
||||||
|
self._e_mu = dist.LogNormal(*lognorm_mean_std_to_loc_scale(0.3, 0.1)) # noqa
|
||||||
|
|
||||||
|
def __call__(self, sample_alpha=True):
|
||||||
|
"""
|
||||||
|
The Tully-Fisher NumPyro PV validation model.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sample_alpha : bool, optional
|
||||||
|
Whether to sample the density bias parameter `alpha`, otherwise
|
||||||
|
it is fixed to 1.
|
||||||
|
"""
|
||||||
|
Vx = numpyro.sample("Vext_x", self._Vext)
|
||||||
|
Vy = numpyro.sample("Vext_y", self._Vext)
|
||||||
|
Vz = numpyro.sample("Vext_z", self._Vext)
|
||||||
|
alpha = numpyro.sample("alpha", self._alpha) if sample_alpha else 1.0
|
||||||
|
beta = numpyro.sample("beta", self._beta)
|
||||||
|
sigma_v = numpyro.sample("sigma_v", self._sigma_v)
|
||||||
|
|
||||||
|
e_mu_intrinsic = numpyro.sample("e_mu_intrinsic", self._e_mu)
|
||||||
|
a = numpyro.sample("a", self._a)
|
||||||
|
b = numpyro.sample("b", self._b)
|
||||||
|
|
||||||
|
Vext_rad = project_Vext(Vx, Vy, Vz, self._RA, self._dec)
|
||||||
|
|
||||||
|
mu = self._mag - (a + b * self._eta)
|
||||||
|
squared_e_mu = (self._e2_mag + b**2 * self._e2_eta
|
||||||
|
+ e_mu_intrinsic**2)
|
||||||
|
|
||||||
|
def scan_body(ll, i):
|
||||||
|
# Calculate p(r) and multiply it by the galaxy bias
|
||||||
|
ptilde = self._f_ptilde_wo_bias(mu[i], squared_e_mu[i])
|
||||||
|
ptilde *= self._los_density[i]**alpha
|
||||||
|
|
||||||
|
# Normalization of p(r)
|
||||||
|
pnorm = self._f_simps(ptilde)
|
||||||
|
|
||||||
|
# Calculate p(z_obs) and multiply it by p(r)
|
||||||
|
zobs_pred = self._f_zobs(beta, Vext_rad[i], self._los_velocity[i])
|
||||||
|
ptilde *= calculate_ll_zobs(self._z_obs[i], zobs_pred, sigma_v)
|
||||||
|
|
||||||
|
return ll + jnp.log(self._f_simps(ptilde) / pnorm), None
|
||||||
|
|
||||||
|
ll = 0.
|
||||||
|
ll, __ = scan(scan_body, ll, jnp.arange(len(self._RA)))
|
||||||
|
|
||||||
|
numpyro.factor("ll", ll)
|
||||||
|
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# Shortcut to create a model #
|
||||||
|
###############################################################################
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(loader, k, zcmb_max=None, verbose=True):
|
||||||
|
"""
|
||||||
|
Get a model and extract the relevant data from the loader.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
loader : DataLoader
|
||||||
|
DataLoader instance.
|
||||||
|
k : int
|
||||||
|
Simulation index.
|
||||||
|
zcmb_max : float, optional
|
||||||
|
Maximum observed redshift in the CMB frame to include.
|
||||||
|
verbose : bool, optional
|
||||||
|
Verbosity flag.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
model : NumPyro model
|
||||||
|
"""
|
||||||
|
zcmb_max = np.infty if zcmb_max is None else zcmb_max
|
||||||
|
|
||||||
|
if k > loader.los_density.shape[1]:
|
||||||
|
raise ValueError(f"Simulation index `{k}` out of range.")
|
||||||
|
|
||||||
|
los_overdensity = loader.los_density[:, k, :]
|
||||||
|
los_velocity = loader.los_radial_velocity[:, k, :]
|
||||||
|
kind = loader._catname
|
||||||
|
|
||||||
|
if kind in ["LOSS", "Foundation"]:
|
||||||
|
keys = ["RA", "DEC", "z_CMB", "mB", "x1", "c", "e_mB", "e_x1", "e_c"]
|
||||||
|
RA, dec, zCMB, mB, x1, c, e_mB, e_x1, e_c = (loader.cat[k] for k in keys) # noqa
|
||||||
|
|
||||||
|
mask = (zCMB < zcmb_max)
|
||||||
|
model = SN_PV_validation_model(
|
||||||
|
los_overdensity[mask], los_velocity[mask], RA[mask], dec[mask],
|
||||||
|
zCMB[mask], mB[mask], x1[mask], c[mask], e_mB[mask], e_x1[mask],
|
||||||
|
e_c[mask], loader.rdist, loader._Omega_m)
|
||||||
|
elif kind == "Pantheon+":
|
||||||
|
keys = ["RA", "DEC", "zCMB", "mB", "x1", "c", "biasCor_m_b", "mBERR",
|
||||||
|
"x1ERR", "cERR", "biasCorErr_m_b"]
|
||||||
|
|
||||||
|
RA, dec, zCMB, mB, x1, c, bias_corr_mB, e_mB, e_x1, e_c, e_bias_corr_mB = (loader.cat[k] for k in keys) # noqa
|
||||||
|
mB -= bias_corr_mB
|
||||||
|
e_mB = np.sqrt(e_mB**2 + e_bias_corr_mB**2)
|
||||||
|
|
||||||
|
mask = (zCMB < zcmb_max)
|
||||||
|
model = SN_PV_validation_model(
|
||||||
|
los_overdensity[mask], los_velocity[mask], RA[mask], dec[mask],
|
||||||
|
zCMB[mask], mB[mask], x1[mask], c[mask], e_mB[mask], e_x1[mask],
|
||||||
|
e_c[mask], loader.rdist, loader._Omega_m)
|
||||||
|
elif kind in ["SFI_gals", "2MTF"]:
|
||||||
|
keys = ["RA", "DEC", "z_CMB", "mag", "eta", "e_mag", "e_eta"]
|
||||||
|
RA, dec, zCMB, mag, eta, e_mag, e_eta = (loader.cat[k] for k in keys)
|
||||||
|
|
||||||
|
mask = (zCMB < zcmb_max)
|
||||||
|
if kind == "SFI_gals":
|
||||||
|
mask &= (eta > -0.15) & (eta < 0.2)
|
||||||
|
if verbose:
|
||||||
|
print("Emplyed eta cut for SFI galaxies.", flush=True)
|
||||||
|
model = TF_PV_validation_model(
|
||||||
|
los_overdensity[mask], los_velocity[mask], RA[mask], dec[mask],
|
||||||
|
zCMB[mask], mag[mask], eta[mask], e_mag[mask], e_eta[mask],
|
||||||
|
loader.rdist, loader._Omega_m)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Catalogue `{kind}` not recognized.")
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f"Selected {np.sum(mask)}/{len(mask)} galaxies.", flush=True)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# Maximizing likelihood of a NumPyro model #
|
||||||
|
###############################################################################
|
||||||
|
|
||||||
|
|
||||||
|
def sample_prior(model, seed, sample_alpha, as_dict=False):
|
||||||
|
"""
|
||||||
|
Sample a single set of parameters from the prior of the model.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model : NumPyro model
|
||||||
|
NumPyro model.
|
||||||
|
seed : int
|
||||||
|
Random seed.
|
||||||
|
sample_alpha : bool
|
||||||
|
Whether to sample the density bias parameter `alpha`.
|
||||||
|
as_dict : bool, optional
|
||||||
|
Whether to return the parameters as a dictionary or a list of
|
||||||
|
parameters.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
x, keys : tuple
|
||||||
|
Tuple of parameters and their names. If `as_dict` is True, returns
|
||||||
|
only a dictionary.
|
||||||
|
"""
|
||||||
|
predictive = Predictive(model, num_samples=1)
|
||||||
|
samples = predictive(PRNGKey(seed), sample_alpha=sample_alpha)
|
||||||
|
|
||||||
|
if as_dict:
|
||||||
|
return samples
|
||||||
|
|
||||||
|
keys = list(samples.keys())
|
||||||
|
if "ll" in keys:
|
||||||
|
keys.remove("ll")
|
||||||
|
|
||||||
|
x = np.asarray([samples[key][0] for key in keys])
|
||||||
|
return x, keys
|
||||||
|
|
||||||
|
|
||||||
|
def make_loss(model, keys, sample_alpha=True, to_jit=True):
|
||||||
|
"""
|
||||||
|
Generate a loss function for the NumPyro model, that is the negative
|
||||||
|
log-likelihood. Note that this loss function cannot be automatically
|
||||||
|
differentiated.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model : NumPyro model
|
||||||
|
NumPyro model.
|
||||||
|
keys : list
|
||||||
|
List of parameter names.
|
||||||
|
sample_alpha : bool, optional
|
||||||
|
Whether to sample the density bias parameter `alpha`.
|
||||||
|
to_jit : bool, optional
|
||||||
|
Whether to JIT the loss function.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
loss : function
|
||||||
|
Loss function `f(x)` where `x` is a list of parameters ordered
|
||||||
|
according to `keys`.
|
||||||
|
"""
|
||||||
|
def f(x):
|
||||||
|
samples = {key: x[i] for i, key in enumerate(keys)}
|
||||||
|
|
||||||
|
loss = -util.log_likelihood(
|
||||||
|
model, samples, sample_alpha=sample_alpha)["ll"]
|
||||||
|
|
||||||
|
loss += cond(samples["sigma_v"] > 0, lambda: 0., lambda: jnp.inf)
|
||||||
|
loss += cond(samples["e_mu_intrinsic"] > 0, lambda: 0., lambda: jnp.inf) # noqa
|
||||||
|
|
||||||
|
return cond(jnp.isfinite(loss), lambda: loss, lambda: jnp.inf)
|
||||||
|
|
||||||
|
if to_jit:
|
||||||
|
return jit(f)
|
||||||
|
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
def optimize_model_with_jackknife(loader, k, n_splits=5, sample_alpha=True,
|
||||||
|
get_model_kwargs={}, seed=42):
|
||||||
|
"""
|
||||||
|
Optimize the log-likelihood of a model for `n_splits` jackknifes.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
loader : DataLoader
|
||||||
|
DataLoader instance.
|
||||||
|
k : int
|
||||||
|
Simulation index.
|
||||||
|
n_splits : int, optional
|
||||||
|
Number of jackknife splits.
|
||||||
|
sample_alpha : bool, optional
|
||||||
|
Whether to sample the density bias parameter `alpha`.
|
||||||
|
get_model_kwargs : dict, optional
|
||||||
|
Additional keyword arguments to pass to the `get_model` function.
|
||||||
|
seed : int, optional
|
||||||
|
Random seed.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
samples : dict
|
||||||
|
Dictionary of optimized parameters for each jackknife split.
|
||||||
|
stats : dict
|
||||||
|
Dictionary of mean and standard deviation for each parameter.
|
||||||
|
fmin : 1-dimensional array
|
||||||
|
Minimum negative log-likelihood for each jackknife split.
|
||||||
|
logz : 1-dimensional array
|
||||||
|
Log-evidence for each jackknife split.
|
||||||
|
bic : 1-dimensional array
|
||||||
|
Bayesian information criterion for each jackknife split.
|
||||||
|
"""
|
||||||
|
mask = np.zeros(n_splits, dtype=bool)
|
||||||
|
x0 = None
|
||||||
|
|
||||||
|
# Loop over the CV splits.
|
||||||
|
for i in trange(n_splits):
|
||||||
|
loader.make_jackknife_mask(i, n_splits, seed=seed)
|
||||||
|
model = get_model(loader, k, verbose=False, **get_model_kwargs)
|
||||||
|
|
||||||
|
if x0 is None:
|
||||||
|
x0, keys = sample_prior(model, seed, sample_alpha)
|
||||||
|
x = np.full((n_splits, len(x0)), np.nan)
|
||||||
|
fmin = np.full(n_splits, np.nan)
|
||||||
|
logz = np.full(n_splits, np.nan)
|
||||||
|
bic = np.full(n_splits, np.nan)
|
||||||
|
|
||||||
|
loss = make_loss(model, keys, sample_alpha=sample_alpha,
|
||||||
|
to_jit=True)
|
||||||
|
for j in range(100):
|
||||||
|
if np.isfinite(loss(x0)):
|
||||||
|
break
|
||||||
|
x0, __ = sample_prior(model, seed + 1, sample_alpha)
|
||||||
|
else:
|
||||||
|
raise ValueError("Failed to find finite initial loss.")
|
||||||
|
|
||||||
|
else:
|
||||||
|
loss = make_loss(model, keys, sample_alpha=sample_alpha,
|
||||||
|
to_jit=True)
|
||||||
|
|
||||||
|
with catch_warnings():
|
||||||
|
simplefilter("ignore")
|
||||||
|
res = fmin_powell(loss, x0, disp=False)
|
||||||
|
|
||||||
|
if np.all(np.isfinite(res)):
|
||||||
|
x[i] = res
|
||||||
|
mask[i] = True
|
||||||
|
x0 = res
|
||||||
|
fmin[i] = loss(res)
|
||||||
|
|
||||||
|
f_hess = Hessian(loss, method="forward", richardson_terms=1)
|
||||||
|
hess = f_hess(res)
|
||||||
|
D = len(keys)
|
||||||
|
logz[i] = (
|
||||||
|
- fmin[i]
|
||||||
|
+ 0.5 * np.log(np.abs(np.linalg.det(np.linalg.inv(hess))))
|
||||||
|
+ D / 2 * np.log(2 * np.pi))
|
||||||
|
|
||||||
|
bic[i] = len(keys) * np.log(len(loader.cat["RA"])) + 2 * fmin[i]
|
||||||
|
|
||||||
|
samples = {key: x[:, i][mask] for i, key in enumerate(keys)}
|
||||||
|
|
||||||
|
mean = [np.mean(samples[key]) for key in keys]
|
||||||
|
std = [(len(samples[key] - 1) * np.var(samples[key], ddof=0))**0.5
|
||||||
|
for key in keys]
|
||||||
|
stats = {key: (mean[i], std[i]) for i, key in enumerate(keys)}
|
||||||
|
|
||||||
|
return samples, stats, fmin, logz, bic
|
||||||
|
|
|
@ -63,6 +63,8 @@ def simname2Omega_m(simname):
|
||||||
Omega_m: float
|
Omega_m: float
|
||||||
"""
|
"""
|
||||||
d = {"csiborg1": 0.307,
|
d = {"csiborg1": 0.307,
|
||||||
|
"csiborg2_main": 0.3111,
|
||||||
|
"csiborg2_random": 0.3111,
|
||||||
"borg1": 0.307,
|
"borg1": 0.307,
|
||||||
"Carrick2015": 0.3,
|
"Carrick2015": 0.3,
|
||||||
}
|
}
|
||||||
|
|
223
notebooks/field_velocity_fof_sph.ipynb
Normal file
223
notebooks/field_velocity_fof_sph.ipynb
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -58,7 +58,8 @@ def get_los(catalogue_name, simname, comm):
|
||||||
if comm.Get_rank() == 0:
|
if comm.Get_rank() == 0:
|
||||||
folder = "/mnt/extraspace/rstiskalek/catalogs"
|
folder = "/mnt/extraspace/rstiskalek/catalogs"
|
||||||
|
|
||||||
if catalogue_name == "LOSS" or catalogue_name == "Foundation":
|
if catalogue_name in ["LOSS", "Foundation", "SFI_gals", "2MTF",
|
||||||
|
"Pantheon+"]:
|
||||||
fpath = join(folder, "PV_compilation_Supranta2019.hdf5")
|
fpath = join(folder, "PV_compilation_Supranta2019.hdf5")
|
||||||
with File(fpath, 'r') as f:
|
with File(fpath, 'r') as f:
|
||||||
grp = f[catalogue_name]
|
grp = f[catalogue_name]
|
||||||
|
@ -69,18 +70,6 @@ def get_los(catalogue_name, simname, comm):
|
||||||
with File(fpath, 'r') as f:
|
with File(fpath, 'r') as f:
|
||||||
RA = f["RA"][:]
|
RA = f["RA"][:]
|
||||||
dec = f["DEC"][:]
|
dec = f["DEC"][:]
|
||||||
elif "csiborg1" in catalogue_name:
|
|
||||||
nsim = int(catalogue_name.split("_")[-1])
|
|
||||||
cat = csiborgtools.read.CSiBORG1Catalogue(
|
|
||||||
nsim, bounds={"totmass": (1e13, None)})
|
|
||||||
|
|
||||||
seed = 42
|
|
||||||
gen = np.random.default_rng(seed)
|
|
||||||
mask = gen.choice(len(cat), size=100, replace=False)
|
|
||||||
|
|
||||||
sph_pos = cat["spherical_pos"]
|
|
||||||
RA = sph_pos[mask, 1]
|
|
||||||
dec = sph_pos[mask, 2]
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown field name: `{catalogue_name}`.")
|
raise ValueError(f"Unknown field name: `{catalogue_name}`.")
|
||||||
|
|
||||||
|
@ -122,6 +111,9 @@ def get_field(simname, nsim, kind, MAS, grid):
|
||||||
# Open the field reader.
|
# Open the field reader.
|
||||||
if simname == "csiborg1":
|
if simname == "csiborg1":
|
||||||
field_reader = csiborgtools.read.CSiBORG1Field(nsim)
|
field_reader = csiborgtools.read.CSiBORG1Field(nsim)
|
||||||
|
elif "csiborg2" in simname:
|
||||||
|
simkind = simname.split("_")[-1]
|
||||||
|
field_reader = csiborgtools.read.CSiBORG2Field(nsim, simkind)
|
||||||
elif simname == "Carrick2015":
|
elif simname == "Carrick2015":
|
||||||
folder = "/mnt/extraspace/rstiskalek/catalogs"
|
folder = "/mnt/extraspace/rstiskalek/catalogs"
|
||||||
warn(f"Using local paths from `{folder}`.", RuntimeWarning)
|
warn(f"Using local paths from `{folder}`.", RuntimeWarning)
|
||||||
|
@ -130,7 +122,20 @@ def get_field(simname, nsim, kind, MAS, grid):
|
||||||
return np.load(fpath).astype(np.float32)
|
return np.load(fpath).astype(np.float32)
|
||||||
elif kind == "velocity":
|
elif kind == "velocity":
|
||||||
fpath = join(folder, "twompp_velocity_carrick2015.npy")
|
fpath = join(folder, "twompp_velocity_carrick2015.npy")
|
||||||
return np.load(fpath).astype(np.float32)
|
field = np.load(fpath).astype(np.float32)
|
||||||
|
|
||||||
|
# Because the Carrick+2015 data is in the following form:
|
||||||
|
# "The velocities are predicted peculiar velocities in the CMB
|
||||||
|
# frame in Galactic Cartesian coordinates, generated from the
|
||||||
|
# \(\delta_g^*\) field with \(\beta^* = 0.43\) and an external
|
||||||
|
# dipole \(V_\mathrm{ext} = [89,-131,17]\) (Carrick et al Table 3)
|
||||||
|
# has already been added.""
|
||||||
|
field[0] -= 89
|
||||||
|
field[1] -= -131
|
||||||
|
field[2] -= 17
|
||||||
|
field /= 0.43
|
||||||
|
|
||||||
|
return field
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown field kind: `{kind}`.")
|
raise ValueError(f"Unknown field kind: `{kind}`.")
|
||||||
else:
|
else:
|
||||||
|
@ -274,7 +279,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
rmax = 200
|
rmax = 200
|
||||||
dr = 0.5
|
dr = 0.5
|
||||||
smooth_scales = [0, 2, 4, 6]
|
smooth_scales = [0, 2]
|
||||||
|
|
||||||
comm = MPI.COMM_WORLD
|
comm = MPI.COMM_WORLD
|
||||||
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
|
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
|
||||||
|
|
|
@ -1,14 +1,13 @@
|
||||||
nthreads=11
|
nthreads=4
|
||||||
memory=64
|
memory=32
|
||||||
on_login=${1}
|
on_login=${1}
|
||||||
queue="berg"
|
queue="berg"
|
||||||
env="/mnt/users/rstiskalek/csiborgtools/venv_csiborg/bin/python"
|
env="/mnt/users/rstiskalek/csiborgtools/venv_csiborg/bin/python"
|
||||||
file="field_los.py"
|
file="field_los.py"
|
||||||
|
|
||||||
catalogue="A2"
|
catalogue=${2}
|
||||||
# catalogue="csiborg1_9844"
|
|
||||||
nsims="-1"
|
nsims="-1"
|
||||||
simname="csiborg1"
|
simname="csiborg2_main"
|
||||||
MAS="SPH"
|
MAS="SPH"
|
||||||
grid=1024
|
grid=1024
|
||||||
|
|
||||||
|
|
|
@ -26,11 +26,11 @@ import jax
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from h5py import File
|
from h5py import File
|
||||||
from mpi4py import MPI
|
from mpi4py import MPI
|
||||||
from numpyro.infer import MCMC, NUTS
|
from numpyro.infer import MCMC, NUTS, init_to_sample
|
||||||
from taskmaster import work_delegation # noqa
|
from taskmaster import work_delegation # noqa
|
||||||
|
|
||||||
|
|
||||||
def get_model(args, nsim):
|
def get_model(args, nsim_iterator):
|
||||||
"""
|
"""
|
||||||
Load the data and create the NumPyro model.
|
Load the data and create the NumPyro model.
|
||||||
|
|
||||||
|
@ -38,8 +38,8 @@ def get_model(args, nsim):
|
||||||
----------
|
----------
|
||||||
args : argparse.Namespace
|
args : argparse.Namespace
|
||||||
Command line arguments.
|
Command line arguments.
|
||||||
nsim : int
|
nsim_iterator : int
|
||||||
Simulation index.
|
Simulation index, not the IC index. Ranges from 0, ... .
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
|
@ -49,7 +49,7 @@ def get_model(args, nsim):
|
||||||
if args.catalogue == "A2":
|
if args.catalogue == "A2":
|
||||||
fpath = join(folder, "A2.h5")
|
fpath = join(folder, "A2.h5")
|
||||||
elif args.catalogue == "LOSS" or args.catalogue == "Foundation":
|
elif args.catalogue == "LOSS" or args.catalogue == "Foundation":
|
||||||
raise NotImplementedError("To be implemented..")
|
fpath = join(folder, "PV_compilation_Supranta2019.hdf5")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown catalogue: `{args.catalogue}`.")
|
raise ValueError(f"Unknown catalogue: `{args.catalogue}`.")
|
||||||
|
|
||||||
|
@ -58,19 +58,51 @@ def get_model(args, nsim):
|
||||||
Omega_m = csiborgtools.simname2Omega_m(args.simname)
|
Omega_m = csiborgtools.simname2Omega_m(args.simname)
|
||||||
|
|
||||||
# Read in the data from the loader.
|
# Read in the data from the loader.
|
||||||
los_overdensity = loader.los_density[:, nsim, :]
|
los_overdensity = loader.los_density[:, nsim_iterator, :]
|
||||||
los_velocity = loader.los_radial_velocity[:, nsim, :]
|
los_velocity = loader.los_radial_velocity[:, nsim_iterator, :]
|
||||||
|
|
||||||
RA = loader.cat["RA"]
|
if args.catalogue == "A2":
|
||||||
dec = loader.cat["DEC"]
|
RA = loader.cat["RA"]
|
||||||
z_obs = loader.cat["z_obs"]
|
dec = loader.cat["DEC"]
|
||||||
|
z_obs = loader.cat["z_obs"]
|
||||||
|
|
||||||
r_hMpc = loader.cat["r_hMpc"]
|
r_hMpc = loader.cat["r_hMpc"]
|
||||||
e_r_hMpc = loader.cat["e_rhMpc"]
|
e_r_hMpc = loader.cat["e_rhMpc"]
|
||||||
|
|
||||||
return csiborgtools.flow.SD_PV_validation_model(
|
return csiborgtools.flow.SD_PV_validation_model(
|
||||||
los_overdensity, los_velocity, RA, dec, z_obs, r_hMpc, e_r_hMpc,
|
los_overdensity, los_velocity, RA, dec, z_obs, r_hMpc, e_r_hMpc,
|
||||||
loader.rdist, Omega_m)
|
loader.rdist, Omega_m)
|
||||||
|
elif args.catalogue == "LOSS" or args.catalogue == "Foundation":
|
||||||
|
RA = loader.cat["RA"]
|
||||||
|
dec = loader.cat["DEC"]
|
||||||
|
zCMB = loader.cat["z_CMB"]
|
||||||
|
|
||||||
|
mB = loader.cat["mB"]
|
||||||
|
x1 = loader.cat["x1"]
|
||||||
|
c = loader.cat["c"]
|
||||||
|
|
||||||
|
e_mB = loader.cat["e_mB"]
|
||||||
|
e_x1 = loader.cat["e_x1"]
|
||||||
|
e_c = loader.cat["e_c"]
|
||||||
|
|
||||||
|
return csiborgtools.flow.SN_PV_validation_model(
|
||||||
|
los_overdensity, los_velocity, RA, dec, zCMB, mB, x1, c,
|
||||||
|
e_mB, e_x1, e_c, loader.rdist, Omega_m)
|
||||||
|
elif args.catalogue in ["SFI_gals", "2MTF"]:
|
||||||
|
RA = loader.cat["RA"]
|
||||||
|
dec = loader.cat["DEC"]
|
||||||
|
zCMB = loader.cat["z_CMB"]
|
||||||
|
|
||||||
|
mag = loader.cat["mag"]
|
||||||
|
eta = loader.cat["eta"]
|
||||||
|
e_mag = loader.cat["e_mag"]
|
||||||
|
e_eta = loader.cat["e_eta"]
|
||||||
|
|
||||||
|
return csiborgtools.flow.TF_PV_validation_model(
|
||||||
|
los_overdensity, los_velocity, RA, dec, zCMB, mag, eta,
|
||||||
|
e_mag, e_eta, loader.rdist, Omega_m)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown catalogue: `{args.catalogue}`.")
|
||||||
|
|
||||||
|
|
||||||
def run_model(model, nsteps, nchains, nsim, dump_folder, show_progress=True):
|
def run_model(model, nsteps, nchains, nsim, dump_folder, show_progress=True):
|
||||||
|
@ -96,8 +128,8 @@ def run_model(model, nsteps, nchains, nsim, dump_folder, show_progress=True):
|
||||||
-------
|
-------
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
nuts_kernel = NUTS(model)
|
nuts_kernel = NUTS(model, init_strategy=init_to_sample)
|
||||||
mcmc = MCMC(nuts_kernel, num_warmup=nsteps // 2, num_samples=nsteps // 2,
|
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=nsteps,
|
||||||
chain_method="sequential", num_chains=nchains,
|
chain_method="sequential", num_chains=nchains,
|
||||||
progress_bar=show_progress)
|
progress_bar=show_progress)
|
||||||
rng_key = jax.random.PRNGKey(42)
|
rng_key = jax.random.PRNGKey(42)
|
||||||
|
@ -185,8 +217,8 @@ if __name__ == "__main__":
|
||||||
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
|
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
|
||||||
nsims = paths.get_ics(args.simname)
|
nsims = paths.get_ics(args.simname)
|
||||||
|
|
||||||
nsteps = 5000
|
nsteps = 2000
|
||||||
nchains = 1
|
nchains = 2
|
||||||
|
|
||||||
# Create the dumping folder.
|
# Create the dumping folder.
|
||||||
if comm.Get_rank() == 0:
|
if comm.Get_rank() == 0:
|
||||||
|
@ -198,12 +230,13 @@ if __name__ == "__main__":
|
||||||
dump_folder = None
|
dump_folder = None
|
||||||
dump_folder = comm.bcast(dump_folder, root=0)
|
dump_folder = comm.bcast(dump_folder, root=0)
|
||||||
|
|
||||||
def main(nsim):
|
def main(i):
|
||||||
model = get_model(args, nsim)
|
model = get_model(args, i)
|
||||||
run_model(model, nsteps, nchains, nsim, dump_folder,
|
run_model(model, nsteps, nchains, nsims[i], dump_folder,
|
||||||
show_progress=size == 1)
|
show_progress=size == 1)
|
||||||
|
|
||||||
work_delegation(main, nsims, comm, master_verbose=True)
|
work_delegation(main, [i for i in range(len(nsims))], comm,
|
||||||
|
master_verbose=True)
|
||||||
comm.Barrier()
|
comm.Barrier()
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
|
|
|
@ -1,13 +1,14 @@
|
||||||
memory=4
|
memory=4
|
||||||
on_login=${1}
|
on_login=${1}
|
||||||
nthreads=${2}
|
nthreads=${2}
|
||||||
|
ksmooth=${3}
|
||||||
|
|
||||||
queue="berg"
|
queue="berg"
|
||||||
env="/mnt/users/rstiskalek/csiborgtools/venv_csiborg/bin/python"
|
env="/mnt/users/rstiskalek/csiborgtools/venv_csiborg/bin/python"
|
||||||
file="flow_validation.py"
|
file="flow_validation.py"
|
||||||
|
|
||||||
catalogue="A2"
|
catalogue="Foundation"
|
||||||
simname="Carrick2015"
|
simname="csiborg2_random"
|
||||||
ksmooth=2
|
|
||||||
|
|
||||||
|
|
||||||
pythoncm="$env $file --catalogue $catalogue --simname $simname --ksmooth $ksmooth"
|
pythoncm="$env $file --catalogue $catalogue --simname $simname --ksmooth $ksmooth"
|
||||||
|
|
Loading…
Reference in a new issue