Add CF4 and more improvements (#141)

* Update params counting

* Update imports

* Add CF4 group

* Update submit

* Update submit

* Many updates

* Many more updates

* Add CF4 TFR

* Add CF4 TF

* Fix RA bug in CF4 TF

* Add CF4 quality cut

* Start sampling alpha

* Update scripts

* Some comments

* Update script

* Add option to have magnitude selection.

* Add calibration dipoles
This commit is contained in:
Richard Stiskalek 2024-08-25 17:03:51 +02:00 committed by GitHub
parent d578c71b83
commit d13246a394
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 420 additions and 255 deletions

View file

@ -13,7 +13,7 @@
# with this program; if not, write to the Free Software Foundation, Inc., # with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from .flow_model import (DataLoader, PV_LogLikelihood, PV_validation_model, # noqa from .flow_model import (DataLoader, PV_LogLikelihood, PV_validation_model, # noqa
dist2distmodulus, dist2redshift, distmodulus2dist, # noqa dist2redshift, get_model, # noqa
get_model, Observed2CosmologicalRedshift, # noqa Observed2CosmologicalRedshift, predict_zobs, # noqa
predict_zobs, project_Vext, # noqa project_Vext, radial_velocity_los, # noqa
radial_velocity_los, stack_pzosmo_over_realizations) # noqa stack_pzosmo_over_realizations) # noqa

View file

