Combine PV samples (#139)

* Update imports

* Update submission script

* Update script

* Add simulataenous sampling of many catalogues

* Update nb
This commit is contained in:
Richard Stiskalek 2024-07-30 17:02:48 +01:00 committed by GitHub
parent 9756175943
commit 3b46f17ead
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 356 additions and 147 deletions

View File

@ -12,8 +12,8 @@
# You should have received a copy of the GNU General Public License along # You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc., # with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from .flow_model import (DataLoader, radial_velocity_los, dist2redshift, # noqa from .flow_model import (DataLoader, PV_LogLikelihood, PV_validation_model, # noqa
dist2distmodulus, predict_zobs, project_Vext, # noqa dist2distmodulus, dist2redshift, distmodulus2dist, # noqa
PV_validation_model, get_model, distmodulus2dist, # noqa get_model, Observed2CosmologicalRedshift, # noqa
Observed2CosmologicalRedshift, # noqa predict_zobs, project_Vext, # noqa
stack_pzosmo_over_realizations) # noqa radial_velocity_los, stack_pzosmo_over_realizations) # noqa

View File

@ -24,15 +24,14 @@ References
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import numpy as np import numpy as np
import numpyro
from astropy.cosmology import FlatLambdaCDM, z_at_value
from astropy import units as u from astropy import units as u
from astropy.cosmology import FlatLambdaCDM, z_at_value
from h5py import File from h5py import File
from jax import jit from jax import jit
from jax import numpy as jnp from jax import numpy as jnp
from jax import vmap from jax import vmap
from jax.scipy.special import logsumexp from jax.scipy.special import logsumexp
from numpyro import sample from numpyro import deterministic, factor, sample
from numpyro.distributions import Normal, Uniform from numpyro.distributions import Normal, Uniform
from quadax import simpson from quadax import simpson
from scipy.interpolate import interp1d from scipy.interpolate import interp1d
@ -558,15 +557,23 @@ def e2_distmod_SN(e2_mB, e2_x1, e2_c, alpha_cal, beta_cal, e_mu_intrinsic):
def sample_SN(e_mu_min, e_mu_max, mag_cal_mean, mag_cal_std, alpha_cal_mean, def sample_SN(e_mu_min, e_mu_max, mag_cal_mean, mag_cal_std, alpha_cal_mean,
alpha_cal_std, beta_cal_mean, beta_cal_std): alpha_cal_std, beta_cal_mean, beta_cal_std, alpha_min, alpha_max,
sample_alpha, name):
"""Sample SNIe Tripp parameters.""" """Sample SNIe Tripp parameters."""
e_mu = sample("e_mu", Uniform(e_mu_min, e_mu_max)) e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max))
mag_cal = sample("mag_cal", Normal(mag_cal_mean, mag_cal_std)) mag_cal = sample(f"mag_cal_{name}", Normal(mag_cal_mean, mag_cal_std))
alpha_cal = sample("alpha_cal", Normal(alpha_cal_mean, alpha_cal_std)) alpha_cal = sample(
f"alpha_cal_{name}", Normal(alpha_cal_mean, alpha_cal_std))
beta_cal = sample(f"beta_cal_{name}", Normal(beta_cal_mean, beta_cal_std))
alpha = sample(f"alpha_{name}",
Uniform(alpha_min, alpha_max)) if sample_alpha else 1.0
beta_cal = sample("beta_cal", Normal(beta_cal_mean, beta_cal_std)) return {"e_mu": e_mu,
"mag_cal": mag_cal,
return e_mu, mag_cal, alpha_cal, beta_cal "alpha_cal": alpha_cal,
"beta_cal": beta_cal,
"alpha": alpha
}
############################################################################### ###############################################################################
@ -583,34 +590,42 @@ def e2_distmod_TFR(e2_mag, e2_eta, b, e_mu_intrinsic):
return e2_mag + b**2 * e2_eta + e_mu_intrinsic**2 return e2_mag + b**2 * e2_eta + e_mu_intrinsic**2
def sample_TFR(e_mu_min, e_mu_max, a_mean, a_std, b_mean, b_std): def sample_TFR(e_mu_min, e_mu_max, a_mean, a_std, b_mean, b_std, alpha_min,
alpha_max, sample_alpha, name):
"""Sample Tully-Fisher calibration parameters.""" """Sample Tully-Fisher calibration parameters."""
e_mu = sample("e_mu", Uniform(e_mu_min, e_mu_max)) e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max))
a = sample("a", Normal(a_mean, a_std)) a = sample(f"a_{name}", Normal(a_mean, a_std))
b = sample("b", Normal(b_mean, b_std)) b = sample(f"b_{name}", Normal(b_mean, b_std))
alpha = sample(f"alpha_{name}",
return e_mu, a, b Uniform(alpha_min, alpha_max)) if sample_alpha else 1.0
return {"e_mu": e_mu,
"a": a,
"b": b,
"alpha": alpha
}
############################################################################### ###############################################################################
# Calibration parameters sampling # # Calibration parameters sampling #
############################################################################### ###############################################################################
def sample_calibration(Vext_min, Vext_max, Vmono_min, Vmono_max, def sample_calibration(Vext_min, Vext_max, Vmono_min, Vmono_max, beta_min,
alpha_min, alpha_max, beta_min, beta_max, sigma_v_min, beta_max, sigma_v_min, sigma_v_max, sample_Vmono,
sigma_v_max, sample_Vmono, sample_alpha, sample_beta, sample_beta, sample_sigma_v_ext):
sample_sigma_v_ext):
"""Sample the flow calibration.""" """Sample the flow calibration."""
Vext = sample("Vext", Uniform(Vext_min, Vext_max).expand([3])) Vext = sample("Vext", Uniform(Vext_min, Vext_max).expand([3]))
sigma_v = sample("sigma_v", Uniform(sigma_v_min, sigma_v_max)) sigma_v = sample("sigma_v", Uniform(sigma_v_min, sigma_v_max))
alpha = sample("alpha", Uniform(alpha_min, alpha_max)) if sample_alpha else 1.0 # noqa
beta = sample("beta", Uniform(beta_min, beta_max)) if sample_beta else 1.0 # noqa beta = sample("beta", Uniform(beta_min, beta_max)) if sample_beta else 1.0 # noqa
Vmono = sample("Vmono", Uniform(Vmono_min, Vmono_max)) if sample_Vmono else 0.0 # noqa Vmono = sample("Vmono", Uniform(Vmono_min, Vmono_max)) if sample_Vmono else 0.0 # noqa
sigma_v_ext = sample("sigma_v_ext", Uniform(sigma_v_min, sigma_v_max)) if sample_sigma_v_ext else sigma_v # noqa sigma_v_ext = sample("sigma_v_ext", Uniform(sigma_v_min, sigma_v_max)) if sample_sigma_v_ext else sigma_v # noqa
return Vext, Vmono, sigma_v, sigma_v_ext, alpha, beta return {"Vext": Vext,
"Vmono": Vmono,
"sigma_v": sigma_v,
"sigma_v_ext": sigma_v_ext,
"beta": beta}
############################################################################### ###############################################################################
@ -635,9 +650,9 @@ def find_extrap_mask(rmax, rdist):
return extrap_mask, extrap_weights return extrap_mask, extrap_weights
class PV_validation_model(BaseFlowValidationModel): class PV_LogLikelihood(BaseFlowValidationModel):
""" """
Peculiar velocity validation model. Peculiar velocity validation model log-likelihood.
Parameters Parameters
---------- ----------
@ -663,7 +678,7 @@ class PV_validation_model(BaseFlowValidationModel):
""" """
def __init__(self, los_density, los_velocity, rmax, RA, dec, z_obs, def __init__(self, los_density, los_velocity, rmax, RA, dec, z_obs,
e_zobs, calibration_params, r_xrange, Omega_m, kind): e_zobs, calibration_params, r_xrange, Omega_m, kind, name):
if e_zobs is not None: if e_zobs is not None:
e2_cz_obs = jnp.asarray((SPEED_OF_LIGHT * e_zobs)**2) e2_cz_obs = jnp.asarray((SPEED_OF_LIGHT * e_zobs)**2)
else: else:
@ -681,6 +696,7 @@ class PV_validation_model(BaseFlowValidationModel):
self._set_radial_spacing(r_xrange, Omega_m) self._set_radial_spacing(r_xrange, Omega_m)
self.kind = kind self.kind = kind
self.name = name
self.Omega_m = Omega_m self.Omega_m = Omega_m
self.norm = - self.ndata * jnp.log(self.num_sims) self.norm = - self.ndata * jnp.log(self.num_sims)
@ -688,29 +704,37 @@ class PV_validation_model(BaseFlowValidationModel):
self.extrap_mask = jnp.asarray(extrap_mask) self.extrap_mask = jnp.asarray(extrap_mask)
self.extrap_weights = jnp.asarray(extrap_weights) self.extrap_weights = jnp.asarray(extrap_weights)
def __call__(self, calibration_hyperparams, distmod_hyperparams, def __call__(self, field_calibration_params, distmod_params,
store_ll_all=False): sample_sigma_v_ext):
"""NumPyro PV validation model.""" """PV validation model log-likelihood."""
Vext, Vmono, sigma_v, sigma_v_ext, alpha, beta = sample_calibration(**calibration_hyperparams) # noqa
# Turn e2_cz to be of shape (nsims, ndata, nxrange) and apply # Turn e2_cz to be of shape (nsims, ndata, nxrange) and apply
# sigma_v_ext where applicable # sigma_v_ext where applicable
sigma_v = field_calibration_params["sigma_v"]
sigma_v_ext = field_calibration_params["sigma_v_ext"]
e2_cz = jnp.full_like(self.extrap_mask, sigma_v**2, dtype=jnp.float32) e2_cz = jnp.full_like(self.extrap_mask, sigma_v**2, dtype=jnp.float32)
if calibration_hyperparams["sample_sigma_v_ext"]: if sample_sigma_v_ext:
e2_cz = e2_cz.at[self.extrap_mask].set(sigma_v_ext**2) e2_cz = e2_cz.at[self.extrap_mask].set(sigma_v_ext**2)
# Now add the observational errors # Now add the observational errors
e2_cz += self.e2_cz_obs[None, :, None] e2_cz += self.e2_cz_obs[None, :, None]
Vext = field_calibration_params["Vext"]
Vmono = field_calibration_params["Vmono"]
Vext_rad = project_Vext(Vext[0], Vext[1], Vext[2], self.RA, self.dec) Vext_rad = project_Vext(Vext[0], Vext[1], Vext[2], self.RA, self.dec)
e_mu = distmod_params["e_mu"]
alpha = distmod_params["alpha"]
if self.kind == "SN": if self.kind == "SN":
e_mu, mag_cal, alpha_cal, beta_cal = sample_SN(**distmod_hyperparams) # noqa mag_cal = distmod_params["mag_cal"]
alpha_cal = distmod_params["alpha_cal"]
beta_cal = distmod_params["beta_cal"]
mu = distmod_SN( mu = distmod_SN(
self.mB, self.x1, self.c, mag_cal, alpha_cal, beta_cal) self.mB, self.x1, self.c, mag_cal, alpha_cal, beta_cal)
squared_e_mu = e2_distmod_SN( squared_e_mu = e2_distmod_SN(
self.e2_mB, self.e2_x1, self.e2_c, alpha_cal, beta_cal, e_mu) self.e2_mB, self.e2_x1, self.e2_c, alpha_cal, beta_cal, e_mu)
elif self.kind == "TFR": elif self.kind == "TFR":
e_mu, a, b = sample_TFR(**distmod_hyperparams) a = distmod_params["a"]
b = distmod_params["b"]
mu = distmod_TFR(self.mag, self.eta, a, b) mu = distmod_TFR(self.mag, self.eta, a, b)
squared_e_mu = e2_distmod_TFR(self.e2_mag, self.e2_eta, b, e_mu) squared_e_mu = e2_distmod_TFR(self.e2_mag, self.e2_eta, b, e_mu)
else: else:
@ -725,20 +749,51 @@ class PV_validation_model(BaseFlowValidationModel):
# Calculate z_obs at each distance. Shape is (n_sims, ndata, nxrange) # Calculate z_obs at each distance. Shape is (n_sims, ndata, nxrange)
# The weights are related to the extrapolation of the velocity field. # The weights are related to the extrapolation of the velocity field.
vrad = beta * self.los_velocity vrad = field_calibration_params["beta"] * self.los_velocity
vrad += (Vext_rad[None, :, None] + Vmono) * self.extrap_weights vrad += (Vext_rad[None, :, None] + Vmono) * self.extrap_weights
zobs = (1 + self.z_xrange[None, None, :]) * (1 + vrad / SPEED_OF_LIGHT) - 1 # noqa zobs = (1 + self.z_xrange[None, None, :]) * (1 + vrad / SPEED_OF_LIGHT) - 1 # noqa
ptilde *= calculate_likelihood_zobs(self.z_obs, zobs, e2_cz) ptilde *= calculate_likelihood_zobs(self.z_obs, zobs, e2_cz)
# ptilde *= calculate_likelihood_zobs(self.z_obs, zobs, sigma_v)
ll = jnp.log(simpson(ptilde, dx=self.dr, axis=-1)) - jnp.log(pnorm) ll = jnp.log(simpson(ptilde, dx=self.dr, axis=-1)) - jnp.log(pnorm)
if store_ll_all: return jnp.sum(logsumexp(ll, axis=0)) + self.norm
numpyro.deterministic("ll_all", ll)
ll = jnp.sum(logsumexp(ll, axis=0)) + self.norm
numpyro.deterministic("ll_values", ll) def PV_validation_model(models, distmod_hyperparams_per_model,
numpyro.factor("ll", ll) field_calibration_hyperparams):
"""
Peculiar velocity validation NumPyro model.
Parameters
----------
models : list of `PV_LogLikelihood`
List of PV validation log-likelihoods for each catalogue.
distmod_hyperparams_per_model : list of dict
Distance modulus hyperparameters for each model/catalogue.
field_calibration_hyperparams : dict
Field calibration hyperparameters.
"""
field_calibration_params = sample_calibration(
**field_calibration_hyperparams)
sample_sigma_v_ext = field_calibration_hyperparams["sample_sigma_v_ext"]
ll = 0.0
for n in range(len(models)):
model = models[n]
distmod_hyperparams = distmod_hyperparams_per_model[n]
if model.kind == "TFR":
distmod_params = sample_TFR(**distmod_hyperparams, name=model.name)
elif model.kind == "SN":
distmod_params = sample_SN(**distmod_hyperparams, name=model.name)
else:
raise ValueError(f"Unknown kind: `{model.kind}`.")
ll += model(
field_calibration_params, distmod_params, sample_sigma_v_ext)
deterministic("ll_values", ll)
factor("ll", ll)
############################################################################### ###############################################################################
@ -780,10 +835,10 @@ def get_model(loader, zcmb_max=None, verbose=True):
"e_mB": e_mB[mask], "e_x1": e_x1[mask], "e_mB": e_mB[mask], "e_x1": e_x1[mask],
"e_c": e_c[mask]} "e_c": e_c[mask]}
model = PV_validation_model( model = PV_LogLikelihood(
los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask], los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask],
RA[mask], dec[mask], zCMB[mask], e_zCMB, calibration_params, RA[mask], dec[mask], zCMB[mask], e_zCMB, calibration_params,
loader.rdist, loader._Omega_m, "SN") loader.rdist, loader._Omega_m, "SN", name=kind)
elif "Pantheon+" in kind: elif "Pantheon+" in kind:
keys = ["RA", "DEC", "zCMB", "mB", "x1", "c", "biasCor_m_b", "mBERR", keys = ["RA", "DEC", "zCMB", "mB", "x1", "c", "biasCor_m_b", "mBERR",
"x1ERR", "cERR", "biasCorErr_m_b", "zCMB_SN", "zCMB_Group", "x1ERR", "cERR", "biasCorErr_m_b", "zCMB_SN", "zCMB_Group",
@ -808,10 +863,10 @@ def get_model(loader, zcmb_max=None, verbose=True):
calibration_params = {"mB": mB[mask], "x1": x1[mask], "c": c[mask], calibration_params = {"mB": mB[mask], "x1": x1[mask], "c": c[mask],
"e_mB": e_mB[mask], "e_x1": e_x1[mask], "e_mB": e_mB[mask], "e_x1": e_x1[mask],
"e_c": e_c[mask]} "e_c": e_c[mask]}
model = PV_validation_model( model = PV_LogLikelihood(
los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask], los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask],
RA[mask], dec[mask], zCMB[mask], e_zCMB[mask], calibration_params, RA[mask], dec[mask], zCMB[mask], e_zCMB[mask], calibration_params,
loader.rdist, loader._Omega_m, "SN") loader.rdist, loader._Omega_m, "SN", name=kind)
elif kind in ["SFI_gals", "2MTF", "SFI_gals_masked"]: elif kind in ["SFI_gals", "2MTF", "SFI_gals_masked"]:
keys = ["RA", "DEC", "z_CMB", "mag", "eta", "e_mag", "e_eta"] keys = ["RA", "DEC", "z_CMB", "mag", "eta", "e_mag", "e_eta"]
RA, dec, zCMB, mag, eta, e_mag, e_eta = (loader.cat[k] for k in keys) RA, dec, zCMB, mag, eta, e_mag, e_eta = (loader.cat[k] for k in keys)
@ -824,14 +879,14 @@ def get_model(loader, zcmb_max=None, verbose=True):
calibration_params = {"mag": mag[mask], "eta": eta[mask], calibration_params = {"mag": mag[mask], "eta": eta[mask],
"e_mag": e_mag[mask], "e_eta": e_eta[mask]} "e_mag": e_mag[mask], "e_eta": e_eta[mask]}
model = PV_validation_model( model = PV_LogLikelihood(
los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask], los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask],
RA[mask], dec[mask], zCMB[mask], None, calibration_params, RA[mask], dec[mask], zCMB[mask], None, calibration_params,
loader.rdist, loader._Omega_m, "TFR") loader.rdist, loader._Omega_m, "TFR", name=kind)
else: else:
raise ValueError(f"Catalogue `{kind}` not recognized.") raise ValueError(f"Catalogue `{kind}` not recognized.")
fprint(f"selected {np.sum(mask)}/{len(mask)} galaxies.") fprint(f"selected {np.sum(mask)}/{len(mask)} galaxies in catalogue `{kind}`") # noqa
return model return model
@ -848,9 +903,12 @@ def _posterior_element(r, beta, Vext_radial, los_velocity, Omega_m, zobs,
`Observed2CosmologicalRedshift`. `Observed2CosmologicalRedshift`.
""" """
zobs_pred = predict_zobs(r, beta, Vext_radial, los_velocity, Omega_m) zobs_pred = predict_zobs(r, beta, Vext_radial, los_velocity, Omega_m)
# Likelihood term # Likelihood term
dcz = SPEED_OF_LIGHT * (zobs - zobs_pred) dcz = SPEED_OF_LIGHT * (zobs - zobs_pred)
posterior = jnp.exp(-0.5 * dcz**2 / sigma_v**2) / jnp.sqrt(2 * jnp.pi * sigma_v**2) # noqa posterior = jnp.exp(-0.5 * dcz**2 / sigma_v**2)
posterior /= jnp.sqrt(2 * jnp.pi * sigma_v**2)
# Prior term # Prior term
posterior *= dVdOmega * los_density**alpha posterior *= dVdOmega * los_density**alpha
@ -864,7 +922,8 @@ class BaseObserved2CosmologicalRedshift(ABC):
for i, key in enumerate(calibration_samples.keys()): for i, key in enumerate(calibration_samples.keys()):
x = calibration_samples[key] x = calibration_samples[key]
if not isinstance(x, (np.ndarray, jnp.ndarray)): if not isinstance(x, (np.ndarray, jnp.ndarray)):
raise ValueError(f"Calibration sample `{key}` must be an array.") # noqa raise ValueError(
f"Calibration sample `{key}` must be an array.")
if x.ndim != 1 and key != "Vext": if x.ndim != 1 and key != "Vext":
raise ValueError(f"Calibration samples `{key}` must be 1D.") raise ValueError(f"Calibration samples `{key}` must be 1D.")
@ -873,14 +932,19 @@ class BaseObserved2CosmologicalRedshift(ABC):
ncalibratrion = len(x) ncalibratrion = len(x)
if len(x) != ncalibratrion: if len(x) != ncalibratrion:
raise ValueError("Calibration samples do not have the same length.") # noqa raise ValueError(
"Calibration samples do not have the same length.")
calibration_samples[key] = jnp.asarray(x) calibration_samples[key] = jnp.asarray(x)
if "alpha" not in calibration_samples: if "alpha" not in calibration_samples:
print("No `alpha` calibration sample found. Setting it to 1.",
flush=True)
calibration_samples["alpha"] = jnp.ones(ncalibratrion) calibration_samples["alpha"] = jnp.ones(ncalibratrion)
if "beta" not in calibration_samples: if "beta" not in calibration_samples:
print("No `beta` calibration sample found. Setting it to 1.",
flush=True)
calibration_samples["beta"] = jnp.ones(ncalibratrion) calibration_samples["beta"] = jnp.ones(ncalibratrion)
# Get the stepsize, we need it to be constant for Simpson's rule. # Get the stepsize, we need it to be constant for Simpson's rule.
@ -898,7 +962,8 @@ class BaseObserved2CosmologicalRedshift(ABC):
def get_calibration_samples(self, key): def get_calibration_samples(self, key):
"""Get calibration samples for a given key.""" """Get calibration samples for a given key."""
if key not in self._calibration_samples: if key not in self._calibration_samples:
raise ValueError(f"Key `{key}` not found in calibration samples. Available keys are: `{self.calibration_keys}`.") # noqa raise ValueError(f"Key `{key}` not found in calibration samples. "
f"Available keys are: `{self.calibration_keys}`.")
return self._calibration_samples[key] return self._calibration_samples[key]

