Combine PV samples (#139)

* Update imports

* Update submission script

* Update script

* Add simulataenous sampling of many catalogues

* Update nb
This commit is contained in:
Richard Stiskalek 2024-07-30 17:02:48 +01:00 committed by GitHub
parent 9756175943
commit 3b46f17ead
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 356 additions and 147 deletions

View file

@ -12,8 +12,8 @@
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from .flow_model import (DataLoader, radial_velocity_los, dist2redshift, # noqa
dist2distmodulus, predict_zobs, project_Vext, # noqa
PV_validation_model, get_model, distmodulus2dist, # noqa
Observed2CosmologicalRedshift, # noqa
stack_pzosmo_over_realizations) # noqa
from .flow_model import (DataLoader, PV_LogLikelihood, PV_validation_model, # noqa
dist2distmodulus, dist2redshift, distmodulus2dist, # noqa
get_model, Observed2CosmologicalRedshift, # noqa
predict_zobs, project_Vext, # noqa
radial_velocity_los, stack_pzosmo_over_realizations) # noqa

View file

@ -24,15 +24,14 @@ References
from abc import ABC, abstractmethod
import numpy as np
import numpyro
from astropy.cosmology import FlatLambdaCDM, z_at_value
from astropy import units as u
from astropy.cosmology import FlatLambdaCDM, z_at_value
from h5py import File
from jax import jit
from jax import numpy as jnp
from jax import vmap
from jax.scipy.special import logsumexp
from numpyro import sample
from numpyro import deterministic, factor, sample
from numpyro.distributions import Normal, Uniform
from quadax import simpson
from scipy.interpolate import interp1d
@ -558,15 +557,23 @@ def e2_distmod_SN(e2_mB, e2_x1, e2_c, alpha_cal, beta_cal, e_mu_intrinsic):
def sample_SN(e_mu_min, e_mu_max, mag_cal_mean, mag_cal_std, alpha_cal_mean,
alpha_cal_std, beta_cal_mean, beta_cal_std):
alpha_cal_std, beta_cal_mean, beta_cal_std, alpha_min, alpha_max,
sample_alpha, name):
"""Sample SNIe Tripp parameters."""
e_mu = sample("e_mu", Uniform(e_mu_min, e_mu_max))
mag_cal = sample("mag_cal", Normal(mag_cal_mean, mag_cal_std))
alpha_cal = sample("alpha_cal", Normal(alpha_cal_mean, alpha_cal_std))
e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max))
mag_cal = sample(f"mag_cal_{name}", Normal(mag_cal_mean, mag_cal_std))
alpha_cal = sample(
f"alpha_cal_{name}", Normal(alpha_cal_mean, alpha_cal_std))
beta_cal = sample(f"beta_cal_{name}", Normal(beta_cal_mean, beta_cal_std))
alpha = sample(f"alpha_{name}",
Uniform(alpha_min, alpha_max)) if sample_alpha else 1.0
beta_cal = sample("beta_cal", Normal(beta_cal_mean, beta_cal_std))
return e_mu, mag_cal, alpha_cal, beta_cal
return {"e_mu": e_mu,
"mag_cal": mag_cal,
"alpha_cal": alpha_cal,
"beta_cal": beta_cal,
"alpha": alpha
}
###############################################################################
@ -583,34 +590,42 @@ def e2_distmod_TFR(e2_mag, e2_eta, b, e_mu_intrinsic):
return e2_mag + b**2 * e2_eta + e_mu_intrinsic**2
def sample_TFR(e_mu_min, e_mu_max, a_mean, a_std, b_mean, b_std):
def sample_TFR(e_mu_min, e_mu_max, a_mean, a_std, b_mean, b_std, alpha_min,
alpha_max, sample_alpha, name):
"""Sample Tully-Fisher calibration parameters."""
e_mu = sample("e_mu", Uniform(e_mu_min, e_mu_max))
a = sample("a", Normal(a_mean, a_std))
b = sample("b", Normal(b_mean, b_std))
return e_mu, a, b
e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max))
a = sample(f"a_{name}", Normal(a_mean, a_std))
b = sample(f"b_{name}", Normal(b_mean, b_std))
alpha = sample(f"alpha_{name}",
Uniform(alpha_min, alpha_max)) if sample_alpha else 1.0
return {"e_mu": e_mu,
"a": a,
"b": b,
"alpha": alpha
}
###############################################################################
# Calibration parameters sampling #
###############################################################################
def sample_calibration(Vext_min, Vext_max, Vmono_min, Vmono_max,
alpha_min, alpha_max, beta_min, beta_max, sigma_v_min,
sigma_v_max, sample_Vmono, sample_alpha, sample_beta,
sample_sigma_v_ext):
def sample_calibration(Vext_min, Vext_max, Vmono_min, Vmono_max, beta_min,
beta_max, sigma_v_min, sigma_v_max, sample_Vmono,
sample_beta, sample_sigma_v_ext):
"""Sample the flow calibration."""
Vext = sample("Vext", Uniform(Vext_min, Vext_max).expand([3]))
sigma_v = sample("sigma_v", Uniform(sigma_v_min, sigma_v_max))
alpha = sample("alpha", Uniform(alpha_min, alpha_max)) if sample_alpha else 1.0 # noqa
beta = sample("beta", Uniform(beta_min, beta_max)) if sample_beta else 1.0 # noqa
Vmono = sample("Vmono", Uniform(Vmono_min, Vmono_max)) if sample_Vmono else 0.0 # noqa
sigma_v_ext = sample("sigma_v_ext", Uniform(sigma_v_min, sigma_v_max)) if sample_sigma_v_ext else sigma_v # noqa
return Vext, Vmono, sigma_v, sigma_v_ext, alpha, beta
return {"Vext": Vext,
"Vmono": Vmono,
"sigma_v": sigma_v,
"sigma_v_ext": sigma_v_ext,
"beta": beta}
###############################################################################
@ -635,9 +650,9 @@ def find_extrap_mask(rmax, rdist):
return extrap_mask, extrap_weights
class PV_validation_model(BaseFlowValidationModel):
class PV_LogLikelihood(BaseFlowValidationModel):
"""
Peculiar velocity validation model.
Peculiar velocity validation model log-likelihood.
Parameters
----------
@ -663,7 +678,7 @@ class PV_validation_model(BaseFlowValidationModel):
"""
def __init__(self, los_density, los_velocity, rmax, RA, dec, z_obs,
e_zobs, calibration_params, r_xrange, Omega_m, kind):
e_zobs, calibration_params, r_xrange, Omega_m, kind, name):
if e_zobs is not None:
e2_cz_obs = jnp.asarray((SPEED_OF_LIGHT * e_zobs)**2)
else:
@ -681,6 +696,7 @@ class PV_validation_model(BaseFlowValidationModel):
self._set_radial_spacing(r_xrange, Omega_m)
self.kind = kind
self.name = name
self.Omega_m = Omega_m
self.norm = - self.ndata * jnp.log(self.num_sims)
@ -688,29 +704,37 @@ class PV_validation_model(BaseFlowValidationModel):
self.extrap_mask = jnp.asarray(extrap_mask)
self.extrap_weights = jnp.asarray(extrap_weights)
def __call__(self, calibration_hyperparams, distmod_hyperparams,
store_ll_all=False):
"""NumPyro PV validation model."""
Vext, Vmono, sigma_v, sigma_v_ext, alpha, beta = sample_calibration(**calibration_hyperparams) # noqa
def __call__(self, field_calibration_params, distmod_params,
sample_sigma_v_ext):
"""PV validation model log-likelihood."""
# Turn e2_cz to be of shape (nsims, ndata, nxrange) and apply
# sigma_v_ext where applicable
sigma_v = field_calibration_params["sigma_v"]
sigma_v_ext = field_calibration_params["sigma_v_ext"]
e2_cz = jnp.full_like(self.extrap_mask, sigma_v**2, dtype=jnp.float32)
if calibration_hyperparams["sample_sigma_v_ext"]:
if sample_sigma_v_ext:
e2_cz = e2_cz.at[self.extrap_mask].set(sigma_v_ext**2)
# Now add the observational errors
e2_cz += self.e2_cz_obs[None, :, None]
Vext = field_calibration_params["Vext"]
Vmono = field_calibration_params["Vmono"]
Vext_rad = project_Vext(Vext[0], Vext[1], Vext[2], self.RA, self.dec)
e_mu = distmod_params["e_mu"]
alpha = distmod_params["alpha"]
if self.kind == "SN":
e_mu, mag_cal, alpha_cal, beta_cal = sample_SN(**distmod_hyperparams) # noqa
mag_cal = distmod_params["mag_cal"]
alpha_cal = distmod_params["alpha_cal"]
beta_cal = distmod_params["beta_cal"]
mu = distmod_SN(
self.mB, self.x1, self.c, mag_cal, alpha_cal, beta_cal)
squared_e_mu = e2_distmod_SN(
self.e2_mB, self.e2_x1, self.e2_c, alpha_cal, beta_cal, e_mu)
elif self.kind == "TFR":
e_mu, a, b = sample_TFR(**distmod_hyperparams)
a = distmod_params["a"]
b = distmod_params["b"]
mu = distmod_TFR(self.mag, self.eta, a, b)
squared_e_mu = e2_distmod_TFR(self.e2_mag, self.e2_eta, b, e_mu)
else:
@ -725,20 +749,51 @@ class PV_validation_model(BaseFlowValidationModel):
# Calculate z_obs at each distance. Shape is (n_sims, ndata, nxrange)
# The weights are related to the extrapolation of the velocity field.
vrad = beta * self.los_velocity
vrad = field_calibration_params["beta"] * self.los_velocity
vrad += (Vext_rad[None, :, None] + Vmono) * self.extrap_weights
zobs = (1 + self.z_xrange[None, None, :]) * (1 + vrad / SPEED_OF_LIGHT) - 1 # noqa
ptilde *= calculate_likelihood_zobs(self.z_obs, zobs, e2_cz)
# ptilde *= calculate_likelihood_zobs(self.z_obs, zobs, sigma_v)
ll = jnp.log(simpson(ptilde, dx=self.dr, axis=-1)) - jnp.log(pnorm)
if store_ll_all:
numpyro.deterministic("ll_all", ll)
return jnp.sum(logsumexp(ll, axis=0)) + self.norm
ll = jnp.sum(logsumexp(ll, axis=0)) + self.norm
numpyro.deterministic("ll_values", ll)
numpyro.factor("ll", ll)
def PV_validation_model(models, distmod_hyperparams_per_model,
field_calibration_hyperparams):
"""
Peculiar velocity validation NumPyro model.
Parameters
----------
models : list of `PV_LogLikelihood`
List of PV validation log-likelihoods for each catalogue.
distmod_hyperparams_per_model : list of dict
Distance modulus hyperparameters for each model/catalogue.
field_calibration_hyperparams : dict
Field calibration hyperparameters.
"""
field_calibration_params = sample_calibration(
**field_calibration_hyperparams)
sample_sigma_v_ext = field_calibration_hyperparams["sample_sigma_v_ext"]
ll = 0.0
for n in range(len(models)):
model = models[n]
distmod_hyperparams = distmod_hyperparams_per_model[n]
if model.kind == "TFR":
distmod_params = sample_TFR(**distmod_hyperparams, name=model.name)
elif model.kind == "SN":
distmod_params = sample_SN(**distmod_hyperparams, name=model.name)
else:
raise ValueError(f"Unknown kind: `{model.kind}`.")
ll += model(
field_calibration_params, distmod_params, sample_sigma_v_ext)
deterministic("ll_values", ll)
factor("ll", ll)
###############################################################################
@ -780,10 +835,10 @@ def get_model(loader, zcmb_max=None, verbose=True):
"e_mB": e_mB[mask], "e_x1": e_x1[mask],
"e_c": e_c[mask]}
model = PV_validation_model(
model = PV_LogLikelihood(
los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask],
RA[mask], dec[mask], zCMB[mask], e_zCMB, calibration_params,
loader.rdist, loader._Omega_m, "SN")
loader.rdist, loader._Omega_m, "SN", name=kind)
elif "Pantheon+" in kind:
keys = ["RA", "DEC", "zCMB", "mB", "x1", "c", "biasCor_m_b", "mBERR",
"x1ERR", "cERR", "biasCorErr_m_b", "zCMB_SN", "zCMB_Group",
@ -808,10 +863,10 @@ def get_model(loader, zcmb_max=None, verbose=True):
calibration_params = {"mB": mB[mask], "x1": x1[mask], "c": c[mask],
"e_mB": e_mB[mask], "e_x1": e_x1[mask],
"e_c": e_c[mask]}
model = PV_validation_model(
model = PV_LogLikelihood(
los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask],
RA[mask], dec[mask], zCMB[mask], e_zCMB[mask], calibration_params,
loader.rdist, loader._Omega_m, "SN")
loader.rdist, loader._Omega_m, "SN", name=kind)
elif kind in ["SFI_gals", "2MTF", "SFI_gals_masked"]:
keys = ["RA", "DEC", "z_CMB", "mag", "eta", "e_mag", "e_eta"]
RA, dec, zCMB, mag, eta, e_mag, e_eta = (loader.cat[k] for k in keys)
@ -824,14 +879,14 @@ def get_model(loader, zcmb_max=None, verbose=True):
calibration_params = {"mag": mag[mask], "eta": eta[mask],
"e_mag": e_mag[mask], "e_eta": e_eta[mask]}
model = PV_validation_model(
model = PV_LogLikelihood(
los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask],
RA[mask], dec[mask], zCMB[mask], None, calibration_params,
loader.rdist, loader._Omega_m, "TFR")
loader.rdist, loader._Omega_m, "TFR", name=kind)
else:
raise ValueError(f"Catalogue `{kind}` not recognized.")
fprint(f"selected {np.sum(mask)}/{len(mask)} galaxies.")
fprint(f"selected {np.sum(mask)}/{len(mask)} galaxies in catalogue `{kind}`") # noqa
return model
@ -848,9 +903,12 @@ def _posterior_element(r, beta, Vext_radial, los_velocity, Omega_m, zobs,
`Observed2CosmologicalRedshift`.
"""
zobs_pred = predict_zobs(r, beta, Vext_radial, los_velocity, Omega_m)
# Likelihood term
dcz = SPEED_OF_LIGHT * (zobs - zobs_pred)
posterior = jnp.exp(-0.5 * dcz**2 / sigma_v**2) / jnp.sqrt(2 * jnp.pi * sigma_v**2) # noqa
posterior = jnp.exp(-0.5 * dcz**2 / sigma_v**2)
posterior /= jnp.sqrt(2 * jnp.pi * sigma_v**2)
# Prior term
posterior *= dVdOmega * los_density**alpha
@ -864,7 +922,8 @@ class BaseObserved2CosmologicalRedshift(ABC):
for i, key in enumerate(calibration_samples.keys()):
x = calibration_samples[key]
if not isinstance(x, (np.ndarray, jnp.ndarray)):
raise ValueError(f"Calibration sample `{key}` must be an array.") # noqa
raise ValueError(
f"Calibration sample `{key}` must be an array.")
if x.ndim != 1 and key != "Vext":
raise ValueError(f"Calibration samples `{key}` must be 1D.")
@ -873,14 +932,19 @@ class BaseObserved2CosmologicalRedshift(ABC):
ncalibratrion = len(x)
if len(x) != ncalibratrion:
raise ValueError("Calibration samples do not have the same length.") # noqa
raise ValueError(
"Calibration samples do not have the same length.")
calibration_samples[key] = jnp.asarray(x)
if "alpha" not in calibration_samples:
print("No `alpha` calibration sample found. Setting it to 1.",
flush=True)
calibration_samples["alpha"] = jnp.ones(ncalibratrion)
if "beta" not in calibration_samples:
print("No `beta` calibration sample found. Setting it to 1.",
flush=True)
calibration_samples["beta"] = jnp.ones(ncalibratrion)
# Get the stepsize, we need it to be constant for Simpson's rule.
@ -898,7 +962,8 @@ class BaseObserved2CosmologicalRedshift(ABC):
def get_calibration_samples(self, key):
"""Get calibration samples for a given key."""
if key not in self._calibration_samples:
raise ValueError(f"Key `{key}` not found in calibration samples. Available keys are: `{self.calibration_keys}`.") # noqa
raise ValueError(f"Key `{key}` not found in calibration samples. "
f"Available keys are: `{self.calibration_keys}`.")
return self._calibration_samples[key]

File diff suppressed because one or more lines are too long

View file

@ -33,7 +33,7 @@ def parse_args():
parser.add_argument("--simname", type=str, required=True,
help="Simulation name.")
parser.add_argument("--catalogue", type=str, required=True,
help="PV catalogue.")
help="PV catalogues.")
parser.add_argument("--ksmooth", type=int, default=1,
help="Smoothing index.")
parser.add_argument("--ksim", type=none_or_int, default=None,
@ -42,7 +42,12 @@ def parse_args():
help="Number of devices to request.")
parser.add_argument("--device", type=str, default="cpu",
help="Device to use.")
return parser.parse_args()
args = parser.parse_args()
# Convert the catalogue to a list of catalogues
args.catalogue = args.catalogue.split(",")
return args
ARGS = parse_args()
@ -58,8 +63,7 @@ from os.path import join
import csiborgtools # noqa
import jax # noqa
from h5py import File # noqa
from mpi4py import MPI # noqa
from numpyro.infer import MCMC, NUTS, Predictive, init_to_median # noqa
from numpyro.infer import MCMC, NUTS, init_to_median # noqa
def print_variables(names, variables):
@ -68,8 +72,8 @@ def print_variables(names, variables):
print(flush=True)
def get_model(paths, get_model_kwargs, verbose=True):
"""Load the data and create the NumPyro model."""
def get_models(get_model_kwargs, verbose=True):
"""Load the data and create the NumPyro models."""
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
folder = "/mnt/extraspace/rstiskalek/catalogs/"
@ -86,55 +90,45 @@ def get_model(paths, get_model_kwargs, verbose=True):
print(f"{'Num. realisations:':<20} {len(nsims)}")
print(flush=True)
if ARGS.catalogue == "A2":
fpath = join(folder, "A2.h5")
elif ARGS.catalogue in ["LOSS", "Foundation", "Pantheon+", "SFI_gals",
"2MTF", "SFI_groups", "SFI_gals_masked",
"Pantheon+_groups", "Pantheon+_groups_zSN",
"Pantheon+_zSN"]:
fpath = join(folder, "PV_compilation.hdf5")
else:
raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.")
# Get models
models = [None] * len(ARGS.catalogue)
for i, cat in enumerate(ARGS.catalogue):
if cat == "A2":
fpath = join(folder, "A2.h5")
elif cat in ["LOSS", "Foundation", "Pantheon+", "SFI_gals",
"2MTF", "SFI_groups", "SFI_gals_masked",
"Pantheon+_groups", "Pantheon+_groups_zSN",
"Pantheon+_zSN"]:
fpath = join(folder, "PV_compilation.hdf5")
else:
raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.")
loader = csiborgtools.flow.DataLoader(ARGS.simname, nsim_iterator,
ARGS.catalogue, fpath, paths,
ksmooth=ARGS.ksmooth)
loader = csiborgtools.flow.DataLoader(ARGS.simname, nsim_iterator,
cat, fpath, paths,
ksmooth=ARGS.ksmooth)
models[i] = csiborgtools.flow.get_model(loader, **get_model_kwargs)
print(f"\n{'Num. radial steps':<20} {len(loader.rdist)}\n", flush=True)
return csiborgtools.flow.get_model(loader, **get_model_kwargs)
return models
def get_harmonic_evidence(samples, log_posterior, nchains_harmonic, epoch_num):
"""Compute evidence using the `harmonic` package."""
data, names = csiborgtools.dict_samples_to_array(samples)
data = data.reshape(nchains_harmonic, -1, len(names))
log_posterior = log_posterior.reshape(10, -1)
log_posterior = log_posterior.reshape(nchains_harmonic, -1)
return csiborgtools.harmonic_evidence(
data, log_posterior, return_flow_samples=False, epochs_num=epoch_num)
def get_simulation_weights(samples, model, model_kwargs):
"""Get the weights per posterior samples for each simulation."""
predictive = Predictive(model, samples)
ll_all = predictive(
jax.random.PRNGKey(1), store_ll_all=True, **model_kwargs)["ll_all"]
# Multiply the likelihood of galaxies
ll_per_simulation = jax.numpy.sum(ll_all, axis=-1)
# Normalization by summing the likelihood over simulations
norm = jax.scipy.special.logsumexp(ll_per_simulation, axis=-1)
return ll_per_simulation - norm[:, None]
def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
calculate_evidence, nchains_harmonic, epoch_num, kwargs_print):
"""Run the NumPyro model and save output to a file."""
try:
ndata = model.ndata
ndata = sum(model.ndata for model in model_kwargs["models"])
except AttributeError as e:
raise AttributeError("The model must have an attribute `ndata` "
raise AttributeError("The models must have an attribute `ndata` "
"indicating the number of data points.") from e
nuts_kernel = NUTS(model, init_strategy=init_to_median(num_samples=1000))
@ -143,7 +137,6 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
mcmc.run(rng_key, extra_fields=("potential_energy",), **model_kwargs)
samples = mcmc.get_samples()
simulation_weights = get_simulation_weights(samples, model, model_kwargs)
log_posterior = -mcmc.get_extra_fields()["potential_energy"]
log_likelihood = samples.pop("ll_values")
@ -165,7 +158,7 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
neg_ln_evidence = jax.numpy.nan
neg_ln_evidence_err = (jax.numpy.nan, jax.numpy.nan)
fname = f"samples_{ARGS.simname}_{ARGS.catalogue}_ksmooth{ARGS.ksmooth}.hdf5" # noqa
fname = f"samples_{ARGS.simname}_{'+'.join(ARGS.catalogue)}_ksmooth{ARGS.ksmooth}.hdf5" # noqa
if ARGS.ksim is not None:
fname = fname.replace(".hdf5", f"_nsim{ARGS.ksim}.hdf5")
@ -183,7 +176,6 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
# Write log likelihood and posterior
f.create_dataset("log_likelihood", data=log_likelihood)
f.create_dataset("log_posterior", data=log_posterior)
f.create_dataset("simulation_weights", data=simulation_weights)
# Write goodness of fit
grp = f.create_group("gof")
@ -215,6 +207,29 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
# Command line interface #
###############################################################################
def get_distmod_hyperparams(catalogue):
alpha_min = -1.0
alpha_max = 3.0
sample_alpha = True
if catalogue in ["LOSS", "Foundation", "Pantheon+", "Pantheon+_groups", "Pantheon+_zSN"]: # noqa
return {"e_mu_min": 0.001, "e_mu_max": 1.0,
"mag_cal_mean": -18.25, "mag_cal_std": 2.0,
"alpha_cal_mean": 0.148, "alpha_cal_std": 1.0,
"beta_cal_mean": 3.112, "beta_cal_std": 2.0,
"alpha_min": alpha_min, "alpha_max": alpha_max,
"sample_alpha": sample_alpha
}
elif catalogue in ["SFI_gals", "2MTF"]:
return {"e_mu_min": 0.001, "e_mu_max": 1.0,
"a_mean": -21., "a_std": 5.0,
"b_mean": -5.95, "b_std": 3.0,
"alpha_min": alpha_min, "alpha_max": alpha_max,
"sample_alpha": sample_alpha
}
else:
raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.")
if __name__ == "__main__":
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
@ -227,14 +242,15 @@ if __name__ == "__main__":
###########################################################################
nsteps = 5000
nburn = 1000
zcmb_max = 0.06
nburn = 1500
zcmb_max = 0.05
calculate_evidence = False
nchains_harmonic = 10
num_epochs = 30
if nsteps % nchains_harmonic != 0:
raise ValueError("The number of steps must be divisible by the number of chains.") # noqa
raise ValueError(
"The number of steps must be divisible by the number of chains.")
main_params = {"nsteps": nsteps, "nburn": nburn, "zcmb_max": zcmb_max,
"calculate_evidence": calculate_evidence,
@ -244,42 +260,36 @@ if __name__ == "__main__":
calibration_hyperparams = {"Vext_min": -1000, "Vext_max": 1000,
"Vmono_min": -1000, "Vmono_max": 1000,
"alpha_min": -1.0, "alpha_max": 3.0,
"beta_min": -1.0, "beta_max": 3.0,
"sigma_v_min": 1.0, "sigma_v_max": 750.,
"sample_Vmono": False,
"sample_alpha": True,
"sample_beta": True,
"sample_sigma_v_ext": False,
}
print_variables(
calibration_hyperparams.keys(), calibration_hyperparams.values())
if ARGS.catalogue in ["LOSS", "Foundation", "Pantheon+", "Pantheon+_groups", "Pantheon+_zSN"]: # noqa
distmod_hyperparams = {"e_mu_min": 0.001, "e_mu_max": 1.0,
"mag_cal_mean": -18.25, "mag_cal_std": 2.0,
"alpha_cal_mean": 0.148, "alpha_cal_std": 1.0,
"beta_cal_mean": 3.112, "beta_cal_std": 2.0,
}
elif ARGS.catalogue in ["SFI_gals", "2MTF"]:
distmod_hyperparams = {"e_mu_min": 0.001, "e_mu_max": 1.0,
"a_mean": -21., "a_std": 5.0,
"b_mean": -5.95, "b_std": 3.0,
}
else:
raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.")
distmod_hyperparams_per_catalogue = []
for cat in ARGS.catalogue:
x = get_distmod_hyperparams(cat)
print(f"\n{cat} hyperparameters:")
print_variables(x.keys(), x.values())
distmod_hyperparams_per_catalogue.append(x)
print_variables(
distmod_hyperparams.keys(), distmod_hyperparams.values())
kwargs_print = (main_params, calibration_hyperparams, distmod_hyperparams)
kwargs_print = (main_params, calibration_hyperparams,
*distmod_hyperparams_per_catalogue)
###########################################################################
model_kwargs = {"calibration_hyperparams": calibration_hyperparams,
"distmod_hyperparams": distmod_hyperparams}
get_model_kwargs = {"zcmb_max": zcmb_max}
models = get_models(get_model_kwargs, )
model_kwargs = {
"models": models,
"field_calibration_hyperparams": calibration_hyperparams,
"distmod_hyperparams_per_model": distmod_hyperparams_per_catalogue,
}
model = csiborgtools.flow.PV_validation_model
model = get_model(paths, get_model_kwargs, )
run_model(model, nsteps, nburn, model_kwargs, out_folder,
calibration_hyperparams["sample_beta"], calculate_evidence,
nchains_harmonic, num_epochs, kwargs_print)

View file

@ -1,44 +1,66 @@
#!/bin/bash
memory=8
memory=7
on_login=${1}
queue=${2}
ndevice=1
device="gpu"
queue="gpulong"
gputype="rtx2080with12gb"
env="/mnt/users/rstiskalek/csiborgtools/venv_gpu_csiborgtools/bin/python"
file="flow_validation.py"
ksmooth=0
if [ "$on_login" != "1" ] && [ "$on_login" != "0" ]; then
echo "Invalid input: 'on_login' (1). Please provide 1 or 0."
if [ "$on_login" != "1" ] && [ "$on_login" != "0" ]
then
echo "'on_login' (1) must be either 0 or 1."
exit 1
fi
if [ "$queue" != "redwood" ] && [ "$queue" != "berg" ] && [ "$queue" != "cmb" ] && [ "$queue" != "gpulong" ] && [ "$queue" != "cmbgpu" ]; then
echo "Invalid queue: $queue (2). Please provide one of 'redwood', 'berg', 'cmb', 'gpulong', 'cmbgpu'."
exit 1
fi
# Submit a job for each combination of simname, catalogue, ksim
if [ "$queue" == "gpulong" ]
then
device="gpu"
gputype="rtx2080with12gb"
# gputype="rtx3070with8gb"
env="/mnt/users/rstiskalek/csiborgtools/venv_gpu_csiborgtools/bin/python"
elif [ "$queue" == "cmbgpu" ]
then
device="gpu"
gputype="rtx3090with24gb"
env="/mnt/users/rstiskalek/csiborgtools/venv_gpu_csiborgtools/bin/python"
else
device="cpu"
env="/mnt/users/rstiskalek/csiborgtools/venv_csiborg/bin/python"
fi
# for simname in "Lilow2024" "CF4" "CF4gp" "csiborg1" "csiborg2_main" "csiborg2X"; do
for simname in "Carrick2015"; do
# for simname in "csiborg1" "csiborg2_main" "csiborg2X"; do
for catalogue in "Foundation"; do
# for catalogue in "2MTF"; do
# for ksim in 0 1 2; do
for catalogue in "LOSS,2MTF,SFI_gals"; do
# for ksim in 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20; do
for ksim in "none"; do
pythoncm="$env $file --catalogue $catalogue --simname $simname --ksim $ksim --ksmooth $ksmooth --ndevice $ndevice --device $device"
if [ $on_login -eq 1 ]; then
if [ "$on_login" == "1" ]; then
echo $pythoncm
$pythoncm
eval $pythoncm
else
cm="addqueue -q $queue -s -m $memory --gpus 1 --gputype $gputype $pythoncm"
if [ "$device" == "gpu" ]; then
cm="addqueue -q $queue -s -m $memory --gpus 1 --gputype $gputype $pythoncm"
else
cm="addqueue -s -q $queue -n 1 -m $memory $pythoncm"
fi
echo "Submitting:"
echo $cm
eval $cm
fi
echo
sleep 0.001
done
done
done