@ -29,13 +29,10 @@ from astropy.cosmology import FlatLambdaCDM, z_at_value
from h5py import File from h5py import File
from jax import jit from jax import jit
from jax import numpy as jnp from jax import numpy as jnp
from jax import vmap from jax.scipy.special import logsumexp, erf
from jax.scipy.special import logsumexp from numpyro import factor, sample, plate
from numpyro import deterministic, factor, sample from numpyro.distributions import Normal, Uniform, MultivariateNormal
from numpyro.distributions import Normal, Uniform
from quadax import simpson from quadax import simpson
from scipy.interpolate import interp1d
from sklearn.model_selection import KFold
from tqdm import trange from tqdm import trange
from ..params import SPEED_OF_LIGHT, simname2Omega_m from ..params import SPEED_OF_LIGHT, simname2Omega_m
@ -197,6 +194,8 @@ class DataLoader:
if "Pantheon+" in catalogue: if "Pantheon+" in catalogue:
fpath = paths.field_los(simname, "Pantheon+") fpath = paths.field_los(simname, "Pantheon+")
elif "CF4_TFR" in catalogue:
fpath = paths.field_los(simname, "CF4_TFR")
else: else:
fpath = paths.field_los(simname, catalogue) fpath = paths.field_los(simname, catalogue)
@ -261,34 +260,23 @@ class DataLoader:
continue continue
arr[key] = f[key][:] arr[key] = f[key][:]
elif catalogue in ["CF4_GroupAll"] or "CF4_TFR" in catalogue:
with File(catalogue_fpath, 'r') as f:
dtype = [(key, np.float32) for key in f.keys()]
dtype += [("DEC", np.float32)]
arr = np.empty(len(f["RA"]), dtype=dtype)
for key in f.keys():
arr[key] = f[key][:]
arr["DEC"] = arr["DE"]
if "CF4_TFR" in catalogue:
arr["RA"] *= 360 / 24
else: else:
raise ValueError(f"Unknown catalogue: `{catalogue}`.") raise ValueError(f"Unknown catalogue: `{catalogue}`.")
return arr return arr
def make_jackknife_mask(self, i, n_splits, seed=42):
"""
Set the internal jackknife mask to exclude the `i`-th split out of
`n_splits`.
"""
cv = KFold(n_splits=n_splits, shuffle=True, random_state=seed)
n = len(self._cat)
indxs = np.arange(n)
gen = np.random.default_rng(seed)
gen.shuffle(indxs)
for j, (train_index, __) in enumerate(cv.split(np.arange(n))):
if i == j:
self._mask = indxs[train_index]
return
raise ValueError("The index `i` must be in the range of `n_splits`.")
def reset_mask(self):
"""Reset the jackknife mask."""
self._mask = np.ones(len(self._cat), dtype=bool)
############################################################################### ###############################################################################
# Supplementary flow functions # # Supplementary flow functions #
@ -319,17 +307,6 @@ def radial_velocity_los(los_velocity, ra, dec):
# JAX Flow model # # JAX Flow model #
############################################################################### ###############################################################################
def lognorm_mean_std_to_loc_scale(mu, std):
"""
Calculate the location and scale parameters for the log-normal distribution
from the mean and standard deviation.
"""
loc = np.log(mu) - 0.5 * np.log(1 + (std / mu) ** 2)
scale = np.sqrt(np.log(1 + (std / mu) ** 2))
return loc, scale
def dist2redshift(dist, Omega_m): def dist2redshift(dist, Omega_m):
""" """
Convert comoving distance to cosmological redshift if the Universe is Convert comoving distance to cosmological redshift if the Universe is
@ -357,91 +334,6 @@ def gradient_redshift2dist(z, Omega_m):
return SPEED_OF_LIGHT / H0 * (1 - z * (1 + q0)) return SPEED_OF_LIGHT / H0 * (1 - z * (1 + q0))
def dist2distmodulus(dist, Omega_m):
"""Convert comoving distance to distance modulus, assuming z << 1."""
zcosmo = dist2redshift(dist, Omega_m)
luminosity_distance = dist * (1 + zcosmo)
return 5 * jnp.log10(luminosity_distance) + 25
def distmodulus2dist(mu, Omega_m, ninterp=10000, zmax=0.1, mu2comoving=None,
return_interpolator=False):
"""
Convert distance modulus to comoving distance. This is costly as it builds
up the interpolator every time it is called, unless it is provided.
Parameters
----------
mu : float or 1-dimensional array
Distance modulus.
Omega_m : float
Matter density parameter.
ninterp : int, optional
Number of points to interpolate the mapping from distance modulus to
comoving distance.
zmax : float, optional
Maximum redshift for the interpolation.
mu2comoving : callable, optional
Interpolator from distance modulus to comoving distance. If not
provided, it is built up every time the function is called.
return_interpolator : bool, optional
Whether to return the interpolator as well.
Returns
-------
float (or 1-dimensional array) and callable (optional)
"""
if mu2comoving is None:
zrange = np.linspace(1e-15, zmax, ninterp)
cosmo = FlatLambdaCDM(H0=H0, Om0=Omega_m)
mu2comoving = interp1d(
cosmo.distmod(zrange).value, cosmo.comoving_distance(zrange).value,
kind="cubic")
if return_interpolator:
return mu2comoving(mu), mu2comoving
return mu2comoving(mu)
def distmodulus2redsfhit(mu, Omega_m, ninterp=10000, zmax=0.1, mu2z=None,
return_interpolator=False):
"""
Convert distance modulus to cosmological redshift. This is costly as it
builts up the interpolator every time it is called, unless it is provided.
Parameters
----------
mu : float or 1-dimensional array
Distance modulus.
Omega_m : float
Matter density parameter.
ninterp : int, optional
Number of points to interpolate the mapping from distance modulus to
comoving distance.
zmax : float, optional
Maximum redshift for the interpolation.
mu2z : callable, optional
Interpolator from distance modulus to cosmological redsfhit. If not
provided, it is built up every time the function is called.
return_interpolator : bool, optional
Whether to return the interpolator as well.
Returns
-------
float (or 1-dimensional array) and callable (optional)
"""
if mu2z is None:
zrange = np.linspace(1e-15, zmax, ninterp)
cosmo = FlatLambdaCDM(H0=H0, Om0=Omega_m)
mu2z = interp1d(cosmo.distmod(zrange).value, zrange, kind="cubic")
if return_interpolator:
return mu2z(mu), mu2z
return mu2z(mu)
def project_Vext(Vext_x, Vext_y, Vext_z, RA_radians, dec_radians): def project_Vext(Vext_x, Vext_y, Vext_z, RA_radians, dec_radians):
"""Project the external velocity vector onto the line of sight.""" """Project the external velocity vector onto the line of sight."""
cos_dec = jnp.cos(dec_radians) cos_dec = jnp.cos(dec_radians)
@ -465,21 +357,37 @@ def predict_zobs(dist, beta, Vext_radial, vpec_radial, Omega_m):
############################################################################### ###############################################################################
def calculate_ptilde_wo_bias(xrange, mu, err_squared, r_squared_xrange): def ptilde_wo_bias(xrange, mu, err_squared, r_squared_xrange):
"""Calculate `ptilde(r)` without imhomogeneous Malmquist bias.""" """Calculate `ptilde(r)` without imhomogeneous Malmquist bias."""
ptilde = jnp.exp(-0.5 * (xrange - mu)**2 / err_squared) ptilde = jnp.exp(-0.5 * (xrange - mu)**2 / err_squared)
ptilde /= jnp.sqrt(2 * np.pi * err_squared)
ptilde *= r_squared_xrange ptilde *= r_squared_xrange
return ptilde return ptilde
def calculate_likelihood_zobs(zobs, zobs_pred, e2_cz): def likelihood_zobs(zobs, zobs_pred, e2_cz):
""" """
Calculate the likelihood of the observed redshift given the predicted Calculate the likelihood of the observed redshift given the predicted
redshift. redshift. Multiplies the redshifts by the speed of light.
""" """
dcz = SPEED_OF_LIGHT * (zobs[:, None] - zobs_pred) dcz = SPEED_OF_LIGHT * (zobs - zobs_pred)
return jnp.exp(-0.5 * dcz**2 / e2_cz) / jnp.sqrt(2 * np.pi * e2_cz) return jnp.exp(-0.5 * dcz**2 / e2_cz) / jnp.sqrt(2 * np.pi * e2_cz)
def normal_logpdf(x, loc, scale):
"""Log of the normal probability density function."""
return (-0.5 * ((x - loc) / scale)**2
- jnp.log(scale) - 0.5 * jnp.log(2 * jnp.pi))
def upper_truncated_normal_logpdf(x, loc, scale, xmax):
"""Log of the normal probability density function truncated at `xmax`."""
# Need the absolute value just to avoid sometimes things going wrong,
# but it should never occur that loc > xmax.
norm = 0.5 * (1 + erf((jnp.abs(xmax - loc)) / (jnp.sqrt(2) * scale)))
return normal_logpdf(x, loc, scale) - jnp.log(norm)
############################################################################### ###############################################################################
# Base flow validation # # Base flow validation #
############################################################################### ###############################################################################
@ -495,18 +403,22 @@ class BaseFlowValidationModel(ABC):
names = [] names = []
values = [] values = []
for key, value in calibration_params.items(): for key, value in calibration_params.items():
names.append(key)
values.append(value)
# Store also the squared uncertainty
if "e_" in key: if "e_" in key:
key = key.replace("e_", "e2_") key = key.replace("e_", "e2_")
value = value**2 value = value**2
names.append(key)
names.append(key) values.append(value)
values.append(value)
self._setattr_as_jax(names, values) self._setattr_as_jax(names, values)
def _set_radial_spacing(self, r_xrange, Omega_m): def _set_radial_spacing(self, r_xrange, Omega_m):
cosmo = FlatLambdaCDM(H0=H0, Om0=Omega_m) cosmo = FlatLambdaCDM(H0=H0, Om0=Omega_m)
r_xrange = jnp.asarray(r_xrange)
r2_xrange = r_xrange**2 r2_xrange = r_xrange**2
r2_xrange /= r2_xrange.mean() r2_xrange /= r2_xrange.mean()
self.r_xrange = r_xrange self.r_xrange = r_xrange
@ -540,19 +452,30 @@ class BaseFlowValidationModel(ABC):
pass pass
###############################################################################
# Sampling shortcuts #
###############################################################################
def sample_alpha_bias(name, xmin, xmax, to_sample):
if to_sample:
return sample(f"alpha_{name}", Uniform(xmin, xmax))
else:
return 1.0
############################################################################### ###############################################################################
# SNIa parameters sampling # # SNIa parameters sampling #
############################################################################### ###############################################################################
def distmod_SN(mB, x1, c, mag_cal, alpha_cal, beta_cal): def distmod_SN(mag, x1, c, mag_cal, alpha_cal, beta_cal):
"""Distance modulus of a SALT2 SN Ia.""" """Distance modulus of a SALT2 SN Ia."""
return mB - mag_cal + alpha_cal * x1 - beta_cal * c return mag - mag_cal + alpha_cal * x1 - beta_cal * c
def e2_distmod_SN(e2_mB, e2_x1, e2_c, alpha_cal, beta_cal, e_mu_intrinsic): def e2_distmod_SN(e2_mag, e2_x1, e2_c, alpha_cal, beta_cal, e_mu_intrinsic):
"""Squared error on the distance modulus of a SALT2 SN Ia.""" """Squared error on the distance modulus of a SALT2 SN Ia."""
return (e2_mB + alpha_cal**2 * e2_x1 + beta_cal**2 * e2_c return (e2_mag + alpha_cal**2 * e2_x1 + beta_cal**2 * e2_c
+ e_mu_intrinsic**2) + e_mu_intrinsic**2)
@ -565,8 +488,7 @@ def sample_SN(e_mu_min, e_mu_max, mag_cal_mean, mag_cal_std, alpha_cal_mean,
alpha_cal = sample( alpha_cal = sample(
f"alpha_cal_{name}", Normal(alpha_cal_mean, alpha_cal_std)) f"alpha_cal_{name}", Normal(alpha_cal_mean, alpha_cal_std))
beta_cal = sample(f"beta_cal_{name}", Normal(beta_cal_mean, beta_cal_std)) beta_cal = sample(f"beta_cal_{name}", Normal(beta_cal_mean, beta_cal_std))
alpha = sample(f"alpha_{name}", alpha = sample_alpha_bias(name, alpha_min, alpha_max, sample_alpha)
Uniform(alpha_min, alpha_max)) if sample_alpha else 1.0
return {"e_mu": e_mu, return {"e_mu": e_mu,
"mag_cal": mag_cal, "mag_cal": mag_cal,
@ -580,29 +502,74 @@ def sample_SN(e_mu_min, e_mu_max, mag_cal_mean, mag_cal_std, alpha_cal_mean,
# Tully-Fisher parameters sampling # # Tully-Fisher parameters sampling #
############################################################################### ###############################################################################
def distmod_TFR(mag, eta, a, b): def distmod_TFR(mag, eta, a, b, c):
"""Distance modulus of a TFR calibration.""" """Distance modulus of a TFR calibration."""
return mag - (a + b * eta) return mag - (a + b * eta + c * eta**2)
def e2_distmod_TFR(e2_mag, e2_eta, b, e_mu_intrinsic): def e2_distmod_TFR(e2_mag, e2_eta, eta, b, c, e_mu_intrinsic):
"""Squared error on the TFR distance modulus.""" """
return e2_mag + b**2 * e2_eta + e_mu_intrinsic**2 Squared error on the TFR distance modulus with linearly propagated
magnitude and linewidth uncertainties.
"""
return e2_mag + (b + 2 * c * eta)**2 * e2_eta + e_mu_intrinsic**2
def sample_TFR(e_mu_min, e_mu_max, a_mean, a_std, b_mean, b_std, alpha_min, def sample_TFR(e_mu_min, e_mu_max, a_mean, a_std, b_mean, b_std,
alpha_max, sample_alpha, name): c_mean, c_std, alpha_min, alpha_max, sample_alpha,
sample_curvature, a_dipole_mean, a_dipole_std, sample_a_dipole,
name):
"""Sample Tully-Fisher calibration parameters.""" """Sample Tully-Fisher calibration parameters."""
e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max)) e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max))
a = sample(f"a_{name}", Normal(a_mean, a_std)) a = sample(f"a_{name}", Normal(a_mean, a_std))
if sample_a_dipole:
ax, ay, az = sample(f"a_dipole_{name}", Normal(0, 5).expand([3]))
else:
ax, ay, az = 0.0, 0.0, 0.0
b = sample(f"b_{name}", Normal(b_mean, b_std)) b = sample(f"b_{name}", Normal(b_mean, b_std))
alpha = sample(f"alpha_{name}", if sample_curvature:
Uniform(alpha_min, alpha_max)) if sample_alpha else 1.0 c = sample(f"c_{name}", Normal(c_mean, c_std))
else:
c = 0.0
alpha = sample_alpha_bias(name, alpha_min, alpha_max, sample_alpha)
return {"e_mu": e_mu, return {"e_mu": e_mu,
"a": a, "a": a,
"ax": ax, "ay": ay, "az": az,
"b": b, "b": b,
"alpha": alpha "c": c,
"alpha": alpha,
"sample_a_dipole": sample_a_dipole,
}
###############################################################################
# Simple calibration parameters sampling #
###############################################################################
def sample_simple(e_mu_min, e_mu_max, dmu_min, dmu_max, alpha_min, alpha_max,
dmu_dipole_mean, dmu_dipole_std, sample_alpha,
sample_dmu_dipole, name):
"""Sample simple calibration parameters."""
e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max))
dmu = sample(f"dmu_{name}", Uniform(dmu_min, dmu_max))
alpha = sample_alpha_bias(name, alpha_min, alpha_max, sample_alpha)
if sample_dmu_dipole:
dmux, dmuy, dmuz = sample(
f"dmu_dipole_{name}",
Normal(dmu_dipole_mean, dmu_dipole_std).expand([3]))
else:
dmux, dmuy, dmuz = 0.0, 0.0, 0.0
return {"e_mu": e_mu,
"dmu": dmu,
"dmux": dmux, "dmuy": dmuy, "dmuz": dmuz,
"alpha": alpha,
"sample_dmu_dipole": sample_dmu_dipole,
} }
############################################################################### ###############################################################################
@ -612,19 +579,24 @@ def sample_TFR(e_mu_min, e_mu_max, a_mean, a_std, b_mean, b_std, alpha_min,
def sample_calibration(Vext_min, Vext_max, Vmono_min, Vmono_max, beta_min, def sample_calibration(Vext_min, Vext_max, Vmono_min, Vmono_max, beta_min,
beta_max, sigma_v_min, sigma_v_max, sample_Vmono, beta_max, sigma_v_min, sigma_v_max, sample_Vmono,
sample_beta, sample_sigma_v_ext): sample_beta):
"""Sample the flow calibration.""" """Sample the flow calibration."""
Vext = sample("Vext", Uniform(Vext_min, Vext_max).expand([3])) Vext = sample("Vext", Uniform(Vext_min, Vext_max).expand([3]))
sigma_v = sample("sigma_v", Uniform(sigma_v_min, sigma_v_max)) sigma_v = sample("sigma_v", Uniform(sigma_v_min, sigma_v_max))
beta = sample("beta", Uniform(beta_min, beta_max)) if sample_beta else 1.0 # noqa if sample_beta:
Vmono = sample("Vmono", Uniform(Vmono_min, Vmono_max)) if sample_Vmono else 0.0 # noqa beta = sample("beta", Uniform(beta_min, beta_max))
sigma_v_ext = sample("sigma_v_ext", Uniform(sigma_v_min, sigma_v_max)) if sample_sigma_v_ext else sigma_v # noqa else:
beta = 1.0
if sample_Vmono:
Vmono = sample("Vmono", Uniform(Vmono_min, Vmono_max))
else:
Vmono = 0.0
return {"Vext": Vext, return {"Vext": Vext,
"Vmono": Vmono, "Vmono": Vmono,
"sigma_v": sigma_v, "sigma_v": sigma_v,
"sigma_v_ext": sigma_v_ext,
"beta": beta} "beta": beta}
@ -633,21 +605,11 @@ def sample_calibration(Vext_min, Vext_max, Vmono_min, Vmono_max, beta_min,
############################################################################### ###############################################################################
def find_extrap_mask(rmax, rdist): def sample_gaussian_hyperprior(param, name, xmin, xmax):
""" """Sample MNR Gaussian hyperprior mean and standard deviation."""
Make a mask of shape `(nsim, ngal, nrdist)` of which velocity field values mean = sample(f"{param}_mean_{name}", Uniform(xmin, xmax))
are extrapolated. above which the std = sample(f"{param}_std_{name}", Uniform(0.0, xmax - xmin))
""" return mean, std
nsim, ngal = rmax.shape
extrap_mask = np.zeros((nsim, ngal, len(rdist)), dtype=bool)
extrap_weights = np.ones((nsim, ngal, len(rdist)))
for i in range(nsim):
for j in range(ngal):
k = np.searchsorted(rdist, rmax[i, j])
extrap_mask[i, j, k:] = True
extrap_weights[i, j, k:] = rmax[i, j] / rdist[k:]
return extrap_mask, extrap_weights
class PV_LogLikelihood(BaseFlowValidationModel): class PV_LogLikelihood(BaseFlowValidationModel):
@ -671,14 +633,21 @@ class PV_LogLikelihood(BaseFlowValidationModel):
Errors on the observed redshifts. Errors on the observed redshifts.
calibration_params: dict calibration_params: dict
Calibration parameters of each object. Calibration parameters of each object.
magmax_selection : float
Maximum magnitude selection if strict threshold.
r_xrange : 1-dimensional array r_xrange : 1-dimensional array
Radial distances where the field was interpolated for each object. Radial distances where the field was interpolated for each object.
Omega_m : float Omega_m : float
Matter density parameter. Matter density parameter.
kind : str
Catalogue kind, either "TFR", "SN", or "simple".
name : str
Name of the catalogue.
""" """
def __init__(self, los_density, los_velocity, rmax, RA, dec, z_obs, def __init__(self, los_density, los_velocity, rmax, RA, dec, z_obs,
e_zobs, calibration_params, r_xrange, Omega_m, kind, name): e_zobs, calibration_params, maxmag_selection, r_xrange,
Omega_m, kind, name):
if e_zobs is not None: if e_zobs is not None:
e2_cz_obs = jnp.asarray((SPEED_OF_LIGHT * e_zobs)**2) e2_cz_obs = jnp.asarray((SPEED_OF_LIGHT * e_zobs)**2)
else: else:
@ -699,68 +668,172 @@ class PV_LogLikelihood(BaseFlowValidationModel):
self.name = name self.name = name
self.Omega_m = Omega_m self.Omega_m = Omega_m
self.norm = - self.ndata * jnp.log(self.num_sims) self.norm = - self.ndata * jnp.log(self.num_sims)
self.maxmag_selection = maxmag_selection
extrap_mask, extrap_weights = find_extrap_mask(rmax, r_xrange) if kind == "TFR":
self.extrap_mask = jnp.asarray(extrap_mask) self.mag_min, self.mag_max = jnp.min(self.mag), jnp.max(self.mag)
self.extrap_weights = jnp.asarray(extrap_weights) eta_mu = jnp.mean(self.eta)
fprint(f"setting the linewith mean to 0 instead of {eta_mu:.3f}.")
self.eta -= eta_mu
self.eta_min, self.eta_max = jnp.min(self.eta), jnp.max(self.eta)
elif kind == "SN":
self.mag_min, self.mag_max = jnp.min(self.mag), jnp.max(self.mag)
self.x1_min, self.x1_max = jnp.min(self.x1), jnp.max(self.x1)
self.c_min, self.c_max = jnp.min(self.c), jnp.max(self.c)
elif kind == "simple":
self.mu_min, self.mu_max = jnp.min(self.mu), jnp.max(self.mu)
else:
raise RuntimeError("Support most be added for other kinds.")
if maxmag_selection is not None and self.maxmag_selection > self.mag_max: # noqa
raise ValueError("The maximum magnitude cannot be larger than the selection threshold.") # noqa
def __call__(self, field_calibration_params, distmod_params, def __call__(self, field_calibration_params, distmod_params,
sample_sigma_v_ext): inference_method):
"""PV validation model log-likelihood.""" if inference_method not in ["mike", "bayes"]:
# Turn e2_cz to be of shape (nsims, ndata, nxrange) and apply raise ValueError(f"Unknown method: `{inference_method}`.")
# sigma_v_ext where applicable
sigma_v = field_calibration_params["sigma_v"]
sigma_v_ext = field_calibration_params["sigma_v_ext"]
e2_cz = jnp.full_like(self.extrap_mask, sigma_v**2, dtype=jnp.float32)
if sample_sigma_v_ext:
e2_cz = e2_cz.at[self.extrap_mask].set(sigma_v_ext**2)
# Now add the observational errors ll0 = 0.0
e2_cz += self.e2_cz_obs[None, :, None] sigma_v = field_calibration_params["sigma_v"]
e2_cz = self.e2_cz_obs + sigma_v**2
Vext = field_calibration_params["Vext"] Vext = field_calibration_params["Vext"]
Vmono = field_calibration_params["Vmono"] Vmono = field_calibration_params["Vmono"]
Vext_rad = project_Vext(Vext[0], Vext[1], Vext[2], self.RA, self.dec) Vext_rad = project_Vext(Vext[0], Vext[1], Vext[2], self.RA, self.dec)
e_mu = distmod_params["e_mu"] e_mu = distmod_params["e_mu"]
alpha = distmod_params["alpha"]
if self.kind == "SN": if self.kind == "SN":
mag_cal = distmod_params["mag_cal"] mag_cal = distmod_params["mag_cal"]
alpha_cal = distmod_params["alpha_cal"] alpha_cal = distmod_params["alpha_cal"]
beta_cal = distmod_params["beta_cal"] beta_cal = distmod_params["beta_cal"]
if inference_method == "bayes":
mag_mean, mag_std = sample_gaussian_hyperprior(
"mag", self.name, self.mag_min, self.mag_max)
x1_mean, x1_std = sample_gaussian_hyperprior(
"x1", self.name, self.x1_min, self.x1_max)
c_mean, c_std = sample_gaussian_hyperprior(
"c", self.name, self.c_min, self.c_max)
# NOTE: that the true variables are currently uncorrelated.
with plate("true_SN", self.ndata):
mag_true = sample(
f"mag_true_{self.name}", Normal(mag_mean, mag_std))
x1_true = sample(
f"x1_true_{self.name}", Normal(x1_mean, x1_std))
c_true = sample(
f"c_true_{self.name}", Normal(c_mean, c_std))
# Log-likelihood of the observed magnitudes.
if self.maxmag_selection is None:
ll0 += jnp.sum(normal_logpdf(
mag_true, self.mag, self.e_mag))
else:
raise NotImplementedError("Maxmag selection not implemented.") # noqa
# Log-likelihood of the observed x1 and c.
ll0 += jnp.sum(normal_logpdf(x1_true, self.x1, self.e_x1))
ll0 += jnp.sum(normal_logpdf(c_true, self.c, self.e_c))
e2_mu = jnp.ones_like(mag_true) * e_mu**2
else:
mag_true = self.mag
x1_true = self.x1
c_true = self.c
e2_mu = e2_distmod_SN(
self.e2_mag, self.e2_x1, self.e2_c, alpha_cal, beta_cal,
e_mu)
mu = distmod_SN( mu = distmod_SN(
self.mB, self.x1, self.c, mag_cal, alpha_cal, beta_cal) mag_true, x1_true, c_true, mag_cal, alpha_cal, beta_cal)
squared_e_mu = e2_distmod_SN(
self.e2_mB, self.e2_x1, self.e2_c, alpha_cal, beta_cal, e_mu)
elif self.kind == "TFR": elif self.kind == "TFR":
a = distmod_params["a"] a = distmod_params["a"]
b = distmod_params["b"] b = distmod_params["b"]
mu = distmod_TFR(self.mag, self.eta, a, b) c = distmod_params["c"]
squared_e_mu = e2_distmod_TFR(self.e2_mag, self.e2_eta, b, e_mu)
if distmod_params["sample_a_dipole"]:
ax, ay, az = (distmod_params[k] for k in ["ax", "ay", "az"])
a = a + project_Vext(ax, ay, az, self.RA, self.dec)
if inference_method == "bayes":
# Sample the true TFR parameters.
mag_mean, mag_std = sample_gaussian_hyperprior(
"mag", self.name, self.mag_min, self.mag_max)
eta_mean, eta_std = sample_gaussian_hyperprior(
"eta", self.name, self.eta_min, self.eta_max)
corr_mag_eta = sample("corr_mag_eta", Uniform(-1, 1))
loc = jnp.array([mag_mean, eta_mean])
cov = jnp.array(
[[mag_std**2, corr_mag_eta * mag_std * eta_std],
[corr_mag_eta * mag_std * eta_std, eta_std**2]])
with plate("true_TFR", self.ndata):
x_true = sample("x_TFR", MultivariateNormal(loc, cov))
mag_true, eta_true = x_true[..., 0], x_true[..., 1]
# Log-likelihood of the observed magnitudes.
if self.maxmag_selection is None:
ll0 += jnp.sum(normal_logpdf(
self.mag, mag_true, self.e_mag))
else:
ll0 += jnp.sum(upper_truncated_normal_logpdf(
self.mag, mag_true, self.e_mag, self.maxmag_selection))
# Log-likelihood of the observed linewidths.
ll0 += jnp.sum(normal_logpdf(eta_true, self.eta, self.e_eta))
e2_mu = jnp.ones_like(mag_true) * e_mu**2
else:
eta_true = self.eta
mag_true = self.mag
e2_mu = e2_distmod_TFR(
self.e2_mag, self.e2_eta, eta_true, b, c, e_mu)
mu = distmod_TFR(mag_true, eta_true, a, b, c)
elif self.kind == "simple":
dmu = distmod_params["dmu"]
if distmod_params["sample_dmu_dipole"]:
dmux, dmuy, dmuz = (
distmod_params[k] for k in ["dmux", "dmuy", "dmuz"])
dmu = dmu + project_Vext(dmux, dmuy, dmuz, self.RA, self.dec)
if inference_method == "bayes":
raise NotImplementedError("Bayes for simple not implemented.")
else:
mu_true = self.mu
e2_mu = e_mu**2 + self.e2_mu
mu = mu_true + dmu
else: else:
raise ValueError(f"Unknown kind: `{self.kind}`.") raise ValueError(f"Unknown kind: `{self.kind}`.")
# Calculate p(r) (Malmquist bias). Shape is (ndata, nxrange) # Calculate p(r) (Malmquist bias). Shape is (ndata, nxrange)
ptilde = jnp.transpose(vmap(calculate_ptilde_wo_bias, in_axes=(0, None, None, 0))(self.mu_xrange, mu, squared_e_mu, self.r2_xrange)) # noqa ptilde = ptilde_wo_bias(
self.mu_xrange[None, :], mu[:, None], e2_mu[:, None],
self.r2_xrange[None, :])
# Inhomogeneous Malmquist bias. Shape is (n_sims, ndata, nxrange) # Inhomogeneous Malmquist bias. Shape is (n_sims, ndata, nxrange)
ptilde = self.los_density**alpha * ptilde alpha = distmod_params["alpha"]
ptilde = ptilde[None, ...] * self.los_density**alpha
# Normalization of p(r). Shape is (n_sims, ndata) # Normalization of p(r). Shape is (n_sims, ndata)
pnorm = simpson(ptilde, dx=self.dr, axis=-1) pnorm = simpson(ptilde, dx=self.dr, axis=-1)
# Calculate z_obs at each distance. Shape is (n_sims, ndata, nxrange) # Calculate z_obs at each distance. Shape is (n_sims, ndata, nxrange)
# The weights are related to the extrapolation of the velocity field.
vrad = field_calibration_params["beta"] * self.los_velocity vrad = field_calibration_params["beta"] * self.los_velocity
vrad += (Vext_rad[None, :, None] + Vmono) * self.extrap_weights vrad += (Vext_rad[None, :, None] + Vmono)
zobs = (1 + self.z_xrange[None, None, :]) * (1 + vrad / SPEED_OF_LIGHT) - 1 # noqa zobs = (1 + self.z_xrange[None, None, :]) * (1 + vrad / SPEED_OF_LIGHT)
zobs -= 1.
ptilde *= likelihood_zobs(
self.z_obs[None, :, None], zobs, e2_cz[None, :, None])
ptilde *= calculate_likelihood_zobs(self.z_obs, zobs, e2_cz)
ll = jnp.log(simpson(ptilde, dx=self.dr, axis=-1)) - jnp.log(pnorm) ll = jnp.log(simpson(ptilde, dx=self.dr, axis=-1)) - jnp.log(pnorm)
return ll0 + jnp.sum(logsumexp(ll, axis=0)) + self.norm
return jnp.sum(logsumexp(ll, axis=0)) + self.norm
def PV_validation_model(models, distmod_hyperparams_per_model, def PV_validation_model(models, distmod_hyperparams_per_model,
field_calibration_hyperparams): field_calibration_hyperparams, inference_method):
""" """
Peculiar velocity validation NumPyro model. Peculiar velocity validation NumPyro model.
@ -772,27 +845,29 @@ def PV_validation_model(models, distmod_hyperparams_per_model,
Distance modulus hyperparameters for each model/catalogue. Distance modulus hyperparameters for each model/catalogue.
field_calibration_hyperparams : dict field_calibration_hyperparams : dict
Field calibration hyperparameters. Field calibration hyperparameters.
inference_method : str
Either `mike` or `bayes`.
""" """
field_calibration_params = sample_calibration( field_calibration_params = sample_calibration(
**field_calibration_hyperparams) **field_calibration_hyperparams)
sample_sigma_v_ext = field_calibration_hyperparams["sample_sigma_v_ext"]
ll = 0.0 ll = 0.0
for n in range(len(models)): for n in range(len(models)):
model = models[n] model = models[n]
name = model.name
distmod_hyperparams = distmod_hyperparams_per_model[n] distmod_hyperparams = distmod_hyperparams_per_model[n]
if model.kind == "TFR": if model.kind == "TFR":
distmod_params = sample_TFR(**distmod_hyperparams, name=model.name) distmod_params = sample_TFR(**distmod_hyperparams, name=name)
elif model.kind == "SN": elif model.kind == "SN":
distmod_params = sample_SN(**distmod_hyperparams, name=model.name) distmod_params = sample_SN(**distmod_hyperparams, name=name)
elif model.kind == "simple":
distmod_params = sample_simple(**distmod_hyperparams, name=name)
else: else:
raise ValueError(f"Unknown kind: `{model.kind}`.") raise ValueError(f"Unknown kind: `{model.kind}`.")
ll += model( ll += model(field_calibration_params, distmod_params, inference_method)
field_calibration_params, distmod_params, sample_sigma_v_ext)
deterministic("ll_values", ll)
factor("ll", ll) factor("ll", ll)
@ -801,7 +876,7 @@ def PV_validation_model(models, distmod_hyperparams_per_model,
############################################################################### ###############################################################################
def get_model(loader, zcmb_max=None, verbose=True): def get_model(loader, zcmb_min=0.0, zcmb_max=None, maxmag_selection=None):
""" """
Get a model and extract the relevant data from the loader. Get a model and extract the relevant data from the loader.
@ -809,10 +884,12 @@ def get_model(loader, zcmb_max=None, verbose=True):
---------- ----------
loader : DataLoader loader : DataLoader
DataLoader instance. DataLoader instance.
zcmb_min : float, optional
Minimum observed redshift in the CMB frame to include.
zcmb_max : float, optional zcmb_max : float, optional
Maximum observed redshift in the CMB frame to include. Maximum observed redshift in the CMB frame to include.
verbose : bool, optional maxmag_selection : float, optional
Verbosity flag. Maximum magnitude selection threshold.
Returns Returns
------- -------
@ -825,20 +902,24 @@ def get_model(loader, zcmb_max=None, verbose=True):
rmax = loader.rmax rmax = loader.rmax
kind = loader._catname kind = loader._catname
if maxmag_selection is not None and kind != "2MTF":
raise ValueError("Threshold magnitude selection implemented only for 2MTF.") # noqa
if kind in ["LOSS", "Foundation"]: if kind in ["LOSS", "Foundation"]:
keys = ["RA", "DEC", "z_CMB", "mB", "x1", "c", "e_mB", "e_x1", "e_c"] keys = ["RA", "DEC", "z_CMB", "mB", "x1", "c", "e_mB", "e_x1", "e_c"]
RA, dec, zCMB, mB, x1, c, e_mB, e_x1, e_c = (loader.cat[k] for k in keys) # noqa RA, dec, zCMB, mag, x1, c, e_mag, e_x1, e_c = (
loader.cat[k] for k in keys)
e_zCMB = None e_zCMB = None
mask = (zCMB < zcmb_max) mask = (zCMB < zcmb_max) & (zCMB > zcmb_min)
calibration_params = {"mB": mB[mask], "x1": x1[mask], "c": c[mask], calibration_params = {"mag": mag[mask], "x1": x1[mask], "c": c[mask],
"e_mB": e_mB[mask], "e_x1": e_x1[mask], "e_mag": e_mag[mask], "e_x1": e_x1[mask],
"e_c": e_c[mask]} "e_c": e_c[mask]}
model = PV_LogLikelihood( model = PV_LogLikelihood(
los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask], los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask],
RA[mask], dec[mask], zCMB[mask], e_zCMB, calibration_params, RA[mask], dec[mask], zCMB[mask], e_zCMB, calibration_params,
loader.rdist, loader._Omega_m, "SN", name=kind) maxmag_selection, loader.rdist, loader._Omega_m, "SN", name=kind)
elif "Pantheon+" in kind: elif "Pantheon+" in kind:
keys = ["RA", "DEC", "zCMB", "mB", "x1", "c", "biasCor_m_b", "mBERR", keys = ["RA", "DEC", "zCMB", "mB", "x1", "c", "biasCor_m_b", "mBERR",
"x1ERR", "cERR", "biasCorErr_m_b", "zCMB_SN", "zCMB_Group", "x1ERR", "cERR", "biasCorErr_m_b", "zCMB_SN", "zCMB_Group",
@ -848,7 +929,7 @@ def get_model(loader, zcmb_max=None, verbose=True):
mB -= bias_corr_mB mB -= bias_corr_mB
e_mB = np.sqrt(e_mB**2 + e_bias_corr_mB**2) e_mB = np.sqrt(e_mB**2 + e_bias_corr_mB**2)
mask = zCMB < zcmb_max mask = (zCMB < zcmb_max) & (zCMB > zcmb_min)
if kind == "Pantheon+_groups": if kind == "Pantheon+_groups":
mask &= np.isfinite(zCMB_Group) mask &= np.isfinite(zCMB_Group)
@ -860,29 +941,80 @@ def get_model(loader, zcmb_max=None, verbose=True):
if kind == "Pantheon+_zSN": if kind == "Pantheon+_zSN":
zCMB = zCMB_SN zCMB = zCMB_SN
calibration_params = {"mB": mB[mask], "x1": x1[mask], "c": c[mask], calibration_params = {"mag": mB[mask], "x1": x1[mask], "c": c[mask],
"e_mB": e_mB[mask], "e_x1": e_x1[mask], "e_mag": e_mB[mask], "e_x1": e_x1[mask],
"e_c": e_c[mask]} "e_c": e_c[mask]}
model = PV_LogLikelihood( model = PV_LogLikelihood(
los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask], los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask],
RA[mask], dec[mask], zCMB[mask], e_zCMB[mask], calibration_params, RA[mask], dec[mask], zCMB[mask], e_zCMB[mask], calibration_params,
loader.rdist, loader._Omega_m, "SN", name=kind) maxmag_selection, loader.rdist, loader._Omega_m, "SN", name=kind)
elif kind in ["SFI_gals", "2MTF", "SFI_gals_masked"]: elif kind in ["SFI_gals", "2MTF", "SFI_gals_masked"]:
keys = ["RA", "DEC", "z_CMB", "mag", "eta", "e_mag", "e_eta"] keys = ["RA", "DEC", "z_CMB", "mag", "eta", "e_mag", "e_eta"]
RA, dec, zCMB, mag, eta, e_mag, e_eta = (loader.cat[k] for k in keys) RA, dec, zCMB, mag, eta, e_mag, e_eta = (loader.cat[k] for k in keys)
mask = (zCMB < zcmb_max) mask = (zCMB < zcmb_max) & (zCMB > zcmb_min)
if kind == "SFI_gals":
mask &= (eta > -0.15) & (eta < 0.2)
if verbose:
print("Emplyed eta cut for SFI galaxies.", flush=True)
calibration_params = {"mag": mag[mask], "eta": eta[mask], calibration_params = {"mag": mag[mask], "eta": eta[mask],
"e_mag": e_mag[mask], "e_eta": e_eta[mask]} "e_mag": e_mag[mask], "e_eta": e_eta[mask]}
model = PV_LogLikelihood( model = PV_LogLikelihood(
los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask], los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask],
RA[mask], dec[mask], zCMB[mask], None, calibration_params, RA[mask], dec[mask], zCMB[mask], None, calibration_params,
loader.rdist, loader._Omega_m, "TFR", name=kind) maxmag_selection, loader.rdist, loader._Omega_m, "TFR", name=kind)
elif "CF4_TFR_" in kind:
# The full name can be e.g. "CF4_TFR_not2MTForSFI_i" or "CF4_TFR_i".
band = kind.split("_")[-1]
if band not in ['g', 'r', 'i', 'z', 'w1', 'w2']:
raise ValueError(f"Band `{band}` not recognized.")
keys = ["RA", "DEC", "Vcmb", f"{band}", "lgWmxi", "elgWi",
"not_matched_to_2MTF_or_SFI", "Qs", "Qw"]
RA, dec, z_obs, mag, eta, e_eta, not_matched_to_2MTF_or_SFI, Qs, Qw = (
loader.cat[k] for k in keys)
not_matched_to_2MTF_or_SFI = not_matched_to_2MTF_or_SFI.astype(bool)
# NOTE: fiducial uncertainty until we can get the actual values.
e_mag = 0.001 * np.ones_like(mag)
z_obs /= SPEED_OF_LIGHT
eta -= 2.5
fprint("selecting only galaxies with mag > 5 and eta > -0.3.")
mask = (mag > 5) & (eta > -0.3)
mask &= (z_obs < zcmb_max) & (z_obs > zcmb_min)
if "not2MTForSFI" in kind:
mask &= not_matched_to_2MTF_or_SFI
elif "2MTForSFI" in kind:
mask &= ~not_matched_to_2MTF_or_SFI
fprint("employing a quality cut on the galaxies.")
if "w" in band:
mask &= Qw == 5
else:
mask &= Qs == 5
calibration_params = {"mag": mag[mask], "eta": eta[mask],
"e_mag": e_mag[mask], "e_eta": e_eta[mask]}
model = PV_LogLikelihood(
los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask],
RA[mask], dec[mask], z_obs[mask], None, calibration_params,
maxmag_selection, loader.rdist, loader._Omega_m, "TFR", name=kind)
elif kind in ["CF4_GroupAll"]:
# Note, this for some reason works terribly.
keys = ["RA", "DE", "Vcmb", "DMzp", "eDM"]
RA, dec, zCMB, mu, e_mu = (loader.cat[k] for k in keys)
zCMB /= SPEED_OF_LIGHT
mask = (zCMB < zcmb_max) & (zCMB > zcmb_min) & np.isfinite(mu)
# The distance moduli in CF4 are most likely given assuming h = 0.75
mu += 5 * np.log10(0.75)
calibration_params = {"mu": mu[mask], "e_mu": e_mu[mask]}
model = PV_LogLikelihood(
los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask],
RA[mask], dec[mask], zCMB[mask], None, calibration_params,
maxmag_selection, loader.rdist, loader._Omega_m, "simple",
name=kind)
else: else:
raise ValueError(f"Catalogue `{kind}` not recognized.") raise ValueError(f"Catalogue `{kind}` not recognized.")