File diff suppressed because one or more lines are too long

View File

@ -33,7 +33,7 @@ def parse_args():
parser.add_argument("--simname", type=str, required=True, parser.add_argument("--simname", type=str, required=True,
help="Simulation name.") help="Simulation name.")
parser.add_argument("--catalogue", type=str, required=True, parser.add_argument("--catalogue", type=str, required=True,
help="PV catalogue.") help="PV catalogues.")
parser.add_argument("--ksmooth", type=int, default=1, parser.add_argument("--ksmooth", type=int, default=1,
help="Smoothing index.") help="Smoothing index.")
parser.add_argument("--ksim", type=none_or_int, default=None, parser.add_argument("--ksim", type=none_or_int, default=None,
@ -42,7 +42,12 @@ def parse_args():
help="Number of devices to request.") help="Number of devices to request.")
parser.add_argument("--device", type=str, default="cpu", parser.add_argument("--device", type=str, default="cpu",
help="Device to use.") help="Device to use.")
return parser.parse_args() args = parser.parse_args()
# Convert the catalogue to a list of catalogues
args.catalogue = args.catalogue.split(",")
return args
ARGS = parse_args() ARGS = parse_args()
@ -58,8 +63,7 @@ from os.path import join
import csiborgtools # noqa import csiborgtools # noqa
import jax # noqa import jax # noqa
from h5py import File # noqa from h5py import File # noqa
from mpi4py import MPI # noqa from numpyro.infer import MCMC, NUTS, init_to_median # noqa
from numpyro.infer import MCMC, NUTS, Predictive, init_to_median # noqa
def print_variables(names, variables): def print_variables(names, variables):
@ -68,8 +72,8 @@ def print_variables(names, variables):
print(flush=True) print(flush=True)
def get_model(paths, get_model_kwargs, verbose=True): def get_models(get_model_kwargs, verbose=True):
"""Load the data and create the NumPyro model.""" """Load the data and create the NumPyro models."""
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring) paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
folder = "/mnt/extraspace/rstiskalek/catalogs/" folder = "/mnt/extraspace/rstiskalek/catalogs/"
@ -86,55 +90,45 @@ def get_model(paths, get_model_kwargs, verbose=True):
print(f"{'Num. realisations:':<20} {len(nsims)}") print(f"{'Num. realisations:':<20} {len(nsims)}")
print(flush=True) print(flush=True)
if ARGS.catalogue == "A2": # Get models
fpath = join(folder, "A2.h5") models = [None] * len(ARGS.catalogue)
elif ARGS.catalogue in ["LOSS", "Foundation", "Pantheon+", "SFI_gals", for i, cat in enumerate(ARGS.catalogue):
"2MTF", "SFI_groups", "SFI_gals_masked", if cat == "A2":
"Pantheon+_groups", "Pantheon+_groups_zSN", fpath = join(folder, "A2.h5")
"Pantheon+_zSN"]: elif cat in ["LOSS", "Foundation", "Pantheon+", "SFI_gals",
fpath = join(folder, "PV_compilation.hdf5") "2MTF", "SFI_groups", "SFI_gals_masked",
else: "Pantheon+_groups", "Pantheon+_groups_zSN",
raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.") "Pantheon+_zSN"]:
fpath = join(folder, "PV_compilation.hdf5")
else:
raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.")
loader = csiborgtools.flow.DataLoader(ARGS.simname, nsim_iterator, loader = csiborgtools.flow.DataLoader(ARGS.simname, nsim_iterator,
ARGS.catalogue, fpath, paths, cat, fpath, paths,
ksmooth=ARGS.ksmooth) ksmooth=ARGS.ksmooth)
models[i] = csiborgtools.flow.get_model(loader, **get_model_kwargs)
print(f"\n{'Num. radial steps':<20} {len(loader.rdist)}\n", flush=True) print(f"\n{'Num. radial steps':<20} {len(loader.rdist)}\n", flush=True)
return models
return csiborgtools.flow.get_model(loader, **get_model_kwargs)
def get_harmonic_evidence(samples, log_posterior, nchains_harmonic, epoch_num): def get_harmonic_evidence(samples, log_posterior, nchains_harmonic, epoch_num):
"""Compute evidence using the `harmonic` package.""" """Compute evidence using the `harmonic` package."""
data, names = csiborgtools.dict_samples_to_array(samples) data, names = csiborgtools.dict_samples_to_array(samples)
data = data.reshape(nchains_harmonic, -1, len(names)) data = data.reshape(nchains_harmonic, -1, len(names))
log_posterior = log_posterior.reshape(10, -1) log_posterior = log_posterior.reshape(nchains_harmonic, -1)
return csiborgtools.harmonic_evidence( return csiborgtools.harmonic_evidence(
data, log_posterior, return_flow_samples=False, epochs_num=epoch_num) data, log_posterior, return_flow_samples=False, epochs_num=epoch_num)
def get_simulation_weights(samples, model, model_kwargs):
"""Get the weights per posterior samples for each simulation."""
predictive = Predictive(model, samples)
ll_all = predictive(
jax.random.PRNGKey(1), store_ll_all=True, **model_kwargs)["ll_all"]
# Multiply the likelihood of galaxies
ll_per_simulation = jax.numpy.sum(ll_all, axis=-1)
# Normalization by summing the likelihood over simulations
norm = jax.scipy.special.logsumexp(ll_per_simulation, axis=-1)
return ll_per_simulation - norm[:, None]
def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta, def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
calculate_evidence, nchains_harmonic, epoch_num, kwargs_print): calculate_evidence, nchains_harmonic, epoch_num, kwargs_print):
"""Run the NumPyro model and save output to a file.""" """Run the NumPyro model and save output to a file."""
try: try:
ndata = model.ndata ndata = sum(model.ndata for model in model_kwargs["models"])
except AttributeError as e: except AttributeError as e:
raise AttributeError("The model must have an attribute `ndata` " raise AttributeError("The models must have an attribute `ndata` "
"indicating the number of data points.") from e "indicating the number of data points.") from e
nuts_kernel = NUTS(model, init_strategy=init_to_median(num_samples=1000)) nuts_kernel = NUTS(model, init_strategy=init_to_median(num_samples=1000))
@ -143,7 +137,6 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
mcmc.run(rng_key, extra_fields=("potential_energy",), **model_kwargs) mcmc.run(rng_key, extra_fields=("potential_energy",), **model_kwargs)
samples = mcmc.get_samples() samples = mcmc.get_samples()
simulation_weights = get_simulation_weights(samples, model, model_kwargs)
log_posterior = -mcmc.get_extra_fields()["potential_energy"] log_posterior = -mcmc.get_extra_fields()["potential_energy"]
log_likelihood = samples.pop("ll_values") log_likelihood = samples.pop("ll_values")
@ -165,7 +158,7 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
neg_ln_evidence = jax.numpy.nan neg_ln_evidence = jax.numpy.nan
neg_ln_evidence_err = (jax.numpy.nan, jax.numpy.nan) neg_ln_evidence_err = (jax.numpy.nan, jax.numpy.nan)
fname = f"samples_{ARGS.simname}_{ARGS.catalogue}_ksmooth{ARGS.ksmooth}.hdf5" # noqa fname = f"samples_{ARGS.simname}_{'+'.join(ARGS.catalogue)}_ksmooth{ARGS.ksmooth}.hdf5" # noqa
if ARGS.ksim is not None: if ARGS.ksim is not None:
fname = fname.replace(".hdf5", f"_nsim{ARGS.ksim}.hdf5") fname = fname.replace(".hdf5", f"_nsim{ARGS.ksim}.hdf5")
@ -183,7 +176,6 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
# Write log likelihood and posterior # Write log likelihood and posterior
f.create_dataset("log_likelihood", data=log_likelihood) f.create_dataset("log_likelihood", data=log_likelihood)
f.create_dataset("log_posterior", data=log_posterior) f.create_dataset("log_posterior", data=log_posterior)
f.create_dataset("simulation_weights", data=simulation_weights)
# Write goodness of fit # Write goodness of fit
grp = f.create_group("gof") grp = f.create_group("gof")
@ -215,6 +207,29 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
# Command line interface # # Command line interface #
############################################################################### ###############################################################################
def get_distmod_hyperparams(catalogue):
alpha_min = -1.0
alpha_max = 3.0
sample_alpha = True
if catalogue in ["LOSS", "Foundation", "Pantheon+", "Pantheon+_groups", "Pantheon+_zSN"]: # noqa
return {"e_mu_min": 0.001, "e_mu_max": 1.0,
"mag_cal_mean": -18.25, "mag_cal_std": 2.0,
"alpha_cal_mean": 0.148, "alpha_cal_std": 1.0,
"beta_cal_mean": 3.112, "beta_cal_std": 2.0,
"alpha_min": alpha_min, "alpha_max": alpha_max,
"sample_alpha": sample_alpha
}
elif catalogue in ["SFI_gals", "2MTF"]:
return {"e_mu_min": 0.001, "e_mu_max": 1.0,
"a_mean": -21., "a_std": 5.0,
"b_mean": -5.95, "b_std": 3.0,
"alpha_min": alpha_min, "alpha_max": alpha_max,
"sample_alpha": sample_alpha
}
else:
raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.")
if __name__ == "__main__": if __name__ == "__main__":
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring) paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
@ -227,14 +242,15 @@ if __name__ == "__main__":
########################################################################### ###########################################################################
nsteps = 5000 nsteps = 5000
nburn = 1000 nburn = 1500
zcmb_max = 0.06 zcmb_max = 0.05
calculate_evidence = False calculate_evidence = False
nchains_harmonic = 10 nchains_harmonic = 10
num_epochs = 30 num_epochs = 30
if nsteps % nchains_harmonic != 0: if nsteps % nchains_harmonic != 0:
raise ValueError("The number of steps must be divisible by the number of chains.") # noqa raise ValueError(
"The number of steps must be divisible by the number of chains.")
main_params = {"nsteps": nsteps, "nburn": nburn, "zcmb_max": zcmb_max, main_params = {"nsteps": nsteps, "nburn": nburn, "zcmb_max": zcmb_max,
"calculate_evidence": calculate_evidence, "calculate_evidence": calculate_evidence,
@ -244,42 +260,36 @@ if __name__ == "__main__":
calibration_hyperparams = {"Vext_min": -1000, "Vext_max": 1000, calibration_hyperparams = {"Vext_min": -1000, "Vext_max": 1000,
"Vmono_min": -1000, "Vmono_max": 1000, "Vmono_min": -1000, "Vmono_max": 1000,
"alpha_min": -1.0, "alpha_max": 3.0,
"beta_min": -1.0, "beta_max": 3.0, "beta_min": -1.0, "beta_max": 3.0,
"sigma_v_min": 1.0, "sigma_v_max": 750., "sigma_v_min": 1.0, "sigma_v_max": 750.,
"sample_Vmono": False, "sample_Vmono": False,
"sample_alpha": True,
"sample_beta": True, "sample_beta": True,
"sample_sigma_v_ext": False, "sample_sigma_v_ext": False,
} }
print_variables( print_variables(
calibration_hyperparams.keys(), calibration_hyperparams.values()) calibration_hyperparams.keys(), calibration_hyperparams.values())
if ARGS.catalogue in ["LOSS", "Foundation", "Pantheon+", "Pantheon+_groups", "Pantheon+_zSN"]: # noqa distmod_hyperparams_per_catalogue = []
distmod_hyperparams = {"e_mu_min": 0.001, "e_mu_max": 1.0, for cat in ARGS.catalogue:
"mag_cal_mean": -18.25, "mag_cal_std": 2.0, x = get_distmod_hyperparams(cat)
"alpha_cal_mean": 0.148, "alpha_cal_std": 1.0, print(f"\n{cat} hyperparameters:")
"beta_cal_mean": 3.112, "beta_cal_std": 2.0, print_variables(x.keys(), x.values())
} distmod_hyperparams_per_catalogue.append(x)
elif ARGS.catalogue in ["SFI_gals", "2MTF"]:
distmod_hyperparams = {"e_mu_min": 0.001, "e_mu_max": 1.0,
"a_mean": -21., "a_std": 5.0,
"b_mean": -5.95, "b_std": 3.0,
}
else:
raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.")
print_variables( kwargs_print = (main_params, calibration_hyperparams,
distmod_hyperparams.keys(), distmod_hyperparams.values()) *distmod_hyperparams_per_catalogue)
kwargs_print = (main_params, calibration_hyperparams, distmod_hyperparams)
########################################################################### ###########################################################################
model_kwargs = {"calibration_hyperparams": calibration_hyperparams,
"distmod_hyperparams": distmod_hyperparams}
get_model_kwargs = {"zcmb_max": zcmb_max} get_model_kwargs = {"zcmb_max": zcmb_max}
models = get_models(get_model_kwargs, )
model_kwargs = {
"models": models,
"field_calibration_hyperparams": calibration_hyperparams,
"distmod_hyperparams_per_model": distmod_hyperparams_per_catalogue,
}
model = csiborgtools.flow.PV_validation_model
model = get_model(paths, get_model_kwargs, )
run_model(model, nsteps, nburn, model_kwargs, out_folder, run_model(model, nsteps, nburn, model_kwargs, out_folder,
calibration_hyperparams["sample_beta"], calculate_evidence, calibration_hyperparams["sample_beta"], calculate_evidence,
nchains_harmonic, num_epochs, kwargs_print) nchains_harmonic, num_epochs, kwargs_print)

View File

@ -1,44 +1,66 @@
#!/bin/bash #!/bin/bash
memory=8 memory=7
on_login=${1} on_login=${1}
queue=${2}
ndevice=1 ndevice=1
device="gpu"
queue="gpulong"
gputype="rtx2080with12gb"
env="/mnt/users/rstiskalek/csiborgtools/venv_gpu_csiborgtools/bin/python"
file="flow_validation.py" file="flow_validation.py"
ksmooth=0 ksmooth=0
if [ "$on_login" != "1" ] && [ "$on_login" != "0" ]; then if [ "$on_login" != "1" ] && [ "$on_login" != "0" ]
echo "Invalid input: 'on_login' (1). Please provide 1 or 0." then
echo "'on_login' (1) must be either 0 or 1."
exit 1
fi
if [ "$queue" != "redwood" ] && [ "$queue" != "berg" ] && [ "$queue" != "cmb" ] && [ "$queue" != "gpulong" ] && [ "$queue" != "cmbgpu" ]; then
echo "Invalid queue: $queue (2). Please provide one of 'redwood', 'berg', 'cmb', 'gpulong', 'cmbgpu'."
exit 1 exit 1
fi fi
# Submit a job for each combination of simname, catalogue, ksim if [ "$queue" == "gpulong" ]
then
device="gpu"
gputype="rtx2080with12gb"
# gputype="rtx3070with8gb"
env="/mnt/users/rstiskalek/csiborgtools/venv_gpu_csiborgtools/bin/python"
elif [ "$queue" == "cmbgpu" ]
then
device="gpu"
gputype="rtx3090with24gb"
env="/mnt/users/rstiskalek/csiborgtools/venv_gpu_csiborgtools/bin/python"
else
device="cpu"
env="/mnt/users/rstiskalek/csiborgtools/venv_csiborg/bin/python"
fi
# for simname in "Lilow2024" "CF4" "CF4gp" "csiborg1" "csiborg2_main" "csiborg2X"; do # for simname in "Lilow2024" "CF4" "CF4gp" "csiborg1" "csiborg2_main" "csiborg2X"; do
for simname in "Carrick2015"; do for simname in "Carrick2015"; do
# for simname in "csiborg1" "csiborg2_main" "csiborg2X"; do for catalogue in "LOSS,2MTF,SFI_gals"; do
for catalogue in "Foundation"; do
# for catalogue in "2MTF"; do
# for ksim in 0 1 2; do
# for ksim in 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20; do # for ksim in 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20; do
for ksim in "none"; do for ksim in "none"; do
pythoncm="$env $file --catalogue $catalogue --simname $simname --ksim $ksim --ksmooth $ksmooth --ndevice $ndevice --device $device" pythoncm="$env $file --catalogue $catalogue --simname $simname --ksim $ksim --ksmooth $ksmooth --ndevice $ndevice --device $device"
if [ $on_login -eq 1 ]; then if [ "$on_login" == "1" ]; then
echo $pythoncm echo $pythoncm
$pythoncm eval $pythoncm
else else
cm="addqueue -q $queue -s -m $memory --gpus 1 --gputype $gputype $pythoncm" if [ "$device" == "gpu" ]; then
cm="addqueue -q $queue -s -m $memory --gpus 1 --gputype $gputype $pythoncm"
else
cm="addqueue -s -q $queue -n 1 -m $memory $pythoncm"
fi
echo "Submitting:" echo "Submitting:"
echo $cm echo $cm
eval $cm eval $cm
fi fi
echo echo
sleep 0.001 sleep 0.001
done done
done done
done done