Pass marg arg

This commit is contained in:
rstiskalek 2024-10-07 00:56:59 +01:00
parent c81d39ee81
commit 05123ec868

View file

@ -354,8 +354,8 @@ def mask_fields(density, velocity, mask, return_none):
def get_model(loader, zcmb_min=None, zcmb_max=None, mag_selection=None,
absolute_calibration=None, calibration_fpath=None,
void_kwargs=None):
wo_num_dist_marginalisation=False, absolute_calibration=None,
calibration_fpath=None, void_kwargs=None):
"""
Get a model and extract the relevant data from the loader.
@ -369,9 +369,14 @@ def get_model(loader, zcmb_min=None, zcmb_max=None, mag_selection=None,
Maximum observed redshift in the CMB frame to include.
mag_selection : dict, optional
Magnitude selection parameters.
wo_num_dist_marginalisation : bool, optional
Whether to directly sample the distance without numerical
marginalisation. in which case the tracers can be coupled by a
covariance matrix. By default `False`.
add_absolute_calibration : bool, optional
Whether to add an absolute calibration for CF4 TFRs.
calibration_fpath : str, optional
Path to the file containing the absolute calibration of CF4 TFR.
Returns
-------
@ -418,7 +423,8 @@ def get_model(loader, zcmb_min=None, zcmb_max=None, mag_selection=None,
los_overdensity, los_velocity,
RA[mask], dec[mask], zCMB[mask], e_zCMB, calibration_params,
None, mag_selection, loader.rdist, loader._Omega_m, "SN",
name=kind, void_kwargs=void_kwargs)
name=kind, void_kwargs=void_kwargs,
wo_num_dist_marginalisation=wo_num_dist_marginalisation)
elif "Pantheon+" in kind:
keys = ["RA", "DEC", "zCMB", "mB", "x1", "c", "biasCor_m_b", "mBERR",
"x1ERR", "cERR", "biasCorErr_m_b", "zCMB_SN", "zCMB_Group",
@ -451,7 +457,8 @@ def get_model(loader, zcmb_min=None, zcmb_max=None, mag_selection=None,
los_overdensity, los_velocity,
RA[mask], dec[mask], zCMB[mask], e_zCMB[mask], calibration_params,
None, mag_selection, loader.rdist, loader._Omega_m, "SN",
name=kind, void_kwargs=void_kwargs)
name=kind, void_kwargs=void_kwargs,
wo_num_dist_marginalisation=wo_num_dist_marginalisation)
elif kind in ["SFI_gals", "2MTF", "SFI_gals_masked"]:
keys = ["RA", "DEC", "z_CMB", "mag", "eta", "e_mag", "e_eta"]
RA, dec, zCMB, mag, eta, e_mag, e_eta = (loader.cat[k] for k in keys)
@ -467,7 +474,8 @@ def get_model(loader, zcmb_min=None, zcmb_max=None, mag_selection=None,
los_overdensity, los_velocity,
RA[mask], dec[mask], zCMB[mask], None, calibration_params, None,
mag_selection, loader.rdist, loader._Omega_m, "TFR", name=kind,
void_kwargs=void_kwargs)
void_kwargs=void_kwargs,
wo_num_dist_marginalisation=wo_num_dist_marginalisation)
elif "CF4_TFR_" in kind:
# The full name can be e.g. "CF4_TFR_not2MTForSFI_i" or "CF4_TFR_i".
band = kind.split("_")[-1]
@ -535,7 +543,8 @@ def get_model(loader, zcmb_min=None, zcmb_max=None, mag_selection=None,
los_overdensity, los_velocity,
RA[mask], dec[mask], z_obs[mask], None, calibration_params,
abs_calibration_params, mag_selection, loader.rdist,
loader._Omega_m, "TFR", name=kind, void_kwargs=void_kwargs)
loader._Omega_m, "TFR", name=kind, void_kwargs=void_kwargs,
wo_num_dist_marginalisation=wo_num_dist_marginalisation)
elif kind in ["CF4_GroupAll"]:
# Note, this for some reason works terribly.
keys = ["RA", "DE", "Vcmb", "DMzp", "eDM"]
@ -556,7 +565,8 @@ def get_model(loader, zcmb_min=None, zcmb_max=None, mag_selection=None,
los_overdensity, los_velocity,
RA[mask], dec[mask], zCMB[mask], None, calibration_params, None,
mag_selection, loader.rdist, loader._Omega_m, "simple",
name=kind, void_kwargs=void_kwargs)
name=kind, void_kwargs=void_kwargs,
wo_num_dist_marginalisation=wo_num_dist_marginalisation)
else:
raise ValueError(f"Catalogue `{kind}` not recognized.")