View file

@ -469,10 +469,9 @@ def BIC_AIC(samples, log_likelihood, ndata):
for val in samples.values(): for val in samples.values():
if val.ndim == 1: if val.ndim == 1:
nparam += 1 nparam += 1
elif val.ndim == 2:
nparam += val.shape[-1]
else: else:
raise ValueError("Invalid dimensionality of samples to count the number of parameters.") # noqa # The first dimension is the number of steps.
nparam += np.prod(val.shape[1:])
BIC = nparam * np.log(ndata) - 2 * log_likelihood[kmax] BIC = nparam * np.log(ndata) - 2 * log_likelihood[kmax]
AIC = 2 * nparam - 2 * log_likelihood[kmax] AIC = 2 * nparam - 2 * log_likelihood[kmax]

View file

@ -15,6 +15,7 @@
MPI script to interpolate the density and velocity fields along the line of MPI script to interpolate the density and velocity fields along the line of
sight. sight.
""" """
import sys
from argparse import ArgumentParser from argparse import ArgumentParser
from datetime import datetime from datetime import datetime
from gc import collect from gc import collect
@ -32,7 +33,8 @@ from mpi4py import MPI
from numba import jit from numba import jit
from taskmaster import work_delegation # noqa from taskmaster import work_delegation # noqa
from utils import get_nsims sys.path.append("../")
from utils import get_nsims # noqa
############################################################################### ###############################################################################
# I/O functions # # I/O functions #
@ -84,8 +86,18 @@ def get_los(catalogue_name, simname, comm):
with File(fname, 'r') as f: with File(fname, 'r') as f:
RA = f["RA"][:] RA = f["RA"][:]
dec = f["DEC"][:] dec = f["DEC"][:]
elif catalogue_name == "CF4_GroupAll":
fname = "/mnt/extraspace/rstiskalek/catalogs/PV/CF4/CF4_GroupAll.hdf5" # noqa
with File(fname, 'r') as f:
RA = f["RA"][:]
dec = f["DE"][:]
elif catalogue_name == "CF4_TFR":
fname = "/mnt/extraspace/rstiskalek/catalogs/PV/CF4/CF4_TF-distances.hdf5" # noqa
with File(fname, 'r') as f:
RA = f["RA"][:] * 360 / 24 # Convert to degrees from hours.
dec = f["DE"][:]
else: else:
raise ValueError(f"Unknown field name: `{catalogue_name}`.") raise ValueError(f"Unknown catalogue name: `{catalogue_name}`.")
if comm.Get_rank() == 0: if comm.Get_rank() == 0:
print(f"The dataset contains {len(RA)} objects.") print(f"The dataset contains {len(RA)} objects.")

