mirror of
https://github.com/Richard-Sti/csiborgtools_public.git
synced 2025-05-13 06:01:13 +00:00
Observer velocity script (#120)
* Rename script * Delete scripts * Add script * Edit script * Add script * Update nb * Update plotting * Update .gitignore * Update nb * Update nb * Add option to keep beta fixed
This commit is contained in:
parent
4093186f9a
commit
27c1f9249b
10 changed files with 361 additions and 723 deletions
|
@ -951,7 +951,7 @@ class SN_PV_validation_model(BaseFlowValidationModel):
|
|||
|
||||
return zobs_mean, zobs_var
|
||||
|
||||
def __call__(self, sample_alpha=True, fix_calibration=False):
|
||||
def __call__(self, sample_alpha=True, sample_beta=True):
|
||||
"""
|
||||
The supernova NumPyro PV validation model with SALT2 calibration.
|
||||
|
||||
|
@ -960,38 +960,25 @@ class SN_PV_validation_model(BaseFlowValidationModel):
|
|||
sample_alpha : bool, optional
|
||||
Whether to sample the density bias parameter `alpha`, otherwise
|
||||
it is fixed to 1.
|
||||
fix_calibration : str, optional
|
||||
Whether to fix the calibration parameters. If not provided, they
|
||||
are sampled. If "Foundation" or "LOSS" is provided, the parameters
|
||||
are fixed to the best inverse parameters for the Foundation or LOSS
|
||||
catalogues.
|
||||
sample_beta : bool, optional
|
||||
Whether to sample the velocity bias parameter `beta`, otherwise
|
||||
it is fixed to 1.
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
Vx = numpyro.sample("Vext_x", self._Vext)
|
||||
Vy = numpyro.sample("Vext_y", self._Vext)
|
||||
Vz = numpyro.sample("Vext_z", self._Vext)
|
||||
alpha = numpyro.sample("alpha", self._alpha) if sample_alpha else 1.0
|
||||
beta = numpyro.sample("beta", self._beta)
|
||||
beta = numpyro.sample("beta", self._beta) if sample_beta else 1.0
|
||||
sigma_v = numpyro.sample("sigma_v", self._sigma_v)
|
||||
|
||||
if fix_calibration == "Foundation":
|
||||
# Foundation inverse best parameters
|
||||
e_mu_intrinsic = 0.064
|
||||
alpha_cal = 0.135
|
||||
beta_cal = 2.9
|
||||
sigma_v = 149
|
||||
mag_cal = -18.555
|
||||
elif fix_calibration == "LOSS":
|
||||
# LOSS inverse best parameters
|
||||
e_mu_intrinsic = 0.123
|
||||
alpha_cal = 0.123
|
||||
beta_cal = 3.52
|
||||
mag_cal = -18.195
|
||||
sigma_v = 149
|
||||
else:
|
||||
e_mu_intrinsic = numpyro.sample("e_mu_intrinsic", self._e_mu)
|
||||
mag_cal = numpyro.sample("mag_cal", self._mag_cal)
|
||||
alpha_cal = numpyro.sample("alpha_cal", self._alpha_cal)
|
||||
beta_cal = numpyro.sample("beta_cal", self._beta_cal)
|
||||
e_mu_intrinsic = numpyro.sample("e_mu_intrinsic", self._e_mu)
|
||||
mag_cal = numpyro.sample("mag_cal", self._mag_cal)
|
||||
alpha_cal = numpyro.sample("alpha_cal", self._alpha_cal)
|
||||
beta_cal = numpyro.sample("beta_cal", self._beta_cal)
|
||||
|
||||
Vext_rad = project_Vext(Vx, Vy, Vz, self._RA, self._dec)
|
||||
|
||||
|
@ -1168,7 +1155,7 @@ class TF_PV_validation_model(BaseFlowValidationModel):
|
|||
|
||||
return zobs_mean, zobs_var
|
||||
|
||||
def __call__(self, sample_alpha=True):
|
||||
def __call__(self, sample_alpha=True, sample_beta=True):
|
||||
"""
|
||||
The Tully-Fisher NumPyro PV validation model.
|
||||
|
||||
|
@ -1177,12 +1164,19 @@ class TF_PV_validation_model(BaseFlowValidationModel):
|
|||
sample_alpha : bool, optional
|
||||
Whether to sample the density bias parameter `alpha`, otherwise
|
||||
it is fixed to 1.
|
||||
sample_beta : bool, optional
|
||||
Whether to sample the velocity bias parameter `beta`, otherwise
|
||||
it is fixed to 1.
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
Vx = numpyro.sample("Vext_x", self._Vext)
|
||||
Vy = numpyro.sample("Vext_y", self._Vext)
|
||||
Vz = numpyro.sample("Vext_z", self._Vext)
|
||||
alpha = numpyro.sample("alpha", self._alpha) if sample_alpha else 1.0
|
||||
beta = numpyro.sample("beta", self._beta)
|
||||
beta = numpyro.sample("beta", self._beta) if sample_beta else 1.0
|
||||
sigma_v = numpyro.sample("sigma_v", self._sigma_v)
|
||||
|
||||
e_mu_intrinsic = numpyro.sample("e_mu_intrinsic", self._e_mu)
|
||||
|
@ -1291,7 +1285,7 @@ def get_model(loader, zcmb_max=None, verbose=True):
|
|||
###############################################################################
|
||||
|
||||
|
||||
def sample_prior(model, seed, sample_alpha, as_dict=False):
|
||||
def sample_prior(model, seed, model_kwargs, as_dict=False):
|
||||
"""
|
||||
Sample a single set of parameters from the prior of the model.
|
||||
|
||||
|
@ -1301,8 +1295,8 @@ def sample_prior(model, seed, sample_alpha, as_dict=False):
|
|||
NumPyro model.
|
||||
seed : int
|
||||
Random seed.
|
||||
sample_alpha : bool
|
||||
Whether to sample the density bias parameter `alpha`.
|
||||
model_kwargs : dict
|
||||
Additional keyword arguments to pass to the model.
|
||||
as_dict : bool, optional
|
||||
Whether to return the parameters as a dictionary or a list of
|
||||
parameters.
|
||||
|
@ -1314,7 +1308,7 @@ def sample_prior(model, seed, sample_alpha, as_dict=False):
|
|||
only a dictionary.
|
||||
"""
|
||||
predictive = Predictive(model, num_samples=1)
|
||||
samples = predictive(PRNGKey(seed), sample_alpha=sample_alpha)
|
||||
samples = predictive(PRNGKey(seed), **model_kwargs)
|
||||
|
||||
if as_dict:
|
||||
return samples
|
||||
|
@ -1327,7 +1321,7 @@ def sample_prior(model, seed, sample_alpha, as_dict=False):
|
|||
return x, keys
|
||||
|
||||
|
||||
def make_loss(model, keys, sample_alpha=True, to_jit=True):
|
||||
def make_loss(model, keys, model_kwargs, to_jit=True):
|
||||
"""
|
||||
Generate a loss function for the NumPyro model, that is the negative
|
||||
log-likelihood. Note that this loss function cannot be automatically
|
||||
|
@ -1339,8 +1333,8 @@ def make_loss(model, keys, sample_alpha=True, to_jit=True):
|
|||
NumPyro model.
|
||||
keys : list
|
||||
List of parameter names.
|
||||
sample_alpha : bool, optional
|
||||
Whether to sample the density bias parameter `alpha`.
|
||||
model_kwargs : dict
|
||||
Additional keyword arguments to pass to the model.
|
||||
to_jit : bool, optional
|
||||
Whether to JIT the loss function.
|
||||
|
||||
|
@ -1353,8 +1347,7 @@ def make_loss(model, keys, sample_alpha=True, to_jit=True):
|
|||
def f(x):
|
||||
samples = {key: x[i] for i, key in enumerate(keys)}
|
||||
|
||||
loss = -util.log_likelihood(
|
||||
model, samples, sample_alpha=sample_alpha)["ll"]
|
||||
loss = -util.log_likelihood(model, samples, **model_kwargs)["ll"]
|
||||
|
||||
loss += cond(samples["sigma_v"] > 0, lambda: 0., lambda: jnp.inf)
|
||||
loss += cond(samples["e_mu_intrinsic"] > 0, lambda: 0., lambda: jnp.inf) # noqa
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue