Observer velocity script (#120)

* Rename script

* Delete scripts

* Add script

* Edit script

* Add script

* Update nb

* Update plotting

* Update .gitignore

* Update nb

* Update nb

* Add option to keep beta fixed
This commit is contained in:
Richard Stiskalek 2024-03-26 10:42:53 +01:00 committed by GitHub
parent 4093186f9a
commit 27c1f9249b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 361 additions and 723 deletions

View file

@ -951,7 +951,7 @@ class SN_PV_validation_model(BaseFlowValidationModel):
return zobs_mean, zobs_var
def __call__(self, sample_alpha=True, fix_calibration=False):
def __call__(self, sample_alpha=True, sample_beta=True):
"""
The supernova NumPyro PV validation model with SALT2 calibration.
@ -960,38 +960,25 @@ class SN_PV_validation_model(BaseFlowValidationModel):
sample_alpha : bool, optional
Whether to sample the density bias parameter `alpha`, otherwise
it is fixed to 1.
fix_calibration : str, optional
Whether to fix the calibration parameters. If not provided, they
are sampled. If "Foundation" or "LOSS" is provided, the parameters
are fixed to the best inverse parameters for the Foundation or LOSS
catalogues.
sample_beta : bool, optional
Whether to sample the velocity bias parameter `beta`, otherwise
it is fixed to 1.
Returns
-------
None
"""
Vx = numpyro.sample("Vext_x", self._Vext)
Vy = numpyro.sample("Vext_y", self._Vext)
Vz = numpyro.sample("Vext_z", self._Vext)
alpha = numpyro.sample("alpha", self._alpha) if sample_alpha else 1.0
beta = numpyro.sample("beta", self._beta)
beta = numpyro.sample("beta", self._beta) if sample_beta else 1.0
sigma_v = numpyro.sample("sigma_v", self._sigma_v)
if fix_calibration == "Foundation":
# Foundation inverse best parameters
e_mu_intrinsic = 0.064
alpha_cal = 0.135
beta_cal = 2.9
sigma_v = 149
mag_cal = -18.555
elif fix_calibration == "LOSS":
# LOSS inverse best parameters
e_mu_intrinsic = 0.123
alpha_cal = 0.123
beta_cal = 3.52
mag_cal = -18.195
sigma_v = 149
else:
e_mu_intrinsic = numpyro.sample("e_mu_intrinsic", self._e_mu)
mag_cal = numpyro.sample("mag_cal", self._mag_cal)
alpha_cal = numpyro.sample("alpha_cal", self._alpha_cal)
beta_cal = numpyro.sample("beta_cal", self._beta_cal)
e_mu_intrinsic = numpyro.sample("e_mu_intrinsic", self._e_mu)
mag_cal = numpyro.sample("mag_cal", self._mag_cal)
alpha_cal = numpyro.sample("alpha_cal", self._alpha_cal)
beta_cal = numpyro.sample("beta_cal", self._beta_cal)
Vext_rad = project_Vext(Vx, Vy, Vz, self._RA, self._dec)
@ -1168,7 +1155,7 @@ class TF_PV_validation_model(BaseFlowValidationModel):
return zobs_mean, zobs_var
def __call__(self, sample_alpha=True):
def __call__(self, sample_alpha=True, sample_beta=True):
"""
The Tully-Fisher NumPyro PV validation model.
@ -1177,12 +1164,19 @@ class TF_PV_validation_model(BaseFlowValidationModel):
sample_alpha : bool, optional
Whether to sample the density bias parameter `alpha`, otherwise
it is fixed to 1.
sample_beta : bool, optional
Whether to sample the velocity bias parameter `beta`, otherwise
it is fixed to 1.
Returns
-------
None
"""
Vx = numpyro.sample("Vext_x", self._Vext)
Vy = numpyro.sample("Vext_y", self._Vext)
Vz = numpyro.sample("Vext_z", self._Vext)
alpha = numpyro.sample("alpha", self._alpha) if sample_alpha else 1.0
beta = numpyro.sample("beta", self._beta)
beta = numpyro.sample("beta", self._beta) if sample_beta else 1.0
sigma_v = numpyro.sample("sigma_v", self._sigma_v)
e_mu_intrinsic = numpyro.sample("e_mu_intrinsic", self._e_mu)
@ -1291,7 +1285,7 @@ def get_model(loader, zcmb_max=None, verbose=True):
###############################################################################
def sample_prior(model, seed, sample_alpha, as_dict=False):
def sample_prior(model, seed, model_kwargs, as_dict=False):
"""
Sample a single set of parameters from the prior of the model.
@ -1301,8 +1295,8 @@ def sample_prior(model, seed, sample_alpha, as_dict=False):
NumPyro model.
seed : int
Random seed.
sample_alpha : bool
Whether to sample the density bias parameter `alpha`.
model_kwargs : dict
Additional keyword arguments to pass to the model.
as_dict : bool, optional
Whether to return the parameters as a dictionary or a list of
parameters.
@ -1314,7 +1308,7 @@ def sample_prior(model, seed, sample_alpha, as_dict=False):
only a dictionary.
"""
predictive = Predictive(model, num_samples=1)
samples = predictive(PRNGKey(seed), sample_alpha=sample_alpha)
samples = predictive(PRNGKey(seed), **model_kwargs)
if as_dict:
return samples
@ -1327,7 +1321,7 @@ def sample_prior(model, seed, sample_alpha, as_dict=False):
return x, keys
def make_loss(model, keys, sample_alpha=True, to_jit=True):
def make_loss(model, keys, model_kwargs, to_jit=True):
"""
Generate a loss function for the NumPyro model, that is the negative
log-likelihood. Note that this loss function cannot be automatically
@ -1339,8 +1333,8 @@ def make_loss(model, keys, sample_alpha=True, to_jit=True):
NumPyro model.
keys : list
List of parameter names.
sample_alpha : bool, optional
Whether to sample the density bias parameter `alpha`.
model_kwargs : dict
Additional keyword arguments to pass to the model.
to_jit : bool, optional
Whether to JIT the loss function.
@ -1353,8 +1347,7 @@ def make_loss(model, keys, sample_alpha=True, to_jit=True):
def f(x):
samples = {key: x[i] for i, key in enumerate(keys)}
loss = -util.log_likelihood(
model, samples, sample_alpha=sample_alpha)["ll"]
loss = -util.log_likelihood(model, samples, **model_kwargs)["ll"]
loss += cond(samples["sigma_v"] > 0, lambda: 0., lambda: jnp.inf)
loss += cond(samples["e_mu_intrinsic"] > 0, lambda: 0., lambda: jnp.inf) # noqa