mirror of
https://github.com/Richard-Sti/csiborgtools.git
synced 2024-12-22 11:58:02 +00:00
Combine PV samples (#139)
* Update imports * Update submission script * Update script * Add simulataenous sampling of many catalogues * Update nb
This commit is contained in:
parent
9756175943
commit
3b46f17ead
5 changed files with 356 additions and 147 deletions
|
@ -12,8 +12,8 @@
|
|||
# You should have received a copy of the GNU General Public License along
|
||||
# with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
|
||||
from .flow_model import (DataLoader, radial_velocity_los, dist2redshift, # noqa
|
||||
dist2distmodulus, predict_zobs, project_Vext, # noqa
|
||||
PV_validation_model, get_model, distmodulus2dist, # noqa
|
||||
Observed2CosmologicalRedshift, # noqa
|
||||
stack_pzosmo_over_realizations) # noqa
|
||||
from .flow_model import (DataLoader, PV_LogLikelihood, PV_validation_model, # noqa
|
||||
dist2distmodulus, dist2redshift, distmodulus2dist, # noqa
|
||||
get_model, Observed2CosmologicalRedshift, # noqa
|
||||
predict_zobs, project_Vext, # noqa
|
||||
radial_velocity_los, stack_pzosmo_over_realizations) # noqa
|
||||
|
|
|
@ -24,15 +24,14 @@ References
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
import numpy as np
|
||||
import numpyro
|
||||
from astropy.cosmology import FlatLambdaCDM, z_at_value
|
||||
from astropy import units as u
|
||||
from astropy.cosmology import FlatLambdaCDM, z_at_value
|
||||
from h5py import File
|
||||
from jax import jit
|
||||
from jax import numpy as jnp
|
||||
from jax import vmap
|
||||
from jax.scipy.special import logsumexp
|
||||
from numpyro import sample
|
||||
from numpyro import deterministic, factor, sample
|
||||
from numpyro.distributions import Normal, Uniform
|
||||
from quadax import simpson
|
||||
from scipy.interpolate import interp1d
|
||||
|
@ -558,15 +557,23 @@ def e2_distmod_SN(e2_mB, e2_x1, e2_c, alpha_cal, beta_cal, e_mu_intrinsic):
|
|||
|
||||
|
||||
def sample_SN(e_mu_min, e_mu_max, mag_cal_mean, mag_cal_std, alpha_cal_mean,
|
||||
alpha_cal_std, beta_cal_mean, beta_cal_std):
|
||||
alpha_cal_std, beta_cal_mean, beta_cal_std, alpha_min, alpha_max,
|
||||
sample_alpha, name):
|
||||
"""Sample SNIe Tripp parameters."""
|
||||
e_mu = sample("e_mu", Uniform(e_mu_min, e_mu_max))
|
||||
mag_cal = sample("mag_cal", Normal(mag_cal_mean, mag_cal_std))
|
||||
alpha_cal = sample("alpha_cal", Normal(alpha_cal_mean, alpha_cal_std))
|
||||
e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max))
|
||||
mag_cal = sample(f"mag_cal_{name}", Normal(mag_cal_mean, mag_cal_std))
|
||||
alpha_cal = sample(
|
||||
f"alpha_cal_{name}", Normal(alpha_cal_mean, alpha_cal_std))
|
||||
beta_cal = sample(f"beta_cal_{name}", Normal(beta_cal_mean, beta_cal_std))
|
||||
alpha = sample(f"alpha_{name}",
|
||||
Uniform(alpha_min, alpha_max)) if sample_alpha else 1.0
|
||||
|
||||
beta_cal = sample("beta_cal", Normal(beta_cal_mean, beta_cal_std))
|
||||
|
||||
return e_mu, mag_cal, alpha_cal, beta_cal
|
||||
return {"e_mu": e_mu,
|
||||
"mag_cal": mag_cal,
|
||||
"alpha_cal": alpha_cal,
|
||||
"beta_cal": beta_cal,
|
||||
"alpha": alpha
|
||||
}
|
||||
|
||||
|
||||
###############################################################################
|
||||
|
@ -583,34 +590,42 @@ def e2_distmod_TFR(e2_mag, e2_eta, b, e_mu_intrinsic):
|
|||
return e2_mag + b**2 * e2_eta + e_mu_intrinsic**2
|
||||
|
||||
|
||||
def sample_TFR(e_mu_min, e_mu_max, a_mean, a_std, b_mean, b_std):
|
||||
def sample_TFR(e_mu_min, e_mu_max, a_mean, a_std, b_mean, b_std, alpha_min,
|
||||
alpha_max, sample_alpha, name):
|
||||
"""Sample Tully-Fisher calibration parameters."""
|
||||
e_mu = sample("e_mu", Uniform(e_mu_min, e_mu_max))
|
||||
a = sample("a", Normal(a_mean, a_std))
|
||||
b = sample("b", Normal(b_mean, b_std))
|
||||
|
||||
return e_mu, a, b
|
||||
e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max))
|
||||
a = sample(f"a_{name}", Normal(a_mean, a_std))
|
||||
b = sample(f"b_{name}", Normal(b_mean, b_std))
|
||||
alpha = sample(f"alpha_{name}",
|
||||
Uniform(alpha_min, alpha_max)) if sample_alpha else 1.0
|
||||
|
||||
return {"e_mu": e_mu,
|
||||
"a": a,
|
||||
"b": b,
|
||||
"alpha": alpha
|
||||
}
|
||||
|
||||
###############################################################################
|
||||
# Calibration parameters sampling #
|
||||
###############################################################################
|
||||
|
||||
|
||||
def sample_calibration(Vext_min, Vext_max, Vmono_min, Vmono_max,
|
||||
alpha_min, alpha_max, beta_min, beta_max, sigma_v_min,
|
||||
sigma_v_max, sample_Vmono, sample_alpha, sample_beta,
|
||||
sample_sigma_v_ext):
|
||||
def sample_calibration(Vext_min, Vext_max, Vmono_min, Vmono_max, beta_min,
|
||||
beta_max, sigma_v_min, sigma_v_max, sample_Vmono,
|
||||
sample_beta, sample_sigma_v_ext):
|
||||
"""Sample the flow calibration."""
|
||||
Vext = sample("Vext", Uniform(Vext_min, Vext_max).expand([3]))
|
||||
sigma_v = sample("sigma_v", Uniform(sigma_v_min, sigma_v_max))
|
||||
|
||||
alpha = sample("alpha", Uniform(alpha_min, alpha_max)) if sample_alpha else 1.0 # noqa
|
||||
beta = sample("beta", Uniform(beta_min, beta_max)) if sample_beta else 1.0 # noqa
|
||||
Vmono = sample("Vmono", Uniform(Vmono_min, Vmono_max)) if sample_Vmono else 0.0 # noqa
|
||||
sigma_v_ext = sample("sigma_v_ext", Uniform(sigma_v_min, sigma_v_max)) if sample_sigma_v_ext else sigma_v # noqa
|
||||
|
||||
return Vext, Vmono, sigma_v, sigma_v_ext, alpha, beta
|
||||
return {"Vext": Vext,
|
||||
"Vmono": Vmono,
|
||||
"sigma_v": sigma_v,
|
||||
"sigma_v_ext": sigma_v_ext,
|
||||
"beta": beta}
|
||||
|
||||
|
||||
###############################################################################
|
||||
|
@ -635,9 +650,9 @@ def find_extrap_mask(rmax, rdist):
|
|||
return extrap_mask, extrap_weights
|
||||
|
||||
|
||||
class PV_validation_model(BaseFlowValidationModel):
|
||||
class PV_LogLikelihood(BaseFlowValidationModel):
|
||||
"""
|
||||
Peculiar velocity validation model.
|
||||
Peculiar velocity validation model log-likelihood.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -663,7 +678,7 @@ class PV_validation_model(BaseFlowValidationModel):
|
|||
"""
|
||||
|
||||
def __init__(self, los_density, los_velocity, rmax, RA, dec, z_obs,
|
||||
e_zobs, calibration_params, r_xrange, Omega_m, kind):
|
||||
e_zobs, calibration_params, r_xrange, Omega_m, kind, name):
|
||||
if e_zobs is not None:
|
||||
e2_cz_obs = jnp.asarray((SPEED_OF_LIGHT * e_zobs)**2)
|
||||
else:
|
||||
|
@ -681,6 +696,7 @@ class PV_validation_model(BaseFlowValidationModel):
|
|||
self._set_radial_spacing(r_xrange, Omega_m)
|
||||
|
||||
self.kind = kind
|
||||
self.name = name
|
||||
self.Omega_m = Omega_m
|
||||
self.norm = - self.ndata * jnp.log(self.num_sims)
|
||||
|
||||
|
@ -688,29 +704,37 @@ class PV_validation_model(BaseFlowValidationModel):
|
|||
self.extrap_mask = jnp.asarray(extrap_mask)
|
||||
self.extrap_weights = jnp.asarray(extrap_weights)
|
||||
|
||||
def __call__(self, calibration_hyperparams, distmod_hyperparams,
|
||||
store_ll_all=False):
|
||||
"""NumPyro PV validation model."""
|
||||
Vext, Vmono, sigma_v, sigma_v_ext, alpha, beta = sample_calibration(**calibration_hyperparams) # noqa
|
||||
def __call__(self, field_calibration_params, distmod_params,
|
||||
sample_sigma_v_ext):
|
||||
"""PV validation model log-likelihood."""
|
||||
# Turn e2_cz to be of shape (nsims, ndata, nxrange) and apply
|
||||
# sigma_v_ext where applicable
|
||||
sigma_v = field_calibration_params["sigma_v"]
|
||||
sigma_v_ext = field_calibration_params["sigma_v_ext"]
|
||||
e2_cz = jnp.full_like(self.extrap_mask, sigma_v**2, dtype=jnp.float32)
|
||||
if calibration_hyperparams["sample_sigma_v_ext"]:
|
||||
if sample_sigma_v_ext:
|
||||
e2_cz = e2_cz.at[self.extrap_mask].set(sigma_v_ext**2)
|
||||
|
||||
# Now add the observational errors
|
||||
e2_cz += self.e2_cz_obs[None, :, None]
|
||||
|
||||
Vext = field_calibration_params["Vext"]
|
||||
Vmono = field_calibration_params["Vmono"]
|
||||
Vext_rad = project_Vext(Vext[0], Vext[1], Vext[2], self.RA, self.dec)
|
||||
|
||||
e_mu = distmod_params["e_mu"]
|
||||
alpha = distmod_params["alpha"]
|
||||
if self.kind == "SN":
|
||||
e_mu, mag_cal, alpha_cal, beta_cal = sample_SN(**distmod_hyperparams) # noqa
|
||||
mag_cal = distmod_params["mag_cal"]
|
||||
alpha_cal = distmod_params["alpha_cal"]
|
||||
beta_cal = distmod_params["beta_cal"]
|
||||
mu = distmod_SN(
|
||||
self.mB, self.x1, self.c, mag_cal, alpha_cal, beta_cal)
|
||||
squared_e_mu = e2_distmod_SN(
|
||||
self.e2_mB, self.e2_x1, self.e2_c, alpha_cal, beta_cal, e_mu)
|
||||
elif self.kind == "TFR":
|
||||
e_mu, a, b = sample_TFR(**distmod_hyperparams)
|
||||
a = distmod_params["a"]
|
||||
b = distmod_params["b"]
|
||||
mu = distmod_TFR(self.mag, self.eta, a, b)
|
||||
squared_e_mu = e2_distmod_TFR(self.e2_mag, self.e2_eta, b, e_mu)
|
||||
else:
|
||||
|
@ -725,20 +749,51 @@ class PV_validation_model(BaseFlowValidationModel):
|
|||
|
||||
# Calculate z_obs at each distance. Shape is (n_sims, ndata, nxrange)
|
||||
# The weights are related to the extrapolation of the velocity field.
|
||||
vrad = beta * self.los_velocity
|
||||
vrad = field_calibration_params["beta"] * self.los_velocity
|
||||
vrad += (Vext_rad[None, :, None] + Vmono) * self.extrap_weights
|
||||
zobs = (1 + self.z_xrange[None, None, :]) * (1 + vrad / SPEED_OF_LIGHT) - 1 # noqa
|
||||
|
||||
ptilde *= calculate_likelihood_zobs(self.z_obs, zobs, e2_cz)
|
||||
# ptilde *= calculate_likelihood_zobs(self.z_obs, zobs, sigma_v)
|
||||
ll = jnp.log(simpson(ptilde, dx=self.dr, axis=-1)) - jnp.log(pnorm)
|
||||
|
||||
if store_ll_all:
|
||||
numpyro.deterministic("ll_all", ll)
|
||||
return jnp.sum(logsumexp(ll, axis=0)) + self.norm
|
||||
|
||||
ll = jnp.sum(logsumexp(ll, axis=0)) + self.norm
|
||||
numpyro.deterministic("ll_values", ll)
|
||||
numpyro.factor("ll", ll)
|
||||
|
||||
def PV_validation_model(models, distmod_hyperparams_per_model,
|
||||
field_calibration_hyperparams):
|
||||
"""
|
||||
Peculiar velocity validation NumPyro model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
models : list of `PV_LogLikelihood`
|
||||
List of PV validation log-likelihoods for each catalogue.
|
||||
distmod_hyperparams_per_model : list of dict
|
||||
Distance modulus hyperparameters for each model/catalogue.
|
||||
field_calibration_hyperparams : dict
|
||||
Field calibration hyperparameters.
|
||||
"""
|
||||
field_calibration_params = sample_calibration(
|
||||
**field_calibration_hyperparams)
|
||||
sample_sigma_v_ext = field_calibration_hyperparams["sample_sigma_v_ext"]
|
||||
|
||||
ll = 0.0
|
||||
for n in range(len(models)):
|
||||
model = models[n]
|
||||
distmod_hyperparams = distmod_hyperparams_per_model[n]
|
||||
|
||||
if model.kind == "TFR":
|
||||
distmod_params = sample_TFR(**distmod_hyperparams, name=model.name)
|
||||
elif model.kind == "SN":
|
||||
distmod_params = sample_SN(**distmod_hyperparams, name=model.name)
|
||||
else:
|
||||
raise ValueError(f"Unknown kind: `{model.kind}`.")
|
||||
|
||||
ll += model(
|
||||
field_calibration_params, distmod_params, sample_sigma_v_ext)
|
||||
|
||||
deterministic("ll_values", ll)
|
||||
factor("ll", ll)
|
||||
|
||||
|
||||
###############################################################################
|
||||
|
@ -780,10 +835,10 @@ def get_model(loader, zcmb_max=None, verbose=True):
|
|||
"e_mB": e_mB[mask], "e_x1": e_x1[mask],
|
||||
"e_c": e_c[mask]}
|
||||
|
||||
model = PV_validation_model(
|
||||
model = PV_LogLikelihood(
|
||||
los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask],
|
||||
RA[mask], dec[mask], zCMB[mask], e_zCMB, calibration_params,
|
||||
loader.rdist, loader._Omega_m, "SN")
|
||||
loader.rdist, loader._Omega_m, "SN", name=kind)
|
||||
elif "Pantheon+" in kind:
|
||||
keys = ["RA", "DEC", "zCMB", "mB", "x1", "c", "biasCor_m_b", "mBERR",
|
||||
"x1ERR", "cERR", "biasCorErr_m_b", "zCMB_SN", "zCMB_Group",
|
||||
|
@ -808,10 +863,10 @@ def get_model(loader, zcmb_max=None, verbose=True):
|
|||
calibration_params = {"mB": mB[mask], "x1": x1[mask], "c": c[mask],
|
||||
"e_mB": e_mB[mask], "e_x1": e_x1[mask],
|
||||
"e_c": e_c[mask]}
|
||||
model = PV_validation_model(
|
||||
model = PV_LogLikelihood(
|
||||
los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask],
|
||||
RA[mask], dec[mask], zCMB[mask], e_zCMB[mask], calibration_params,
|
||||
loader.rdist, loader._Omega_m, "SN")
|
||||
loader.rdist, loader._Omega_m, "SN", name=kind)
|
||||
elif kind in ["SFI_gals", "2MTF", "SFI_gals_masked"]:
|
||||
keys = ["RA", "DEC", "z_CMB", "mag", "eta", "e_mag", "e_eta"]
|
||||
RA, dec, zCMB, mag, eta, e_mag, e_eta = (loader.cat[k] for k in keys)
|
||||
|
@ -824,14 +879,14 @@ def get_model(loader, zcmb_max=None, verbose=True):
|
|||
|
||||
calibration_params = {"mag": mag[mask], "eta": eta[mask],
|
||||
"e_mag": e_mag[mask], "e_eta": e_eta[mask]}
|
||||
model = PV_validation_model(
|
||||
model = PV_LogLikelihood(
|
||||
los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask],
|
||||
RA[mask], dec[mask], zCMB[mask], None, calibration_params,
|
||||
loader.rdist, loader._Omega_m, "TFR")
|
||||
loader.rdist, loader._Omega_m, "TFR", name=kind)
|
||||
else:
|
||||
raise ValueError(f"Catalogue `{kind}` not recognized.")
|
||||
|
||||
fprint(f"selected {np.sum(mask)}/{len(mask)} galaxies.")
|
||||
fprint(f"selected {np.sum(mask)}/{len(mask)} galaxies in catalogue `{kind}`") # noqa
|
||||
|
||||
return model
|
||||
|
||||
|
@ -848,9 +903,12 @@ def _posterior_element(r, beta, Vext_radial, los_velocity, Omega_m, zobs,
|
|||
`Observed2CosmologicalRedshift`.
|
||||
"""
|
||||
zobs_pred = predict_zobs(r, beta, Vext_radial, los_velocity, Omega_m)
|
||||
|
||||
# Likelihood term
|
||||
dcz = SPEED_OF_LIGHT * (zobs - zobs_pred)
|
||||
posterior = jnp.exp(-0.5 * dcz**2 / sigma_v**2) / jnp.sqrt(2 * jnp.pi * sigma_v**2) # noqa
|
||||
posterior = jnp.exp(-0.5 * dcz**2 / sigma_v**2)
|
||||
posterior /= jnp.sqrt(2 * jnp.pi * sigma_v**2)
|
||||
|
||||
# Prior term
|
||||
posterior *= dVdOmega * los_density**alpha
|
||||
|
||||
|
@ -864,7 +922,8 @@ class BaseObserved2CosmologicalRedshift(ABC):
|
|||
for i, key in enumerate(calibration_samples.keys()):
|
||||
x = calibration_samples[key]
|
||||
if not isinstance(x, (np.ndarray, jnp.ndarray)):
|
||||
raise ValueError(f"Calibration sample `{key}` must be an array.") # noqa
|
||||
raise ValueError(
|
||||
f"Calibration sample `{key}` must be an array.")
|
||||
|
||||
if x.ndim != 1 and key != "Vext":
|
||||
raise ValueError(f"Calibration samples `{key}` must be 1D.")
|
||||
|
@ -873,14 +932,19 @@ class BaseObserved2CosmologicalRedshift(ABC):
|
|||
ncalibratrion = len(x)
|
||||
|
||||
if len(x) != ncalibratrion:
|
||||
raise ValueError("Calibration samples do not have the same length.") # noqa
|
||||
raise ValueError(
|
||||
"Calibration samples do not have the same length.")
|
||||
|
||||
calibration_samples[key] = jnp.asarray(x)
|
||||
|
||||
if "alpha" not in calibration_samples:
|
||||
print("No `alpha` calibration sample found. Setting it to 1.",
|
||||
flush=True)
|
||||
calibration_samples["alpha"] = jnp.ones(ncalibratrion)
|
||||
|
||||
if "beta" not in calibration_samples:
|
||||
print("No `beta` calibration sample found. Setting it to 1.",
|
||||
flush=True)
|
||||
calibration_samples["beta"] = jnp.ones(ncalibratrion)
|
||||
|
||||
# Get the stepsize, we need it to be constant for Simpson's rule.
|
||||
|
@ -898,7 +962,8 @@ class BaseObserved2CosmologicalRedshift(ABC):
|
|||
def get_calibration_samples(self, key):
|
||||
"""Get calibration samples for a given key."""
|
||||
if key not in self._calibration_samples:
|
||||
raise ValueError(f"Key `{key}` not found in calibration samples. Available keys are: `{self.calibration_keys}`.") # noqa
|
||||
raise ValueError(f"Key `{key}` not found in calibration samples. "
|
||||
f"Available keys are: `{self.calibration_keys}`.")
|
||||
|
||||
return self._calibration_samples[key]
|
||||
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -33,7 +33,7 @@ def parse_args():
|
|||
parser.add_argument("--simname", type=str, required=True,
|
||||
help="Simulation name.")
|
||||
parser.add_argument("--catalogue", type=str, required=True,
|
||||
help="PV catalogue.")
|
||||
help="PV catalogues.")
|
||||
parser.add_argument("--ksmooth", type=int, default=1,
|
||||
help="Smoothing index.")
|
||||
parser.add_argument("--ksim", type=none_or_int, default=None,
|
||||
|
@ -42,7 +42,12 @@ def parse_args():
|
|||
help="Number of devices to request.")
|
||||
parser.add_argument("--device", type=str, default="cpu",
|
||||
help="Device to use.")
|
||||
return parser.parse_args()
|
||||
args = parser.parse_args()
|
||||
|
||||
# Convert the catalogue to a list of catalogues
|
||||
args.catalogue = args.catalogue.split(",")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
ARGS = parse_args()
|
||||
|
@ -58,8 +63,7 @@ from os.path import join
|
|||
import csiborgtools # noqa
|
||||
import jax # noqa
|
||||
from h5py import File # noqa
|
||||
from mpi4py import MPI # noqa
|
||||
from numpyro.infer import MCMC, NUTS, Predictive, init_to_median # noqa
|
||||
from numpyro.infer import MCMC, NUTS, init_to_median # noqa
|
||||
|
||||
|
||||
def print_variables(names, variables):
|
||||
|
@ -68,8 +72,8 @@ def print_variables(names, variables):
|
|||
print(flush=True)
|
||||
|
||||
|
||||
def get_model(paths, get_model_kwargs, verbose=True):
|
||||
"""Load the data and create the NumPyro model."""
|
||||
def get_models(get_model_kwargs, verbose=True):
|
||||
"""Load the data and create the NumPyro models."""
|
||||
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
|
||||
folder = "/mnt/extraspace/rstiskalek/catalogs/"
|
||||
|
||||
|
@ -86,55 +90,45 @@ def get_model(paths, get_model_kwargs, verbose=True):
|
|||
print(f"{'Num. realisations:':<20} {len(nsims)}")
|
||||
print(flush=True)
|
||||
|
||||
if ARGS.catalogue == "A2":
|
||||
fpath = join(folder, "A2.h5")
|
||||
elif ARGS.catalogue in ["LOSS", "Foundation", "Pantheon+", "SFI_gals",
|
||||
"2MTF", "SFI_groups", "SFI_gals_masked",
|
||||
"Pantheon+_groups", "Pantheon+_groups_zSN",
|
||||
"Pantheon+_zSN"]:
|
||||
fpath = join(folder, "PV_compilation.hdf5")
|
||||
else:
|
||||
raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.")
|
||||
# Get models
|
||||
models = [None] * len(ARGS.catalogue)
|
||||
for i, cat in enumerate(ARGS.catalogue):
|
||||
if cat == "A2":
|
||||
fpath = join(folder, "A2.h5")
|
||||
elif cat in ["LOSS", "Foundation", "Pantheon+", "SFI_gals",
|
||||
"2MTF", "SFI_groups", "SFI_gals_masked",
|
||||
"Pantheon+_groups", "Pantheon+_groups_zSN",
|
||||
"Pantheon+_zSN"]:
|
||||
fpath = join(folder, "PV_compilation.hdf5")
|
||||
else:
|
||||
raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.")
|
||||
|
||||
loader = csiborgtools.flow.DataLoader(ARGS.simname, nsim_iterator,
|
||||
ARGS.catalogue, fpath, paths,
|
||||
ksmooth=ARGS.ksmooth)
|
||||
loader = csiborgtools.flow.DataLoader(ARGS.simname, nsim_iterator,
|
||||
cat, fpath, paths,
|
||||
ksmooth=ARGS.ksmooth)
|
||||
models[i] = csiborgtools.flow.get_model(loader, **get_model_kwargs)
|
||||
|
||||
print(f"\n{'Num. radial steps':<20} {len(loader.rdist)}\n", flush=True)
|
||||
|
||||
return csiborgtools.flow.get_model(loader, **get_model_kwargs)
|
||||
return models
|
||||
|
||||
|
||||
def get_harmonic_evidence(samples, log_posterior, nchains_harmonic, epoch_num):
|
||||
"""Compute evidence using the `harmonic` package."""
|
||||
data, names = csiborgtools.dict_samples_to_array(samples)
|
||||
data = data.reshape(nchains_harmonic, -1, len(names))
|
||||
log_posterior = log_posterior.reshape(10, -1)
|
||||
log_posterior = log_posterior.reshape(nchains_harmonic, -1)
|
||||
|
||||
return csiborgtools.harmonic_evidence(
|
||||
data, log_posterior, return_flow_samples=False, epochs_num=epoch_num)
|
||||
|
||||
|
||||
def get_simulation_weights(samples, model, model_kwargs):
|
||||
"""Get the weights per posterior samples for each simulation."""
|
||||
predictive = Predictive(model, samples)
|
||||
ll_all = predictive(
|
||||
jax.random.PRNGKey(1), store_ll_all=True, **model_kwargs)["ll_all"]
|
||||
|
||||
# Multiply the likelihood of galaxies
|
||||
ll_per_simulation = jax.numpy.sum(ll_all, axis=-1)
|
||||
# Normalization by summing the likelihood over simulations
|
||||
norm = jax.scipy.special.logsumexp(ll_per_simulation, axis=-1)
|
||||
return ll_per_simulation - norm[:, None]
|
||||
|
||||
|
||||
def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
|
||||
calculate_evidence, nchains_harmonic, epoch_num, kwargs_print):
|
||||
"""Run the NumPyro model and save output to a file."""
|
||||
try:
|
||||
ndata = model.ndata
|
||||
ndata = sum(model.ndata for model in model_kwargs["models"])
|
||||
except AttributeError as e:
|
||||
raise AttributeError("The model must have an attribute `ndata` "
|
||||
raise AttributeError("The models must have an attribute `ndata` "
|
||||
"indicating the number of data points.") from e
|
||||
|
||||
nuts_kernel = NUTS(model, init_strategy=init_to_median(num_samples=1000))
|
||||
|
@ -143,7 +137,6 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
|
|||
|
||||
mcmc.run(rng_key, extra_fields=("potential_energy",), **model_kwargs)
|
||||
samples = mcmc.get_samples()
|
||||
simulation_weights = get_simulation_weights(samples, model, model_kwargs)
|
||||
|
||||
log_posterior = -mcmc.get_extra_fields()["potential_energy"]
|
||||
log_likelihood = samples.pop("ll_values")
|
||||
|
@ -165,7 +158,7 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
|
|||
neg_ln_evidence = jax.numpy.nan
|
||||
neg_ln_evidence_err = (jax.numpy.nan, jax.numpy.nan)
|
||||
|
||||
fname = f"samples_{ARGS.simname}_{ARGS.catalogue}_ksmooth{ARGS.ksmooth}.hdf5" # noqa
|
||||
fname = f"samples_{ARGS.simname}_{'+'.join(ARGS.catalogue)}_ksmooth{ARGS.ksmooth}.hdf5" # noqa
|
||||
if ARGS.ksim is not None:
|
||||
fname = fname.replace(".hdf5", f"_nsim{ARGS.ksim}.hdf5")
|
||||
|
||||
|
@ -183,7 +176,6 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
|
|||
# Write log likelihood and posterior
|
||||
f.create_dataset("log_likelihood", data=log_likelihood)
|
||||
f.create_dataset("log_posterior", data=log_posterior)
|
||||
f.create_dataset("simulation_weights", data=simulation_weights)
|
||||
|
||||
# Write goodness of fit
|
||||
grp = f.create_group("gof")
|
||||
|
@ -215,6 +207,29 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
|
|||
# Command line interface #
|
||||
###############################################################################
|
||||
|
||||
def get_distmod_hyperparams(catalogue):
|
||||
alpha_min = -1.0
|
||||
alpha_max = 3.0
|
||||
sample_alpha = True
|
||||
|
||||
if catalogue in ["LOSS", "Foundation", "Pantheon+", "Pantheon+_groups", "Pantheon+_zSN"]: # noqa
|
||||
return {"e_mu_min": 0.001, "e_mu_max": 1.0,
|
||||
"mag_cal_mean": -18.25, "mag_cal_std": 2.0,
|
||||
"alpha_cal_mean": 0.148, "alpha_cal_std": 1.0,
|
||||
"beta_cal_mean": 3.112, "beta_cal_std": 2.0,
|
||||
"alpha_min": alpha_min, "alpha_max": alpha_max,
|
||||
"sample_alpha": sample_alpha
|
||||
}
|
||||
elif catalogue in ["SFI_gals", "2MTF"]:
|
||||
return {"e_mu_min": 0.001, "e_mu_max": 1.0,
|
||||
"a_mean": -21., "a_std": 5.0,
|
||||
"b_mean": -5.95, "b_std": 3.0,
|
||||
"alpha_min": alpha_min, "alpha_max": alpha_max,
|
||||
"sample_alpha": sample_alpha
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
|
||||
|
@ -227,14 +242,15 @@ if __name__ == "__main__":
|
|||
###########################################################################
|
||||
|
||||
nsteps = 5000
|
||||
nburn = 1000
|
||||
zcmb_max = 0.06
|
||||
nburn = 1500
|
||||
zcmb_max = 0.05
|
||||
calculate_evidence = False
|
||||
nchains_harmonic = 10
|
||||
num_epochs = 30
|
||||
|
||||
if nsteps % nchains_harmonic != 0:
|
||||
raise ValueError("The number of steps must be divisible by the number of chains.") # noqa
|
||||
raise ValueError(
|
||||
"The number of steps must be divisible by the number of chains.")
|
||||
|
||||
main_params = {"nsteps": nsteps, "nburn": nburn, "zcmb_max": zcmb_max,
|
||||
"calculate_evidence": calculate_evidence,
|
||||
|
@ -244,42 +260,36 @@ if __name__ == "__main__":
|
|||
|
||||
calibration_hyperparams = {"Vext_min": -1000, "Vext_max": 1000,
|
||||
"Vmono_min": -1000, "Vmono_max": 1000,
|
||||
"alpha_min": -1.0, "alpha_max": 3.0,
|
||||
"beta_min": -1.0, "beta_max": 3.0,
|
||||
"sigma_v_min": 1.0, "sigma_v_max": 750.,
|
||||
"sample_Vmono": False,
|
||||
"sample_alpha": True,
|
||||
"sample_beta": True,
|
||||
"sample_sigma_v_ext": False,
|
||||
}
|
||||
print_variables(
|
||||
calibration_hyperparams.keys(), calibration_hyperparams.values())
|
||||
|
||||
if ARGS.catalogue in ["LOSS", "Foundation", "Pantheon+", "Pantheon+_groups", "Pantheon+_zSN"]: # noqa
|
||||
distmod_hyperparams = {"e_mu_min": 0.001, "e_mu_max": 1.0,
|
||||
"mag_cal_mean": -18.25, "mag_cal_std": 2.0,
|
||||
"alpha_cal_mean": 0.148, "alpha_cal_std": 1.0,
|
||||
"beta_cal_mean": 3.112, "beta_cal_std": 2.0,
|
||||
}
|
||||
elif ARGS.catalogue in ["SFI_gals", "2MTF"]:
|
||||
distmod_hyperparams = {"e_mu_min": 0.001, "e_mu_max": 1.0,
|
||||
"a_mean": -21., "a_std": 5.0,
|
||||
"b_mean": -5.95, "b_std": 3.0,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.")
|
||||
distmod_hyperparams_per_catalogue = []
|
||||
for cat in ARGS.catalogue:
|
||||
x = get_distmod_hyperparams(cat)
|
||||
print(f"\n{cat} hyperparameters:")
|
||||
print_variables(x.keys(), x.values())
|
||||
distmod_hyperparams_per_catalogue.append(x)
|
||||
|
||||
print_variables(
|
||||
distmod_hyperparams.keys(), distmod_hyperparams.values())
|
||||
|
||||
kwargs_print = (main_params, calibration_hyperparams, distmod_hyperparams)
|
||||
kwargs_print = (main_params, calibration_hyperparams,
|
||||
*distmod_hyperparams_per_catalogue)
|
||||
###########################################################################
|
||||
|
||||
model_kwargs = {"calibration_hyperparams": calibration_hyperparams,
|
||||
"distmod_hyperparams": distmod_hyperparams}
|
||||
get_model_kwargs = {"zcmb_max": zcmb_max}
|
||||
models = get_models(get_model_kwargs, )
|
||||
model_kwargs = {
|
||||
"models": models,
|
||||
"field_calibration_hyperparams": calibration_hyperparams,
|
||||
"distmod_hyperparams_per_model": distmod_hyperparams_per_catalogue,
|
||||
}
|
||||
|
||||
model = csiborgtools.flow.PV_validation_model
|
||||
|
||||
model = get_model(paths, get_model_kwargs, )
|
||||
run_model(model, nsteps, nburn, model_kwargs, out_folder,
|
||||
calibration_hyperparams["sample_beta"], calculate_evidence,
|
||||
nchains_harmonic, num_epochs, kwargs_print)
|
||||
|
|
|
@ -1,44 +1,66 @@
|
|||
#!/bin/bash
|
||||
memory=8
|
||||
memory=7
|
||||
on_login=${1}
|
||||
queue=${2}
|
||||
ndevice=1
|
||||
|
||||
device="gpu"
|
||||
queue="gpulong"
|
||||
gputype="rtx2080with12gb"
|
||||
env="/mnt/users/rstiskalek/csiborgtools/venv_gpu_csiborgtools/bin/python"
|
||||
file="flow_validation.py"
|
||||
ksmooth=0
|
||||
|
||||
|
||||
if [ "$on_login" != "1" ] && [ "$on_login" != "0" ]; then
|
||||
echo "Invalid input: 'on_login' (1). Please provide 1 or 0."
|
||||
if [ "$on_login" != "1" ] && [ "$on_login" != "0" ]
|
||||
then
|
||||
echo "'on_login' (1) must be either 0 or 1."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
if [ "$queue" != "redwood" ] && [ "$queue" != "berg" ] && [ "$queue" != "cmb" ] && [ "$queue" != "gpulong" ] && [ "$queue" != "cmbgpu" ]; then
|
||||
echo "Invalid queue: $queue (2). Please provide one of 'redwood', 'berg', 'cmb', 'gpulong', 'cmbgpu'."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
# Submit a job for each combination of simname, catalogue, ksim
|
||||
if [ "$queue" == "gpulong" ]
|
||||
then
|
||||
device="gpu"
|
||||
gputype="rtx2080with12gb"
|
||||
# gputype="rtx3070with8gb"
|
||||
env="/mnt/users/rstiskalek/csiborgtools/venv_gpu_csiborgtools/bin/python"
|
||||
elif [ "$queue" == "cmbgpu" ]
|
||||
then
|
||||
device="gpu"
|
||||
gputype="rtx3090with24gb"
|
||||
env="/mnt/users/rstiskalek/csiborgtools/venv_gpu_csiborgtools/bin/python"
|
||||
else
|
||||
device="cpu"
|
||||
env="/mnt/users/rstiskalek/csiborgtools/venv_csiborg/bin/python"
|
||||
fi
|
||||
|
||||
|
||||
# for simname in "Lilow2024" "CF4" "CF4gp" "csiborg1" "csiborg2_main" "csiborg2X"; do
|
||||
for simname in "Carrick2015"; do
|
||||
# for simname in "csiborg1" "csiborg2_main" "csiborg2X"; do
|
||||
for catalogue in "Foundation"; do
|
||||
# for catalogue in "2MTF"; do
|
||||
# for ksim in 0 1 2; do
|
||||
for catalogue in "LOSS,2MTF,SFI_gals"; do
|
||||
# for ksim in 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20; do
|
||||
for ksim in "none"; do
|
||||
pythoncm="$env $file --catalogue $catalogue --simname $simname --ksim $ksim --ksmooth $ksmooth --ndevice $ndevice --device $device"
|
||||
|
||||
if [ $on_login -eq 1 ]; then
|
||||
if [ "$on_login" == "1" ]; then
|
||||
echo $pythoncm
|
||||
$pythoncm
|
||||
eval $pythoncm
|
||||
else
|
||||
cm="addqueue -q $queue -s -m $memory --gpus 1 --gputype $gputype $pythoncm"
|
||||
if [ "$device" == "gpu" ]; then
|
||||
cm="addqueue -q $queue -s -m $memory --gpus 1 --gputype $gputype $pythoncm"
|
||||
else
|
||||
cm="addqueue -s -q $queue -n 1 -m $memory $pythoncm"
|
||||
fi
|
||||
echo "Submitting:"
|
||||
echo $cm
|
||||
eval $cm
|
||||
fi
|
||||
|
||||
echo
|
||||
sleep 0.001
|
||||
|
||||
done
|
||||
done
|
||||
done
|
||||
|
|
Loading…
Reference in a new issue