View file

@ -10,8 +10,8 @@ MAS="SPH"
grid=1024 grid=1024
for simname in "CF4"; do for simname in "Carrick2015"; do
for catalogue in "Foundation"; do for catalogue in "CF4_TFR"; do
pythoncm="$env $file --catalogue $catalogue --nsims $nsims --simname $simname --MAS $MAS --grid $grid" pythoncm="$env $file --catalogue $catalogue --nsims $nsims --simname $simname --MAS $MAS --grid $grid"
if [ $on_login -eq 1 ]; then if [ $on_login -eq 1 ]; then
echo $pythoncm echo $pythoncm

View file

@ -100,6 +100,10 @@ def get_models(get_model_kwargs, verbose=True):
"Pantheon+_groups", "Pantheon+_groups_zSN", "Pantheon+_groups", "Pantheon+_groups_zSN",
"Pantheon+_zSN"]: "Pantheon+_zSN"]:
fpath = join(folder, "PV_compilation.hdf5") fpath = join(folder, "PV_compilation.hdf5")
elif "CF4_TFR" in cat:
fpath = join(folder, "PV/CF4/CF4_TF-distances.hdf5")
elif cat in ["CF4_GroupAll"]:
fpath = join(folder, "PV/CF4/CF4_GroupAll.hdf5")
else: else:
raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.") raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.")
@ -139,11 +143,7 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
samples = mcmc.get_samples() samples = mcmc.get_samples()
log_posterior = -mcmc.get_extra_fields()["potential_energy"] log_posterior = -mcmc.get_extra_fields()["potential_energy"]
log_likelihood = samples.pop("ll_values") BIC, AIC = csiborgtools.BIC_AIC(samples, log_posterior, ndata)
if log_likelihood is None:
raise ValueError("The samples must contain the log likelihood values under the key `ll_values`.") # noqa
BIC, AIC = csiborgtools.BIC_AIC(samples, log_likelihood, ndata)
print(f"{'BIC':<20} {BIC}") print(f"{'BIC':<20} {BIC}")
print(f"{'AIC':<20} {AIC}") print(f"{'AIC':<20} {AIC}")
mcmc.print_summary() mcmc.print_summary()
@ -174,7 +174,6 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
grp.create_dataset(key, data=value) grp.create_dataset(key, data=value)
# Write log likelihood and posterior # Write log likelihood and posterior
f.create_dataset("log_likelihood", data=log_likelihood)
f.create_dataset("log_posterior", data=log_posterior) f.create_dataset("log_posterior", data=log_posterior)
# Write goodness of fit # Write goodness of fit
@ -207,10 +206,9 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
# Command line interface # # Command line interface #
############################################################################### ###############################################################################
def get_distmod_hyperparams(catalogue): def get_distmod_hyperparams(catalogue, sample_alpha):
alpha_min = -1.0 alpha_min = -1.0
alpha_max = 3.0 alpha_max = 3.0
sample_alpha = True
if catalogue in ["LOSS", "Foundation", "Pantheon+", "Pantheon+_groups", "Pantheon+_zSN"]: # noqa if catalogue in ["LOSS", "Foundation", "Pantheon+", "Pantheon+_groups", "Pantheon+_zSN"]: # noqa
return {"e_mu_min": 0.001, "e_mu_max": 1.0, return {"e_mu_min": 0.001, "e_mu_max": 1.0,
@ -220,12 +218,24 @@ def get_distmod_hyperparams(catalogue):
"alpha_min": alpha_min, "alpha_max": alpha_max, "alpha_min": alpha_min, "alpha_max": alpha_max,
"sample_alpha": sample_alpha "sample_alpha": sample_alpha
} }
elif catalogue in ["SFI_gals", "2MTF"]: elif catalogue in ["SFI_gals", "2MTF"] or "CF4_TFR" in catalogue:
return {"e_mu_min": 0.001, "e_mu_max": 1.0, return {"e_mu_min": 0.001, "e_mu_max": 1.0,
"a_mean": -21., "a_std": 5.0, "a_mean": -21., "a_std": 5.0,
"b_mean": -5.95, "b_std": 3.0, "b_mean": -5.95, "b_std": 4.0,
"c_mean": 0., "c_std": 20.0,
"sample_curvature": False,
"a_dipole_mean": 0., "a_dipole_std": 1.0,
"sample_a_dipole": True,
"alpha_min": alpha_min, "alpha_max": alpha_max, "alpha_min": alpha_min, "alpha_max": alpha_max,
"sample_alpha": sample_alpha "sample_alpha": sample_alpha,
}
elif catalogue in ["CF4_GroupAll"]:
return {"e_mu_min": 0.001, "e_mu_max": 1.0,
"dmu_min": -3.0, "dmu_max": 3.0,
"dmu_dipole_mean": 0., "dmu_dipole_std": 1.0,
"sample_dmu_dipole": True,
"alpha_min": alpha_min, "alpha_max": alpha_max,
"sample_alpha": sample_alpha,
} }
else: else:
raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.") raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.")
@ -241,37 +251,46 @@ if __name__ == "__main__":
# Fixed user parameters # # Fixed user parameters #
########################################################################### ###########################################################################
nsteps = 5000 nsteps = 1500
nburn = 1500 nburn = 1000
zcmb_min = 0
zcmb_max = 0.05 zcmb_max = 0.05
calculate_evidence = False calculate_evidence = False
nchains_harmonic = 10 nchains_harmonic = 10
num_epochs = 30 num_epochs = 30
inference_method = "mike"
maxmag_selection = None
sample_alpha = True
sample_beta = True
sample_Vmono = False
if nsteps % nchains_harmonic != 0: if nsteps % nchains_harmonic != 0:
raise ValueError( raise ValueError(
"The number of steps must be divisible by the number of chains.") "The number of steps must be divisible by the number of chains.")
main_params = {"nsteps": nsteps, "nburn": nburn, "zcmb_max": zcmb_max, main_params = {"nsteps": nsteps, "nburn": nburn,
"zcmb_min": zcmb_min,
"zcmb_max": zcmb_max,
"maxmag_selection": maxmag_selection,
"calculate_evidence": calculate_evidence, "calculate_evidence": calculate_evidence,
"nchains_harmonic": nchains_harmonic, "nchains_harmonic": nchains_harmonic,
"num_epochs": num_epochs} "num_epochs": num_epochs,
"inference_method": inference_method}
print_variables(main_params.keys(), main_params.values()) print_variables(main_params.keys(), main_params.values())
calibration_hyperparams = {"Vext_min": -1000, "Vext_max": 1000, calibration_hyperparams = {"Vext_min": -1000, "Vext_max": 1000,
"Vmono_min": -1000, "Vmono_max": 1000, "Vmono_min": -1000, "Vmono_max": 1000,
"beta_min": -1.0, "beta_max": 3.0, "beta_min": -1.0, "beta_max": 3.0,
"sigma_v_min": 1.0, "sigma_v_max": 750., "sigma_v_min": 1.0, "sigma_v_max": 750.,
"sample_Vmono": False, "sample_Vmono": sample_Vmono,
"sample_beta": True, "sample_beta": sample_beta,
"sample_sigma_v_ext": False,
} }
print_variables( print_variables(
calibration_hyperparams.keys(), calibration_hyperparams.values()) calibration_hyperparams.keys(), calibration_hyperparams.values())
distmod_hyperparams_per_catalogue = [] distmod_hyperparams_per_catalogue = []
for cat in ARGS.catalogue: for cat in ARGS.catalogue:
x = get_distmod_hyperparams(cat) x = get_distmod_hyperparams(cat, sample_alpha)
print(f"\n{cat} hyperparameters:") print(f"\n{cat} hyperparameters:")
print_variables(x.keys(), x.values()) print_variables(x.keys(), x.values())
distmod_hyperparams_per_catalogue.append(x) distmod_hyperparams_per_catalogue.append(x)
@ -280,12 +299,14 @@ if __name__ == "__main__":
*distmod_hyperparams_per_catalogue) *distmod_hyperparams_per_catalogue)
########################################################################### ###########################################################################
get_model_kwargs = {"zcmb_max": zcmb_max} get_model_kwargs = {"zcmb_min": zcmb_min, "zcmb_max": zcmb_max,
"maxmag_selection": maxmag_selection}
models = get_models(get_model_kwargs, ) models = get_models(get_model_kwargs, )
model_kwargs = { model_kwargs = {
"models": models, "models": models,
"field_calibration_hyperparams": calibration_hyperparams, "field_calibration_hyperparams": calibration_hyperparams,
"distmod_hyperparams_per_model": distmod_hyperparams_per_catalogue, "distmod_hyperparams_per_model": distmod_hyperparams_per_catalogue,
"inference_method": inference_method,
} }
model = csiborgtools.flow.PV_validation_model model = csiborgtools.flow.PV_validation_model

View file

@ -39,7 +39,8 @@ fi
# for simname in "Lilow2024" "CF4" "CF4gp" "csiborg1" "csiborg2_main" "csiborg2X"; do # for simname in "Lilow2024" "CF4" "CF4gp" "csiborg1" "csiborg2_main" "csiborg2X"; do
for simname in "Carrick2015"; do for simname in "Carrick2015"; do
for catalogue in "LOSS,2MTF,SFI_gals"; do for catalogue in "CF4_GroupAll"; do
# for catalogue in "CF4_TFR_i"; do
# for ksim in 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20; do # for ksim in 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20; do
for ksim in "none"; do for ksim in "none"; do
pythoncm="$env $file --catalogue $catalogue --simname $simname --ksim $ksim --ksmooth $ksmooth --ndevice $ndevice --device $device" pythoncm="$env $file --catalogue $catalogue --simname $simname --ksim $ksim --ksmooth $ksmooth --ndevice $ndevice --device $device"