diff --git a/csiborgtools/flow/__init__.py b/csiborgtools/flow/__init__.py index 3c33d2d..0ee633b 100644 --- a/csiborgtools/flow/__init__.py +++ b/csiborgtools/flow/__init__.py @@ -13,7 +13,7 @@ # with this program; if not, write to the Free Software Foundation, Inc., # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. from .flow_model import (DataLoader, PV_LogLikelihood, PV_validation_model, # noqa - dist2distmodulus, dist2redshift, distmodulus2dist, # noqa - get_model, Observed2CosmologicalRedshift, # noqa - predict_zobs, project_Vext, # noqa - radial_velocity_los, stack_pzosmo_over_realizations) # noqa + dist2redshift, get_model, # noqa + Observed2CosmologicalRedshift, predict_zobs, # noqa + project_Vext, radial_velocity_los, # noqa + stack_pzosmo_over_realizations) # noqa diff --git a/csiborgtools/flow/flow_model.py b/csiborgtools/flow/flow_model.py index 9d9305b..3c48bd9 100644 --- a/csiborgtools/flow/flow_model.py +++ b/csiborgtools/flow/flow_model.py @@ -29,13 +29,10 @@ from astropy.cosmology import FlatLambdaCDM, z_at_value from h5py import File from jax import jit from jax import numpy as jnp -from jax import vmap -from jax.scipy.special import logsumexp -from numpyro import deterministic, factor, sample -from numpyro.distributions import Normal, Uniform +from jax.scipy.special import logsumexp, erf +from numpyro import factor, sample, plate +from numpyro.distributions import Normal, Uniform, MultivariateNormal from quadax import simpson -from scipy.interpolate import interp1d -from sklearn.model_selection import KFold from tqdm import trange from ..params import SPEED_OF_LIGHT, simname2Omega_m @@ -197,6 +194,8 @@ class DataLoader: if "Pantheon+" in catalogue: fpath = paths.field_los(simname, "Pantheon+") + elif "CF4_TFR" in catalogue: + fpath = paths.field_los(simname, "CF4_TFR") else: fpath = paths.field_los(simname, catalogue) @@ -261,34 +260,23 @@ class DataLoader: continue arr[key] = f[key][:] + elif catalogue in ["CF4_GroupAll"] or "CF4_TFR" in catalogue: + with File(catalogue_fpath, 'r') as f: + dtype = [(key, np.float32) for key in f.keys()] + dtype += [("DEC", np.float32)] + arr = np.empty(len(f["RA"]), dtype=dtype) + + for key in f.keys(): + arr[key] = f[key][:] + arr["DEC"] = arr["DE"] + + if "CF4_TFR" in catalogue: + arr["RA"] *= 360 / 24 else: raise ValueError(f"Unknown catalogue: `{catalogue}`.") return arr - def make_jackknife_mask(self, i, n_splits, seed=42): - """ - Set the internal jackknife mask to exclude the `i`-th split out of - `n_splits`. - """ - cv = KFold(n_splits=n_splits, shuffle=True, random_state=seed) - n = len(self._cat) - indxs = np.arange(n) - - gen = np.random.default_rng(seed) - gen.shuffle(indxs) - - for j, (train_index, __) in enumerate(cv.split(np.arange(n))): - if i == j: - self._mask = indxs[train_index] - return - - raise ValueError("The index `i` must be in the range of `n_splits`.") - - def reset_mask(self): - """Reset the jackknife mask.""" - self._mask = np.ones(len(self._cat), dtype=bool) - ############################################################################### # Supplementary flow functions # @@ -319,17 +307,6 @@ def radial_velocity_los(los_velocity, ra, dec): # JAX Flow model # ############################################################################### - -def lognorm_mean_std_to_loc_scale(mu, std): - """ - Calculate the location and scale parameters for the log-normal distribution - from the mean and standard deviation. - """ - loc = np.log(mu) - 0.5 * np.log(1 + (std / mu) ** 2) - scale = np.sqrt(np.log(1 + (std / mu) ** 2)) - return loc, scale - - def dist2redshift(dist, Omega_m): """ Convert comoving distance to cosmological redshift if the Universe is @@ -357,91 +334,6 @@ def gradient_redshift2dist(z, Omega_m): return SPEED_OF_LIGHT / H0 * (1 - z * (1 + q0)) -def dist2distmodulus(dist, Omega_m): - """Convert comoving distance to distance modulus, assuming z << 1.""" - zcosmo = dist2redshift(dist, Omega_m) - luminosity_distance = dist * (1 + zcosmo) - return 5 * jnp.log10(luminosity_distance) + 25 - - -def distmodulus2dist(mu, Omega_m, ninterp=10000, zmax=0.1, mu2comoving=None, - return_interpolator=False): - """ - Convert distance modulus to comoving distance. This is costly as it builds - up the interpolator every time it is called, unless it is provided. - - Parameters - ---------- - mu : float or 1-dimensional array - Distance modulus. - Omega_m : float - Matter density parameter. - ninterp : int, optional - Number of points to interpolate the mapping from distance modulus to - comoving distance. - zmax : float, optional - Maximum redshift for the interpolation. - mu2comoving : callable, optional - Interpolator from distance modulus to comoving distance. If not - provided, it is built up every time the function is called. - return_interpolator : bool, optional - Whether to return the interpolator as well. - - Returns - ------- - float (or 1-dimensional array) and callable (optional) - """ - if mu2comoving is None: - zrange = np.linspace(1e-15, zmax, ninterp) - cosmo = FlatLambdaCDM(H0=H0, Om0=Omega_m) - mu2comoving = interp1d( - cosmo.distmod(zrange).value, cosmo.comoving_distance(zrange).value, - kind="cubic") - - if return_interpolator: - return mu2comoving(mu), mu2comoving - - return mu2comoving(mu) - - -def distmodulus2redsfhit(mu, Omega_m, ninterp=10000, zmax=0.1, mu2z=None, - return_interpolator=False): - """ - Convert distance modulus to cosmological redshift. This is costly as it - builts up the interpolator every time it is called, unless it is provided. - - Parameters - ---------- - mu : float or 1-dimensional array - Distance modulus. - Omega_m : float - Matter density parameter. - ninterp : int, optional - Number of points to interpolate the mapping from distance modulus to - comoving distance. - zmax : float, optional - Maximum redshift for the interpolation. - mu2z : callable, optional - Interpolator from distance modulus to cosmological redsfhit. If not - provided, it is built up every time the function is called. - return_interpolator : bool, optional - Whether to return the interpolator as well. - - Returns - ------- - float (or 1-dimensional array) and callable (optional) - """ - if mu2z is None: - zrange = np.linspace(1e-15, zmax, ninterp) - cosmo = FlatLambdaCDM(H0=H0, Om0=Omega_m) - mu2z = interp1d(cosmo.distmod(zrange).value, zrange, kind="cubic") - - if return_interpolator: - return mu2z(mu), mu2z - - return mu2z(mu) - - def project_Vext(Vext_x, Vext_y, Vext_z, RA_radians, dec_radians): """Project the external velocity vector onto the line of sight.""" cos_dec = jnp.cos(dec_radians) @@ -465,21 +357,37 @@ def predict_zobs(dist, beta, Vext_radial, vpec_radial, Omega_m): ############################################################################### -def calculate_ptilde_wo_bias(xrange, mu, err_squared, r_squared_xrange): +def ptilde_wo_bias(xrange, mu, err_squared, r_squared_xrange): """Calculate `ptilde(r)` without imhomogeneous Malmquist bias.""" ptilde = jnp.exp(-0.5 * (xrange - mu)**2 / err_squared) + ptilde /= jnp.sqrt(2 * np.pi * err_squared) ptilde *= r_squared_xrange return ptilde -def calculate_likelihood_zobs(zobs, zobs_pred, e2_cz): +def likelihood_zobs(zobs, zobs_pred, e2_cz): """ Calculate the likelihood of the observed redshift given the predicted - redshift. + redshift. Multiplies the redshifts by the speed of light. """ - dcz = SPEED_OF_LIGHT * (zobs[:, None] - zobs_pred) + dcz = SPEED_OF_LIGHT * (zobs - zobs_pred) return jnp.exp(-0.5 * dcz**2 / e2_cz) / jnp.sqrt(2 * np.pi * e2_cz) + +def normal_logpdf(x, loc, scale): + """Log of the normal probability density function.""" + return (-0.5 * ((x - loc) / scale)**2 + - jnp.log(scale) - 0.5 * jnp.log(2 * jnp.pi)) + + +def upper_truncated_normal_logpdf(x, loc, scale, xmax): + """Log of the normal probability density function truncated at `xmax`.""" + # Need the absolute value just to avoid sometimes things going wrong, + # but it should never occur that loc > xmax. + norm = 0.5 * (1 + erf((jnp.abs(xmax - loc)) / (jnp.sqrt(2) * scale))) + return normal_logpdf(x, loc, scale) - jnp.log(norm) + + ############################################################################### # Base flow validation # ############################################################################### @@ -495,18 +403,22 @@ class BaseFlowValidationModel(ABC): names = [] values = [] for key, value in calibration_params.items(): + names.append(key) + values.append(value) + + # Store also the squared uncertainty if "e_" in key: key = key.replace("e_", "e2_") value = value**2 - - names.append(key) - values.append(value) + names.append(key) + values.append(value) self._setattr_as_jax(names, values) def _set_radial_spacing(self, r_xrange, Omega_m): cosmo = FlatLambdaCDM(H0=H0, Om0=Omega_m) + r_xrange = jnp.asarray(r_xrange) r2_xrange = r_xrange**2 r2_xrange /= r2_xrange.mean() self.r_xrange = r_xrange @@ -540,19 +452,30 @@ class BaseFlowValidationModel(ABC): pass +############################################################################### +# Sampling shortcuts # +############################################################################### + +def sample_alpha_bias(name, xmin, xmax, to_sample): + if to_sample: + return sample(f"alpha_{name}", Uniform(xmin, xmax)) + else: + return 1.0 + + ############################################################################### # SNIa parameters sampling # ############################################################################### -def distmod_SN(mB, x1, c, mag_cal, alpha_cal, beta_cal): +def distmod_SN(mag, x1, c, mag_cal, alpha_cal, beta_cal): """Distance modulus of a SALT2 SN Ia.""" - return mB - mag_cal + alpha_cal * x1 - beta_cal * c + return mag - mag_cal + alpha_cal * x1 - beta_cal * c -def e2_distmod_SN(e2_mB, e2_x1, e2_c, alpha_cal, beta_cal, e_mu_intrinsic): +def e2_distmod_SN(e2_mag, e2_x1, e2_c, alpha_cal, beta_cal, e_mu_intrinsic): """Squared error on the distance modulus of a SALT2 SN Ia.""" - return (e2_mB + alpha_cal**2 * e2_x1 + beta_cal**2 * e2_c + return (e2_mag + alpha_cal**2 * e2_x1 + beta_cal**2 * e2_c + e_mu_intrinsic**2) @@ -565,8 +488,7 @@ def sample_SN(e_mu_min, e_mu_max, mag_cal_mean, mag_cal_std, alpha_cal_mean, alpha_cal = sample( f"alpha_cal_{name}", Normal(alpha_cal_mean, alpha_cal_std)) beta_cal = sample(f"beta_cal_{name}", Normal(beta_cal_mean, beta_cal_std)) - alpha = sample(f"alpha_{name}", - Uniform(alpha_min, alpha_max)) if sample_alpha else 1.0 + alpha = sample_alpha_bias(name, alpha_min, alpha_max, sample_alpha) return {"e_mu": e_mu, "mag_cal": mag_cal, @@ -580,29 +502,74 @@ def sample_SN(e_mu_min, e_mu_max, mag_cal_mean, mag_cal_std, alpha_cal_mean, # Tully-Fisher parameters sampling # ############################################################################### -def distmod_TFR(mag, eta, a, b): +def distmod_TFR(mag, eta, a, b, c): """Distance modulus of a TFR calibration.""" - return mag - (a + b * eta) + return mag - (a + b * eta + c * eta**2) -def e2_distmod_TFR(e2_mag, e2_eta, b, e_mu_intrinsic): - """Squared error on the TFR distance modulus.""" - return e2_mag + b**2 * e2_eta + e_mu_intrinsic**2 +def e2_distmod_TFR(e2_mag, e2_eta, eta, b, c, e_mu_intrinsic): + """ + Squared error on the TFR distance modulus with linearly propagated + magnitude and linewidth uncertainties. + """ + return e2_mag + (b + 2 * c * eta)**2 * e2_eta + e_mu_intrinsic**2 -def sample_TFR(e_mu_min, e_mu_max, a_mean, a_std, b_mean, b_std, alpha_min, - alpha_max, sample_alpha, name): +def sample_TFR(e_mu_min, e_mu_max, a_mean, a_std, b_mean, b_std, + c_mean, c_std, alpha_min, alpha_max, sample_alpha, + sample_curvature, a_dipole_mean, a_dipole_std, sample_a_dipole, + name): """Sample Tully-Fisher calibration parameters.""" e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max)) a = sample(f"a_{name}", Normal(a_mean, a_std)) + + if sample_a_dipole: + ax, ay, az = sample(f"a_dipole_{name}", Normal(0, 5).expand([3])) + else: + ax, ay, az = 0.0, 0.0, 0.0 + b = sample(f"b_{name}", Normal(b_mean, b_std)) - alpha = sample(f"alpha_{name}", - Uniform(alpha_min, alpha_max)) if sample_alpha else 1.0 + if sample_curvature: + c = sample(f"c_{name}", Normal(c_mean, c_std)) + else: + c = 0.0 + + alpha = sample_alpha_bias(name, alpha_min, alpha_max, sample_alpha) return {"e_mu": e_mu, "a": a, + "ax": ax, "ay": ay, "az": az, "b": b, - "alpha": alpha + "c": c, + "alpha": alpha, + "sample_a_dipole": sample_a_dipole, + } + + +############################################################################### +# Simple calibration parameters sampling # +############################################################################### + +def sample_simple(e_mu_min, e_mu_max, dmu_min, dmu_max, alpha_min, alpha_max, + dmu_dipole_mean, dmu_dipole_std, sample_alpha, + sample_dmu_dipole, name): + """Sample simple calibration parameters.""" + e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max)) + dmu = sample(f"dmu_{name}", Uniform(dmu_min, dmu_max)) + alpha = sample_alpha_bias(name, alpha_min, alpha_max, sample_alpha) + + if sample_dmu_dipole: + dmux, dmuy, dmuz = sample( + f"dmu_dipole_{name}", + Normal(dmu_dipole_mean, dmu_dipole_std).expand([3])) + else: + dmux, dmuy, dmuz = 0.0, 0.0, 0.0 + + return {"e_mu": e_mu, + "dmu": dmu, + "dmux": dmux, "dmuy": dmuy, "dmuz": dmuz, + "alpha": alpha, + "sample_dmu_dipole": sample_dmu_dipole, } ############################################################################### @@ -612,19 +579,24 @@ def sample_TFR(e_mu_min, e_mu_max, a_mean, a_std, b_mean, b_std, alpha_min, def sample_calibration(Vext_min, Vext_max, Vmono_min, Vmono_max, beta_min, beta_max, sigma_v_min, sigma_v_max, sample_Vmono, - sample_beta, sample_sigma_v_ext): + sample_beta): """Sample the flow calibration.""" Vext = sample("Vext", Uniform(Vext_min, Vext_max).expand([3])) sigma_v = sample("sigma_v", Uniform(sigma_v_min, sigma_v_max)) - beta = sample("beta", Uniform(beta_min, beta_max)) if sample_beta else 1.0 # noqa - Vmono = sample("Vmono", Uniform(Vmono_min, Vmono_max)) if sample_Vmono else 0.0 # noqa - sigma_v_ext = sample("sigma_v_ext", Uniform(sigma_v_min, sigma_v_max)) if sample_sigma_v_ext else sigma_v # noqa + if sample_beta: + beta = sample("beta", Uniform(beta_min, beta_max)) + else: + beta = 1.0 + + if sample_Vmono: + Vmono = sample("Vmono", Uniform(Vmono_min, Vmono_max)) + else: + Vmono = 0.0 return {"Vext": Vext, "Vmono": Vmono, "sigma_v": sigma_v, - "sigma_v_ext": sigma_v_ext, "beta": beta} @@ -633,21 +605,11 @@ def sample_calibration(Vext_min, Vext_max, Vmono_min, Vmono_max, beta_min, ############################################################################### -def find_extrap_mask(rmax, rdist): - """ - Make a mask of shape `(nsim, ngal, nrdist)` of which velocity field values - are extrapolated. above which the - """ - nsim, ngal = rmax.shape - extrap_mask = np.zeros((nsim, ngal, len(rdist)), dtype=bool) - extrap_weights = np.ones((nsim, ngal, len(rdist))) - for i in range(nsim): - for j in range(ngal): - k = np.searchsorted(rdist, rmax[i, j]) - extrap_mask[i, j, k:] = True - extrap_weights[i, j, k:] = rmax[i, j] / rdist[k:] - - return extrap_mask, extrap_weights +def sample_gaussian_hyperprior(param, name, xmin, xmax): + """Sample MNR Gaussian hyperprior mean and standard deviation.""" + mean = sample(f"{param}_mean_{name}", Uniform(xmin, xmax)) + std = sample(f"{param}_std_{name}", Uniform(0.0, xmax - xmin)) + return mean, std class PV_LogLikelihood(BaseFlowValidationModel): @@ -671,14 +633,21 @@ class PV_LogLikelihood(BaseFlowValidationModel): Errors on the observed redshifts. calibration_params: dict Calibration parameters of each object. + magmax_selection : float + Maximum magnitude selection if strict threshold. r_xrange : 1-dimensional array Radial distances where the field was interpolated for each object. Omega_m : float Matter density parameter. + kind : str + Catalogue kind, either "TFR", "SN", or "simple". + name : str + Name of the catalogue. """ def __init__(self, los_density, los_velocity, rmax, RA, dec, z_obs, - e_zobs, calibration_params, r_xrange, Omega_m, kind, name): + e_zobs, calibration_params, maxmag_selection, r_xrange, + Omega_m, kind, name): if e_zobs is not None: e2_cz_obs = jnp.asarray((SPEED_OF_LIGHT * e_zobs)**2) else: @@ -699,68 +668,172 @@ class PV_LogLikelihood(BaseFlowValidationModel): self.name = name self.Omega_m = Omega_m self.norm = - self.ndata * jnp.log(self.num_sims) + self.maxmag_selection = maxmag_selection - extrap_mask, extrap_weights = find_extrap_mask(rmax, r_xrange) - self.extrap_mask = jnp.asarray(extrap_mask) - self.extrap_weights = jnp.asarray(extrap_weights) + if kind == "TFR": + self.mag_min, self.mag_max = jnp.min(self.mag), jnp.max(self.mag) + eta_mu = jnp.mean(self.eta) + fprint(f"setting the linewith mean to 0 instead of {eta_mu:.3f}.") + self.eta -= eta_mu + self.eta_min, self.eta_max = jnp.min(self.eta), jnp.max(self.eta) + elif kind == "SN": + self.mag_min, self.mag_max = jnp.min(self.mag), jnp.max(self.mag) + self.x1_min, self.x1_max = jnp.min(self.x1), jnp.max(self.x1) + self.c_min, self.c_max = jnp.min(self.c), jnp.max(self.c) + elif kind == "simple": + self.mu_min, self.mu_max = jnp.min(self.mu), jnp.max(self.mu) + else: + raise RuntimeError("Support most be added for other kinds.") + + if maxmag_selection is not None and self.maxmag_selection > self.mag_max: # noqa + raise ValueError("The maximum magnitude cannot be larger than the selection threshold.") # noqa def __call__(self, field_calibration_params, distmod_params, - sample_sigma_v_ext): - """PV validation model log-likelihood.""" - # Turn e2_cz to be of shape (nsims, ndata, nxrange) and apply - # sigma_v_ext where applicable - sigma_v = field_calibration_params["sigma_v"] - sigma_v_ext = field_calibration_params["sigma_v_ext"] - e2_cz = jnp.full_like(self.extrap_mask, sigma_v**2, dtype=jnp.float32) - if sample_sigma_v_ext: - e2_cz = e2_cz.at[self.extrap_mask].set(sigma_v_ext**2) + inference_method): + if inference_method not in ["mike", "bayes"]: + raise ValueError(f"Unknown method: `{inference_method}`.") - # Now add the observational errors - e2_cz += self.e2_cz_obs[None, :, None] + ll0 = 0.0 + sigma_v = field_calibration_params["sigma_v"] + e2_cz = self.e2_cz_obs + sigma_v**2 Vext = field_calibration_params["Vext"] Vmono = field_calibration_params["Vmono"] Vext_rad = project_Vext(Vext[0], Vext[1], Vext[2], self.RA, self.dec) e_mu = distmod_params["e_mu"] - alpha = distmod_params["alpha"] if self.kind == "SN": mag_cal = distmod_params["mag_cal"] alpha_cal = distmod_params["alpha_cal"] beta_cal = distmod_params["beta_cal"] + + if inference_method == "bayes": + mag_mean, mag_std = sample_gaussian_hyperprior( + "mag", self.name, self.mag_min, self.mag_max) + x1_mean, x1_std = sample_gaussian_hyperprior( + "x1", self.name, self.x1_min, self.x1_max) + c_mean, c_std = sample_gaussian_hyperprior( + "c", self.name, self.c_min, self.c_max) + + # NOTE: that the true variables are currently uncorrelated. + with plate("true_SN", self.ndata): + mag_true = sample( + f"mag_true_{self.name}", Normal(mag_mean, mag_std)) + x1_true = sample( + f"x1_true_{self.name}", Normal(x1_mean, x1_std)) + c_true = sample( + f"c_true_{self.name}", Normal(c_mean, c_std)) + + # Log-likelihood of the observed magnitudes. + if self.maxmag_selection is None: + ll0 += jnp.sum(normal_logpdf( + mag_true, self.mag, self.e_mag)) + else: + raise NotImplementedError("Maxmag selection not implemented.") # noqa + + # Log-likelihood of the observed x1 and c. + ll0 += jnp.sum(normal_logpdf(x1_true, self.x1, self.e_x1)) + ll0 += jnp.sum(normal_logpdf(c_true, self.c, self.e_c)) + e2_mu = jnp.ones_like(mag_true) * e_mu**2 + else: + mag_true = self.mag + x1_true = self.x1 + c_true = self.c + e2_mu = e2_distmod_SN( + self.e2_mag, self.e2_x1, self.e2_c, alpha_cal, beta_cal, + e_mu) + mu = distmod_SN( - self.mB, self.x1, self.c, mag_cal, alpha_cal, beta_cal) - squared_e_mu = e2_distmod_SN( - self.e2_mB, self.e2_x1, self.e2_c, alpha_cal, beta_cal, e_mu) + mag_true, x1_true, c_true, mag_cal, alpha_cal, beta_cal) elif self.kind == "TFR": a = distmod_params["a"] b = distmod_params["b"] - mu = distmod_TFR(self.mag, self.eta, a, b) - squared_e_mu = e2_distmod_TFR(self.e2_mag, self.e2_eta, b, e_mu) + c = distmod_params["c"] + + if distmod_params["sample_a_dipole"]: + ax, ay, az = (distmod_params[k] for k in ["ax", "ay", "az"]) + a = a + project_Vext(ax, ay, az, self.RA, self.dec) + + if inference_method == "bayes": + # Sample the true TFR parameters. + mag_mean, mag_std = sample_gaussian_hyperprior( + "mag", self.name, self.mag_min, self.mag_max) + eta_mean, eta_std = sample_gaussian_hyperprior( + "eta", self.name, self.eta_min, self.eta_max) + corr_mag_eta = sample("corr_mag_eta", Uniform(-1, 1)) + + loc = jnp.array([mag_mean, eta_mean]) + cov = jnp.array( + [[mag_std**2, corr_mag_eta * mag_std * eta_std], + [corr_mag_eta * mag_std * eta_std, eta_std**2]]) + + with plate("true_TFR", self.ndata): + x_true = sample("x_TFR", MultivariateNormal(loc, cov)) + + mag_true, eta_true = x_true[..., 0], x_true[..., 1] + # Log-likelihood of the observed magnitudes. + if self.maxmag_selection is None: + ll0 += jnp.sum(normal_logpdf( + self.mag, mag_true, self.e_mag)) + else: + ll0 += jnp.sum(upper_truncated_normal_logpdf( + self.mag, mag_true, self.e_mag, self.maxmag_selection)) + + # Log-likelihood of the observed linewidths. + ll0 += jnp.sum(normal_logpdf(eta_true, self.eta, self.e_eta)) + + e2_mu = jnp.ones_like(mag_true) * e_mu**2 + else: + eta_true = self.eta + mag_true = self.mag + e2_mu = e2_distmod_TFR( + self.e2_mag, self.e2_eta, eta_true, b, c, e_mu) + + mu = distmod_TFR(mag_true, eta_true, a, b, c) + elif self.kind == "simple": + dmu = distmod_params["dmu"] + + if distmod_params["sample_dmu_dipole"]: + dmux, dmuy, dmuz = ( + distmod_params[k] for k in ["dmux", "dmuy", "dmuz"]) + dmu = dmu + project_Vext(dmux, dmuy, dmuz, self.RA, self.dec) + + if inference_method == "bayes": + raise NotImplementedError("Bayes for simple not implemented.") + else: + mu_true = self.mu + e2_mu = e_mu**2 + self.e2_mu + + mu = mu_true + dmu else: raise ValueError(f"Unknown kind: `{self.kind}`.") # Calculate p(r) (Malmquist bias). Shape is (ndata, nxrange) - ptilde = jnp.transpose(vmap(calculate_ptilde_wo_bias, in_axes=(0, None, None, 0))(self.mu_xrange, mu, squared_e_mu, self.r2_xrange)) # noqa + ptilde = ptilde_wo_bias( + self.mu_xrange[None, :], mu[:, None], e2_mu[:, None], + self.r2_xrange[None, :]) # Inhomogeneous Malmquist bias. Shape is (n_sims, ndata, nxrange) - ptilde = self.los_density**alpha * ptilde + alpha = distmod_params["alpha"] + ptilde = ptilde[None, ...] * self.los_density**alpha + # Normalization of p(r). Shape is (n_sims, ndata) pnorm = simpson(ptilde, dx=self.dr, axis=-1) # Calculate z_obs at each distance. Shape is (n_sims, ndata, nxrange) - # The weights are related to the extrapolation of the velocity field. vrad = field_calibration_params["beta"] * self.los_velocity - vrad += (Vext_rad[None, :, None] + Vmono) * self.extrap_weights - zobs = (1 + self.z_xrange[None, None, :]) * (1 + vrad / SPEED_OF_LIGHT) - 1 # noqa + vrad += (Vext_rad[None, :, None] + Vmono) + zobs = (1 + self.z_xrange[None, None, :]) * (1 + vrad / SPEED_OF_LIGHT) + zobs -= 1. + + ptilde *= likelihood_zobs( + self.z_obs[None, :, None], zobs, e2_cz[None, :, None]) - ptilde *= calculate_likelihood_zobs(self.z_obs, zobs, e2_cz) ll = jnp.log(simpson(ptilde, dx=self.dr, axis=-1)) - jnp.log(pnorm) - - return jnp.sum(logsumexp(ll, axis=0)) + self.norm + return ll0 + jnp.sum(logsumexp(ll, axis=0)) + self.norm def PV_validation_model(models, distmod_hyperparams_per_model, - field_calibration_hyperparams): + field_calibration_hyperparams, inference_method): """ Peculiar velocity validation NumPyro model. @@ -772,27 +845,29 @@ def PV_validation_model(models, distmod_hyperparams_per_model, Distance modulus hyperparameters for each model/catalogue. field_calibration_hyperparams : dict Field calibration hyperparameters. + inference_method : str + Either `mike` or `bayes`. """ field_calibration_params = sample_calibration( **field_calibration_hyperparams) - sample_sigma_v_ext = field_calibration_hyperparams["sample_sigma_v_ext"] ll = 0.0 for n in range(len(models)): model = models[n] + name = model.name distmod_hyperparams = distmod_hyperparams_per_model[n] if model.kind == "TFR": - distmod_params = sample_TFR(**distmod_hyperparams, name=model.name) + distmod_params = sample_TFR(**distmod_hyperparams, name=name) elif model.kind == "SN": - distmod_params = sample_SN(**distmod_hyperparams, name=model.name) + distmod_params = sample_SN(**distmod_hyperparams, name=name) + elif model.kind == "simple": + distmod_params = sample_simple(**distmod_hyperparams, name=name) else: raise ValueError(f"Unknown kind: `{model.kind}`.") - ll += model( - field_calibration_params, distmod_params, sample_sigma_v_ext) + ll += model(field_calibration_params, distmod_params, inference_method) - deterministic("ll_values", ll) factor("ll", ll) @@ -801,7 +876,7 @@ def PV_validation_model(models, distmod_hyperparams_per_model, ############################################################################### -def get_model(loader, zcmb_max=None, verbose=True): +def get_model(loader, zcmb_min=0.0, zcmb_max=None, maxmag_selection=None): """ Get a model and extract the relevant data from the loader. @@ -809,10 +884,12 @@ def get_model(loader, zcmb_max=None, verbose=True): ---------- loader : DataLoader DataLoader instance. + zcmb_min : float, optional + Minimum observed redshift in the CMB frame to include. zcmb_max : float, optional Maximum observed redshift in the CMB frame to include. - verbose : bool, optional - Verbosity flag. + maxmag_selection : float, optional + Maximum magnitude selection threshold. Returns ------- @@ -825,20 +902,24 @@ def get_model(loader, zcmb_max=None, verbose=True): rmax = loader.rmax kind = loader._catname + if maxmag_selection is not None and kind != "2MTF": + raise ValueError("Threshold magnitude selection implemented only for 2MTF.") # noqa + if kind in ["LOSS", "Foundation"]: keys = ["RA", "DEC", "z_CMB", "mB", "x1", "c", "e_mB", "e_x1", "e_c"] - RA, dec, zCMB, mB, x1, c, e_mB, e_x1, e_c = (loader.cat[k] for k in keys) # noqa + RA, dec, zCMB, mag, x1, c, e_mag, e_x1, e_c = ( + loader.cat[k] for k in keys) e_zCMB = None - mask = (zCMB < zcmb_max) - calibration_params = {"mB": mB[mask], "x1": x1[mask], "c": c[mask], - "e_mB": e_mB[mask], "e_x1": e_x1[mask], + mask = (zCMB < zcmb_max) & (zCMB > zcmb_min) + calibration_params = {"mag": mag[mask], "x1": x1[mask], "c": c[mask], + "e_mag": e_mag[mask], "e_x1": e_x1[mask], "e_c": e_c[mask]} model = PV_LogLikelihood( los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask], RA[mask], dec[mask], zCMB[mask], e_zCMB, calibration_params, - loader.rdist, loader._Omega_m, "SN", name=kind) + maxmag_selection, loader.rdist, loader._Omega_m, "SN", name=kind) elif "Pantheon+" in kind: keys = ["RA", "DEC", "zCMB", "mB", "x1", "c", "biasCor_m_b", "mBERR", "x1ERR", "cERR", "biasCorErr_m_b", "zCMB_SN", "zCMB_Group", @@ -848,7 +929,7 @@ def get_model(loader, zcmb_max=None, verbose=True): mB -= bias_corr_mB e_mB = np.sqrt(e_mB**2 + e_bias_corr_mB**2) - mask = zCMB < zcmb_max + mask = (zCMB < zcmb_max) & (zCMB > zcmb_min) if kind == "Pantheon+_groups": mask &= np.isfinite(zCMB_Group) @@ -860,29 +941,80 @@ def get_model(loader, zcmb_max=None, verbose=True): if kind == "Pantheon+_zSN": zCMB = zCMB_SN - calibration_params = {"mB": mB[mask], "x1": x1[mask], "c": c[mask], - "e_mB": e_mB[mask], "e_x1": e_x1[mask], + calibration_params = {"mag": mB[mask], "x1": x1[mask], "c": c[mask], + "e_mag": e_mB[mask], "e_x1": e_x1[mask], "e_c": e_c[mask]} model = PV_LogLikelihood( los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask], RA[mask], dec[mask], zCMB[mask], e_zCMB[mask], calibration_params, - loader.rdist, loader._Omega_m, "SN", name=kind) + maxmag_selection, loader.rdist, loader._Omega_m, "SN", name=kind) elif kind in ["SFI_gals", "2MTF", "SFI_gals_masked"]: keys = ["RA", "DEC", "z_CMB", "mag", "eta", "e_mag", "e_eta"] RA, dec, zCMB, mag, eta, e_mag, e_eta = (loader.cat[k] for k in keys) - mask = (zCMB < zcmb_max) - if kind == "SFI_gals": - mask &= (eta > -0.15) & (eta < 0.2) - if verbose: - print("Emplyed eta cut for SFI galaxies.", flush=True) - + mask = (zCMB < zcmb_max) & (zCMB > zcmb_min) calibration_params = {"mag": mag[mask], "eta": eta[mask], "e_mag": e_mag[mask], "e_eta": e_eta[mask]} model = PV_LogLikelihood( los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask], RA[mask], dec[mask], zCMB[mask], None, calibration_params, - loader.rdist, loader._Omega_m, "TFR", name=kind) + maxmag_selection, loader.rdist, loader._Omega_m, "TFR", name=kind) + elif "CF4_TFR_" in kind: + # The full name can be e.g. "CF4_TFR_not2MTForSFI_i" or "CF4_TFR_i". + band = kind.split("_")[-1] + if band not in ['g', 'r', 'i', 'z', 'w1', 'w2']: + raise ValueError(f"Band `{band}` not recognized.") + + keys = ["RA", "DEC", "Vcmb", f"{band}", "lgWmxi", "elgWi", + "not_matched_to_2MTF_or_SFI", "Qs", "Qw"] + RA, dec, z_obs, mag, eta, e_eta, not_matched_to_2MTF_or_SFI, Qs, Qw = ( + loader.cat[k] for k in keys) + + not_matched_to_2MTF_or_SFI = not_matched_to_2MTF_or_SFI.astype(bool) + # NOTE: fiducial uncertainty until we can get the actual values. + e_mag = 0.001 * np.ones_like(mag) + + z_obs /= SPEED_OF_LIGHT + eta -= 2.5 + + fprint("selecting only galaxies with mag > 5 and eta > -0.3.") + mask = (mag > 5) & (eta > -0.3) + mask &= (z_obs < zcmb_max) & (z_obs > zcmb_min) + + if "not2MTForSFI" in kind: + mask &= not_matched_to_2MTF_or_SFI + elif "2MTForSFI" in kind: + mask &= ~not_matched_to_2MTF_or_SFI + + fprint("employing a quality cut on the galaxies.") + if "w" in band: + mask &= Qw == 5 + else: + mask &= Qs == 5 + + calibration_params = {"mag": mag[mask], "eta": eta[mask], + "e_mag": e_mag[mask], "e_eta": e_eta[mask]} + model = PV_LogLikelihood( + los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask], + RA[mask], dec[mask], z_obs[mask], None, calibration_params, + maxmag_selection, loader.rdist, loader._Omega_m, "TFR", name=kind) + elif kind in ["CF4_GroupAll"]: + # Note, this for some reason works terribly. + keys = ["RA", "DE", "Vcmb", "DMzp", "eDM"] + RA, dec, zCMB, mu, e_mu = (loader.cat[k] for k in keys) + + zCMB /= SPEED_OF_LIGHT + mask = (zCMB < zcmb_max) & (zCMB > zcmb_min) & np.isfinite(mu) + + # The distance moduli in CF4 are most likely given assuming h = 0.75 + mu += 5 * np.log10(0.75) + + calibration_params = {"mu": mu[mask], "e_mu": e_mu[mask]} + model = PV_LogLikelihood( + los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask], + RA[mask], dec[mask], zCMB[mask], None, calibration_params, + maxmag_selection, loader.rdist, loader._Omega_m, "simple", + name=kind) else: raise ValueError(f"Catalogue `{kind}` not recognized.") diff --git a/csiborgtools/utils.py b/csiborgtools/utils.py index c986457..2d129db 100644 --- a/csiborgtools/utils.py +++ b/csiborgtools/utils.py @@ -469,10 +469,9 @@ def BIC_AIC(samples, log_likelihood, ndata): for val in samples.values(): if val.ndim == 1: nparam += 1 - elif val.ndim == 2: - nparam += val.shape[-1] else: - raise ValueError("Invalid dimensionality of samples to count the number of parameters.") # noqa + # The first dimension is the number of steps. + nparam += np.prod(val.shape[1:]) BIC = nparam * np.log(ndata) - 2 * log_likelihood[kmax] AIC = 2 * nparam - 2 * log_likelihood[kmax] diff --git a/scripts/field_prop/field_los.py b/scripts/field_prop/field_los.py index 2c6122a..9096113 100644 --- a/scripts/field_prop/field_los.py +++ b/scripts/field_prop/field_los.py @@ -15,6 +15,7 @@ MPI script to interpolate the density and velocity fields along the line of sight. """ +import sys from argparse import ArgumentParser from datetime import datetime from gc import collect @@ -32,7 +33,8 @@ from mpi4py import MPI from numba import jit from taskmaster import work_delegation # noqa -from utils import get_nsims +sys.path.append("../") +from utils import get_nsims # noqa ############################################################################### # I/O functions # @@ -84,8 +86,18 @@ def get_los(catalogue_name, simname, comm): with File(fname, 'r') as f: RA = f["RA"][:] dec = f["DEC"][:] + elif catalogue_name == "CF4_GroupAll": + fname = "/mnt/extraspace/rstiskalek/catalogs/PV/CF4/CF4_GroupAll.hdf5" # noqa + with File(fname, 'r') as f: + RA = f["RA"][:] + dec = f["DE"][:] + elif catalogue_name == "CF4_TFR": + fname = "/mnt/extraspace/rstiskalek/catalogs/PV/CF4/CF4_TF-distances.hdf5" # noqa + with File(fname, 'r') as f: + RA = f["RA"][:] * 360 / 24 # Convert to degrees from hours. + dec = f["DE"][:] else: - raise ValueError(f"Unknown field name: `{catalogue_name}`.") + raise ValueError(f"Unknown catalogue name: `{catalogue_name}`.") if comm.Get_rank() == 0: print(f"The dataset contains {len(RA)} objects.") diff --git a/scripts/field_prop/field_los.sh b/scripts/field_prop/field_los.sh index 16cd0b9..2fe3037 100755 --- a/scripts/field_prop/field_los.sh +++ b/scripts/field_prop/field_los.sh @@ -10,8 +10,8 @@ MAS="SPH" grid=1024 -for simname in "CF4"; do - for catalogue in "Foundation"; do +for simname in "Carrick2015"; do + for catalogue in "CF4_TFR"; do pythoncm="$env $file --catalogue $catalogue --nsims $nsims --simname $simname --MAS $MAS --grid $grid" if [ $on_login -eq 1 ]; then echo $pythoncm diff --git a/scripts/flow/flow_validation.py b/scripts/flow/flow_validation.py index d5d262c..c159123 100644 --- a/scripts/flow/flow_validation.py +++ b/scripts/flow/flow_validation.py @@ -100,6 +100,10 @@ def get_models(get_model_kwargs, verbose=True): "Pantheon+_groups", "Pantheon+_groups_zSN", "Pantheon+_zSN"]: fpath = join(folder, "PV_compilation.hdf5") + elif "CF4_TFR" in cat: + fpath = join(folder, "PV/CF4/CF4_TF-distances.hdf5") + elif cat in ["CF4_GroupAll"]: + fpath = join(folder, "PV/CF4/CF4_GroupAll.hdf5") else: raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.") @@ -139,11 +143,7 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta, samples = mcmc.get_samples() log_posterior = -mcmc.get_extra_fields()["potential_energy"] - log_likelihood = samples.pop("ll_values") - if log_likelihood is None: - raise ValueError("The samples must contain the log likelihood values under the key `ll_values`.") # noqa - - BIC, AIC = csiborgtools.BIC_AIC(samples, log_likelihood, ndata) + BIC, AIC = csiborgtools.BIC_AIC(samples, log_posterior, ndata) print(f"{'BIC':<20} {BIC}") print(f"{'AIC':<20} {AIC}") mcmc.print_summary() @@ -174,7 +174,6 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta, grp.create_dataset(key, data=value) # Write log likelihood and posterior - f.create_dataset("log_likelihood", data=log_likelihood) f.create_dataset("log_posterior", data=log_posterior) # Write goodness of fit @@ -207,10 +206,9 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta, # Command line interface # ############################################################################### -def get_distmod_hyperparams(catalogue): +def get_distmod_hyperparams(catalogue, sample_alpha): alpha_min = -1.0 alpha_max = 3.0 - sample_alpha = True if catalogue in ["LOSS", "Foundation", "Pantheon+", "Pantheon+_groups", "Pantheon+_zSN"]: # noqa return {"e_mu_min": 0.001, "e_mu_max": 1.0, @@ -220,12 +218,24 @@ def get_distmod_hyperparams(catalogue): "alpha_min": alpha_min, "alpha_max": alpha_max, "sample_alpha": sample_alpha } - elif catalogue in ["SFI_gals", "2MTF"]: + elif catalogue in ["SFI_gals", "2MTF"] or "CF4_TFR" in catalogue: return {"e_mu_min": 0.001, "e_mu_max": 1.0, "a_mean": -21., "a_std": 5.0, - "b_mean": -5.95, "b_std": 3.0, + "b_mean": -5.95, "b_std": 4.0, + "c_mean": 0., "c_std": 20.0, + "sample_curvature": False, + "a_dipole_mean": 0., "a_dipole_std": 1.0, + "sample_a_dipole": True, "alpha_min": alpha_min, "alpha_max": alpha_max, - "sample_alpha": sample_alpha + "sample_alpha": sample_alpha, + } + elif catalogue in ["CF4_GroupAll"]: + return {"e_mu_min": 0.001, "e_mu_max": 1.0, + "dmu_min": -3.0, "dmu_max": 3.0, + "dmu_dipole_mean": 0., "dmu_dipole_std": 1.0, + "sample_dmu_dipole": True, + "alpha_min": alpha_min, "alpha_max": alpha_max, + "sample_alpha": sample_alpha, } else: raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.") @@ -241,37 +251,46 @@ if __name__ == "__main__": # Fixed user parameters # ########################################################################### - nsteps = 5000 - nburn = 1500 + nsteps = 1500 + nburn = 1000 + zcmb_min = 0 zcmb_max = 0.05 calculate_evidence = False nchains_harmonic = 10 num_epochs = 30 + inference_method = "mike" + maxmag_selection = None + sample_alpha = True + sample_beta = True + sample_Vmono = False if nsteps % nchains_harmonic != 0: raise ValueError( "The number of steps must be divisible by the number of chains.") - main_params = {"nsteps": nsteps, "nburn": nburn, "zcmb_max": zcmb_max, + main_params = {"nsteps": nsteps, "nburn": nburn, + "zcmb_min": zcmb_min, + "zcmb_max": zcmb_max, + "maxmag_selection": maxmag_selection, "calculate_evidence": calculate_evidence, "nchains_harmonic": nchains_harmonic, - "num_epochs": num_epochs} + "num_epochs": num_epochs, + "inference_method": inference_method} print_variables(main_params.keys(), main_params.values()) calibration_hyperparams = {"Vext_min": -1000, "Vext_max": 1000, "Vmono_min": -1000, "Vmono_max": 1000, "beta_min": -1.0, "beta_max": 3.0, "sigma_v_min": 1.0, "sigma_v_max": 750., - "sample_Vmono": False, - "sample_beta": True, - "sample_sigma_v_ext": False, + "sample_Vmono": sample_Vmono, + "sample_beta": sample_beta, } print_variables( calibration_hyperparams.keys(), calibration_hyperparams.values()) distmod_hyperparams_per_catalogue = [] for cat in ARGS.catalogue: - x = get_distmod_hyperparams(cat) + x = get_distmod_hyperparams(cat, sample_alpha) print(f"\n{cat} hyperparameters:") print_variables(x.keys(), x.values()) distmod_hyperparams_per_catalogue.append(x) @@ -280,12 +299,14 @@ if __name__ == "__main__": *distmod_hyperparams_per_catalogue) ########################################################################### - get_model_kwargs = {"zcmb_max": zcmb_max} + get_model_kwargs = {"zcmb_min": zcmb_min, "zcmb_max": zcmb_max, + "maxmag_selection": maxmag_selection} models = get_models(get_model_kwargs, ) model_kwargs = { "models": models, "field_calibration_hyperparams": calibration_hyperparams, "distmod_hyperparams_per_model": distmod_hyperparams_per_catalogue, + "inference_method": inference_method, } model = csiborgtools.flow.PV_validation_model diff --git a/scripts/flow/flow_validation.sh b/scripts/flow/flow_validation.sh index 1b71307..366bd3f 100755 --- a/scripts/flow/flow_validation.sh +++ b/scripts/flow/flow_validation.sh @@ -39,7 +39,8 @@ fi # for simname in "Lilow2024" "CF4" "CF4gp" "csiborg1" "csiborg2_main" "csiborg2X"; do for simname in "Carrick2015"; do - for catalogue in "LOSS,2MTF,SFI_gals"; do + for catalogue in "CF4_GroupAll"; do + # for catalogue in "CF4_TFR_i"; do # for ksim in 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20; do for ksim in "none"; do pythoncm="$env $file --catalogue $catalogue --simname $simname --ksim $ksim --ksmooth $ksmooth --ndevice $ndevice --device $device"