mirror of
https://github.com/Richard-Sti/csiborgtools.git
synced 2024-12-22 17:08:03 +00:00
Combine PV samples (#139)
* Update imports * Update submission script * Update script * Add simulataenous sampling of many catalogues * Update nb
This commit is contained in:
parent
9756175943
commit
3b46f17ead
5 changed files with 356 additions and 147 deletions
|
@ -12,8 +12,8 @@
|
||||||
# You should have received a copy of the GNU General Public License along
|
# You should have received a copy of the GNU General Public License along
|
||||||
# with this program; if not, write to the Free Software Foundation, Inc.,
|
# with this program; if not, write to the Free Software Foundation, Inc.,
|
||||||
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
|
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
|
||||||
from .flow_model import (DataLoader, radial_velocity_los, dist2redshift, # noqa
|
from .flow_model import (DataLoader, PV_LogLikelihood, PV_validation_model, # noqa
|
||||||
dist2distmodulus, predict_zobs, project_Vext, # noqa
|
dist2distmodulus, dist2redshift, distmodulus2dist, # noqa
|
||||||
PV_validation_model, get_model, distmodulus2dist, # noqa
|
get_model, Observed2CosmologicalRedshift, # noqa
|
||||||
Observed2CosmologicalRedshift, # noqa
|
predict_zobs, project_Vext, # noqa
|
||||||
stack_pzosmo_over_realizations) # noqa
|
radial_velocity_los, stack_pzosmo_over_realizations) # noqa
|
||||||
|
|
|
@ -24,15 +24,14 @@ References
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpyro
|
|
||||||
from astropy.cosmology import FlatLambdaCDM, z_at_value
|
|
||||||
from astropy import units as u
|
from astropy import units as u
|
||||||
|
from astropy.cosmology import FlatLambdaCDM, z_at_value
|
||||||
from h5py import File
|
from h5py import File
|
||||||
from jax import jit
|
from jax import jit
|
||||||
from jax import numpy as jnp
|
from jax import numpy as jnp
|
||||||
from jax import vmap
|
from jax import vmap
|
||||||
from jax.scipy.special import logsumexp
|
from jax.scipy.special import logsumexp
|
||||||
from numpyro import sample
|
from numpyro import deterministic, factor, sample
|
||||||
from numpyro.distributions import Normal, Uniform
|
from numpyro.distributions import Normal, Uniform
|
||||||
from quadax import simpson
|
from quadax import simpson
|
||||||
from scipy.interpolate import interp1d
|
from scipy.interpolate import interp1d
|
||||||
|
@ -558,15 +557,23 @@ def e2_distmod_SN(e2_mB, e2_x1, e2_c, alpha_cal, beta_cal, e_mu_intrinsic):
|
||||||
|
|
||||||
|
|
||||||
def sample_SN(e_mu_min, e_mu_max, mag_cal_mean, mag_cal_std, alpha_cal_mean,
|
def sample_SN(e_mu_min, e_mu_max, mag_cal_mean, mag_cal_std, alpha_cal_mean,
|
||||||
alpha_cal_std, beta_cal_mean, beta_cal_std):
|
alpha_cal_std, beta_cal_mean, beta_cal_std, alpha_min, alpha_max,
|
||||||
|
sample_alpha, name):
|
||||||
"""Sample SNIe Tripp parameters."""
|
"""Sample SNIe Tripp parameters."""
|
||||||
e_mu = sample("e_mu", Uniform(e_mu_min, e_mu_max))
|
e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max))
|
||||||
mag_cal = sample("mag_cal", Normal(mag_cal_mean, mag_cal_std))
|
mag_cal = sample(f"mag_cal_{name}", Normal(mag_cal_mean, mag_cal_std))
|
||||||
alpha_cal = sample("alpha_cal", Normal(alpha_cal_mean, alpha_cal_std))
|
alpha_cal = sample(
|
||||||
|
f"alpha_cal_{name}", Normal(alpha_cal_mean, alpha_cal_std))
|
||||||
|
beta_cal = sample(f"beta_cal_{name}", Normal(beta_cal_mean, beta_cal_std))
|
||||||
|
alpha = sample(f"alpha_{name}",
|
||||||
|
Uniform(alpha_min, alpha_max)) if sample_alpha else 1.0
|
||||||
|
|
||||||
beta_cal = sample("beta_cal", Normal(beta_cal_mean, beta_cal_std))
|
return {"e_mu": e_mu,
|
||||||
|
"mag_cal": mag_cal,
|
||||||
return e_mu, mag_cal, alpha_cal, beta_cal
|
"alpha_cal": alpha_cal,
|
||||||
|
"beta_cal": beta_cal,
|
||||||
|
"alpha": alpha
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
###############################################################################
|
###############################################################################
|
||||||
|
@ -583,34 +590,42 @@ def e2_distmod_TFR(e2_mag, e2_eta, b, e_mu_intrinsic):
|
||||||
return e2_mag + b**2 * e2_eta + e_mu_intrinsic**2
|
return e2_mag + b**2 * e2_eta + e_mu_intrinsic**2
|
||||||
|
|
||||||
|
|
||||||
def sample_TFR(e_mu_min, e_mu_max, a_mean, a_std, b_mean, b_std):
|
def sample_TFR(e_mu_min, e_mu_max, a_mean, a_std, b_mean, b_std, alpha_min,
|
||||||
|
alpha_max, sample_alpha, name):
|
||||||
"""Sample Tully-Fisher calibration parameters."""
|
"""Sample Tully-Fisher calibration parameters."""
|
||||||
e_mu = sample("e_mu", Uniform(e_mu_min, e_mu_max))
|
e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max))
|
||||||
a = sample("a", Normal(a_mean, a_std))
|
a = sample(f"a_{name}", Normal(a_mean, a_std))
|
||||||
b = sample("b", Normal(b_mean, b_std))
|
b = sample(f"b_{name}", Normal(b_mean, b_std))
|
||||||
|
alpha = sample(f"alpha_{name}",
|
||||||
return e_mu, a, b
|
Uniform(alpha_min, alpha_max)) if sample_alpha else 1.0
|
||||||
|
|
||||||
|
return {"e_mu": e_mu,
|
||||||
|
"a": a,
|
||||||
|
"b": b,
|
||||||
|
"alpha": alpha
|
||||||
|
}
|
||||||
|
|
||||||
###############################################################################
|
###############################################################################
|
||||||
# Calibration parameters sampling #
|
# Calibration parameters sampling #
|
||||||
###############################################################################
|
###############################################################################
|
||||||
|
|
||||||
|
|
||||||
def sample_calibration(Vext_min, Vext_max, Vmono_min, Vmono_max,
|
def sample_calibration(Vext_min, Vext_max, Vmono_min, Vmono_max, beta_min,
|
||||||
alpha_min, alpha_max, beta_min, beta_max, sigma_v_min,
|
beta_max, sigma_v_min, sigma_v_max, sample_Vmono,
|
||||||
sigma_v_max, sample_Vmono, sample_alpha, sample_beta,
|
sample_beta, sample_sigma_v_ext):
|
||||||
sample_sigma_v_ext):
|
|
||||||
"""Sample the flow calibration."""
|
"""Sample the flow calibration."""
|
||||||
Vext = sample("Vext", Uniform(Vext_min, Vext_max).expand([3]))
|
Vext = sample("Vext", Uniform(Vext_min, Vext_max).expand([3]))
|
||||||
sigma_v = sample("sigma_v", Uniform(sigma_v_min, sigma_v_max))
|
sigma_v = sample("sigma_v", Uniform(sigma_v_min, sigma_v_max))
|
||||||
|
|
||||||
alpha = sample("alpha", Uniform(alpha_min, alpha_max)) if sample_alpha else 1.0 # noqa
|
|
||||||
beta = sample("beta", Uniform(beta_min, beta_max)) if sample_beta else 1.0 # noqa
|
beta = sample("beta", Uniform(beta_min, beta_max)) if sample_beta else 1.0 # noqa
|
||||||
Vmono = sample("Vmono", Uniform(Vmono_min, Vmono_max)) if sample_Vmono else 0.0 # noqa
|
Vmono = sample("Vmono", Uniform(Vmono_min, Vmono_max)) if sample_Vmono else 0.0 # noqa
|
||||||
sigma_v_ext = sample("sigma_v_ext", Uniform(sigma_v_min, sigma_v_max)) if sample_sigma_v_ext else sigma_v # noqa
|
sigma_v_ext = sample("sigma_v_ext", Uniform(sigma_v_min, sigma_v_max)) if sample_sigma_v_ext else sigma_v # noqa
|
||||||
|
|
||||||
return Vext, Vmono, sigma_v, sigma_v_ext, alpha, beta
|
return {"Vext": Vext,
|
||||||
|
"Vmono": Vmono,
|
||||||
|
"sigma_v": sigma_v,
|
||||||
|
"sigma_v_ext": sigma_v_ext,
|
||||||
|
"beta": beta}
|
||||||
|
|
||||||
|
|
||||||
###############################################################################
|
###############################################################################
|
||||||
|
@ -635,9 +650,9 @@ def find_extrap_mask(rmax, rdist):
|
||||||
return extrap_mask, extrap_weights
|
return extrap_mask, extrap_weights
|
||||||
|
|
||||||
|
|
||||||
class PV_validation_model(BaseFlowValidationModel):
|
class PV_LogLikelihood(BaseFlowValidationModel):
|
||||||
"""
|
"""
|
||||||
Peculiar velocity validation model.
|
Peculiar velocity validation model log-likelihood.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
@ -663,7 +678,7 @@ class PV_validation_model(BaseFlowValidationModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, los_density, los_velocity, rmax, RA, dec, z_obs,
|
def __init__(self, los_density, los_velocity, rmax, RA, dec, z_obs,
|
||||||
e_zobs, calibration_params, r_xrange, Omega_m, kind):
|
e_zobs, calibration_params, r_xrange, Omega_m, kind, name):
|
||||||
if e_zobs is not None:
|
if e_zobs is not None:
|
||||||
e2_cz_obs = jnp.asarray((SPEED_OF_LIGHT * e_zobs)**2)
|
e2_cz_obs = jnp.asarray((SPEED_OF_LIGHT * e_zobs)**2)
|
||||||
else:
|
else:
|
||||||
|
@ -681,6 +696,7 @@ class PV_validation_model(BaseFlowValidationModel):
|
||||||
self._set_radial_spacing(r_xrange, Omega_m)
|
self._set_radial_spacing(r_xrange, Omega_m)
|
||||||
|
|
||||||
self.kind = kind
|
self.kind = kind
|
||||||
|
self.name = name
|
||||||
self.Omega_m = Omega_m
|
self.Omega_m = Omega_m
|
||||||
self.norm = - self.ndata * jnp.log(self.num_sims)
|
self.norm = - self.ndata * jnp.log(self.num_sims)
|
||||||
|
|
||||||
|
@ -688,29 +704,37 @@ class PV_validation_model(BaseFlowValidationModel):
|
||||||
self.extrap_mask = jnp.asarray(extrap_mask)
|
self.extrap_mask = jnp.asarray(extrap_mask)
|
||||||
self.extrap_weights = jnp.asarray(extrap_weights)
|
self.extrap_weights = jnp.asarray(extrap_weights)
|
||||||
|
|
||||||
def __call__(self, calibration_hyperparams, distmod_hyperparams,
|
def __call__(self, field_calibration_params, distmod_params,
|
||||||
store_ll_all=False):
|
sample_sigma_v_ext):
|
||||||
"""NumPyro PV validation model."""
|
"""PV validation model log-likelihood."""
|
||||||
Vext, Vmono, sigma_v, sigma_v_ext, alpha, beta = sample_calibration(**calibration_hyperparams) # noqa
|
|
||||||
# Turn e2_cz to be of shape (nsims, ndata, nxrange) and apply
|
# Turn e2_cz to be of shape (nsims, ndata, nxrange) and apply
|
||||||
# sigma_v_ext where applicable
|
# sigma_v_ext where applicable
|
||||||
|
sigma_v = field_calibration_params["sigma_v"]
|
||||||
|
sigma_v_ext = field_calibration_params["sigma_v_ext"]
|
||||||
e2_cz = jnp.full_like(self.extrap_mask, sigma_v**2, dtype=jnp.float32)
|
e2_cz = jnp.full_like(self.extrap_mask, sigma_v**2, dtype=jnp.float32)
|
||||||
if calibration_hyperparams["sample_sigma_v_ext"]:
|
if sample_sigma_v_ext:
|
||||||
e2_cz = e2_cz.at[self.extrap_mask].set(sigma_v_ext**2)
|
e2_cz = e2_cz.at[self.extrap_mask].set(sigma_v_ext**2)
|
||||||
|
|
||||||
# Now add the observational errors
|
# Now add the observational errors
|
||||||
e2_cz += self.e2_cz_obs[None, :, None]
|
e2_cz += self.e2_cz_obs[None, :, None]
|
||||||
|
|
||||||
|
Vext = field_calibration_params["Vext"]
|
||||||
|
Vmono = field_calibration_params["Vmono"]
|
||||||
Vext_rad = project_Vext(Vext[0], Vext[1], Vext[2], self.RA, self.dec)
|
Vext_rad = project_Vext(Vext[0], Vext[1], Vext[2], self.RA, self.dec)
|
||||||
|
|
||||||
|
e_mu = distmod_params["e_mu"]
|
||||||
|
alpha = distmod_params["alpha"]
|
||||||
if self.kind == "SN":
|
if self.kind == "SN":
|
||||||
e_mu, mag_cal, alpha_cal, beta_cal = sample_SN(**distmod_hyperparams) # noqa
|
mag_cal = distmod_params["mag_cal"]
|
||||||
|
alpha_cal = distmod_params["alpha_cal"]
|
||||||
|
beta_cal = distmod_params["beta_cal"]
|
||||||
mu = distmod_SN(
|
mu = distmod_SN(
|
||||||
self.mB, self.x1, self.c, mag_cal, alpha_cal, beta_cal)
|
self.mB, self.x1, self.c, mag_cal, alpha_cal, beta_cal)
|
||||||
squared_e_mu = e2_distmod_SN(
|
squared_e_mu = e2_distmod_SN(
|
||||||
self.e2_mB, self.e2_x1, self.e2_c, alpha_cal, beta_cal, e_mu)
|
self.e2_mB, self.e2_x1, self.e2_c, alpha_cal, beta_cal, e_mu)
|
||||||
elif self.kind == "TFR":
|
elif self.kind == "TFR":
|
||||||
e_mu, a, b = sample_TFR(**distmod_hyperparams)
|
a = distmod_params["a"]
|
||||||
|
b = distmod_params["b"]
|
||||||
mu = distmod_TFR(self.mag, self.eta, a, b)
|
mu = distmod_TFR(self.mag, self.eta, a, b)
|
||||||
squared_e_mu = e2_distmod_TFR(self.e2_mag, self.e2_eta, b, e_mu)
|
squared_e_mu = e2_distmod_TFR(self.e2_mag, self.e2_eta, b, e_mu)
|
||||||
else:
|
else:
|
||||||
|
@ -725,20 +749,51 @@ class PV_validation_model(BaseFlowValidationModel):
|
||||||
|
|
||||||
# Calculate z_obs at each distance. Shape is (n_sims, ndata, nxrange)
|
# Calculate z_obs at each distance. Shape is (n_sims, ndata, nxrange)
|
||||||
# The weights are related to the extrapolation of the velocity field.
|
# The weights are related to the extrapolation of the velocity field.
|
||||||
vrad = beta * self.los_velocity
|
vrad = field_calibration_params["beta"] * self.los_velocity
|
||||||
vrad += (Vext_rad[None, :, None] + Vmono) * self.extrap_weights
|
vrad += (Vext_rad[None, :, None] + Vmono) * self.extrap_weights
|
||||||
zobs = (1 + self.z_xrange[None, None, :]) * (1 + vrad / SPEED_OF_LIGHT) - 1 # noqa
|
zobs = (1 + self.z_xrange[None, None, :]) * (1 + vrad / SPEED_OF_LIGHT) - 1 # noqa
|
||||||
|
|
||||||
ptilde *= calculate_likelihood_zobs(self.z_obs, zobs, e2_cz)
|
ptilde *= calculate_likelihood_zobs(self.z_obs, zobs, e2_cz)
|
||||||
# ptilde *= calculate_likelihood_zobs(self.z_obs, zobs, sigma_v)
|
|
||||||
ll = jnp.log(simpson(ptilde, dx=self.dr, axis=-1)) - jnp.log(pnorm)
|
ll = jnp.log(simpson(ptilde, dx=self.dr, axis=-1)) - jnp.log(pnorm)
|
||||||
|
|
||||||
if store_ll_all:
|
return jnp.sum(logsumexp(ll, axis=0)) + self.norm
|
||||||
numpyro.deterministic("ll_all", ll)
|
|
||||||
|
|
||||||
ll = jnp.sum(logsumexp(ll, axis=0)) + self.norm
|
|
||||||
numpyro.deterministic("ll_values", ll)
|
def PV_validation_model(models, distmod_hyperparams_per_model,
|
||||||
numpyro.factor("ll", ll)
|
field_calibration_hyperparams):
|
||||||
|
"""
|
||||||
|
Peculiar velocity validation NumPyro model.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
models : list of `PV_LogLikelihood`
|
||||||
|
List of PV validation log-likelihoods for each catalogue.
|
||||||
|
distmod_hyperparams_per_model : list of dict
|
||||||
|
Distance modulus hyperparameters for each model/catalogue.
|
||||||
|
field_calibration_hyperparams : dict
|
||||||
|
Field calibration hyperparameters.
|
||||||
|
"""
|
||||||
|
field_calibration_params = sample_calibration(
|
||||||
|
**field_calibration_hyperparams)
|
||||||
|
sample_sigma_v_ext = field_calibration_hyperparams["sample_sigma_v_ext"]
|
||||||
|
|
||||||
|
ll = 0.0
|
||||||
|
for n in range(len(models)):
|
||||||
|
model = models[n]
|
||||||
|
distmod_hyperparams = distmod_hyperparams_per_model[n]
|
||||||
|
|
||||||
|
if model.kind == "TFR":
|
||||||
|
distmod_params = sample_TFR(**distmod_hyperparams, name=model.name)
|
||||||
|
elif model.kind == "SN":
|
||||||
|
distmod_params = sample_SN(**distmod_hyperparams, name=model.name)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown kind: `{model.kind}`.")
|
||||||
|
|
||||||
|
ll += model(
|
||||||
|
field_calibration_params, distmod_params, sample_sigma_v_ext)
|
||||||
|
|
||||||
|
deterministic("ll_values", ll)
|
||||||
|
factor("ll", ll)
|
||||||
|
|
||||||
|
|
||||||
###############################################################################
|
###############################################################################
|
||||||
|
@ -780,10 +835,10 @@ def get_model(loader, zcmb_max=None, verbose=True):
|
||||||
"e_mB": e_mB[mask], "e_x1": e_x1[mask],
|
"e_mB": e_mB[mask], "e_x1": e_x1[mask],
|
||||||
"e_c": e_c[mask]}
|
"e_c": e_c[mask]}
|
||||||
|
|
||||||
model = PV_validation_model(
|
model = PV_LogLikelihood(
|
||||||
los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask],
|
los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask],
|
||||||
RA[mask], dec[mask], zCMB[mask], e_zCMB, calibration_params,
|
RA[mask], dec[mask], zCMB[mask], e_zCMB, calibration_params,
|
||||||
loader.rdist, loader._Omega_m, "SN")
|
loader.rdist, loader._Omega_m, "SN", name=kind)
|
||||||
elif "Pantheon+" in kind:
|
elif "Pantheon+" in kind:
|
||||||
keys = ["RA", "DEC", "zCMB", "mB", "x1", "c", "biasCor_m_b", "mBERR",
|
keys = ["RA", "DEC", "zCMB", "mB", "x1", "c", "biasCor_m_b", "mBERR",
|
||||||
"x1ERR", "cERR", "biasCorErr_m_b", "zCMB_SN", "zCMB_Group",
|
"x1ERR", "cERR", "biasCorErr_m_b", "zCMB_SN", "zCMB_Group",
|
||||||
|
@ -808,10 +863,10 @@ def get_model(loader, zcmb_max=None, verbose=True):
|
||||||
calibration_params = {"mB": mB[mask], "x1": x1[mask], "c": c[mask],
|
calibration_params = {"mB": mB[mask], "x1": x1[mask], "c": c[mask],
|
||||||
"e_mB": e_mB[mask], "e_x1": e_x1[mask],
|
"e_mB": e_mB[mask], "e_x1": e_x1[mask],
|
||||||
"e_c": e_c[mask]}
|
"e_c": e_c[mask]}
|
||||||
model = PV_validation_model(
|
model = PV_LogLikelihood(
|
||||||
los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask],
|
los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask],
|
||||||
RA[mask], dec[mask], zCMB[mask], e_zCMB[mask], calibration_params,
|
RA[mask], dec[mask], zCMB[mask], e_zCMB[mask], calibration_params,
|
||||||
loader.rdist, loader._Omega_m, "SN")
|
loader.rdist, loader._Omega_m, "SN", name=kind)
|
||||||
elif kind in ["SFI_gals", "2MTF", "SFI_gals_masked"]:
|
elif kind in ["SFI_gals", "2MTF", "SFI_gals_masked"]:
|
||||||
keys = ["RA", "DEC", "z_CMB", "mag", "eta", "e_mag", "e_eta"]
|
keys = ["RA", "DEC", "z_CMB", "mag", "eta", "e_mag", "e_eta"]
|
||||||
RA, dec, zCMB, mag, eta, e_mag, e_eta = (loader.cat[k] for k in keys)
|
RA, dec, zCMB, mag, eta, e_mag, e_eta = (loader.cat[k] for k in keys)
|
||||||
|
@ -824,14 +879,14 @@ def get_model(loader, zcmb_max=None, verbose=True):
|
||||||
|
|
||||||
calibration_params = {"mag": mag[mask], "eta": eta[mask],
|
calibration_params = {"mag": mag[mask], "eta": eta[mask],
|
||||||
"e_mag": e_mag[mask], "e_eta": e_eta[mask]}
|
"e_mag": e_mag[mask], "e_eta": e_eta[mask]}
|
||||||
model = PV_validation_model(
|
model = PV_LogLikelihood(
|
||||||
los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask],
|
los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask],
|
||||||
RA[mask], dec[mask], zCMB[mask], None, calibration_params,
|
RA[mask], dec[mask], zCMB[mask], None, calibration_params,
|
||||||
loader.rdist, loader._Omega_m, "TFR")
|
loader.rdist, loader._Omega_m, "TFR", name=kind)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Catalogue `{kind}` not recognized.")
|
raise ValueError(f"Catalogue `{kind}` not recognized.")
|
||||||
|
|
||||||
fprint(f"selected {np.sum(mask)}/{len(mask)} galaxies.")
|
fprint(f"selected {np.sum(mask)}/{len(mask)} galaxies in catalogue `{kind}`") # noqa
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@ -848,9 +903,12 @@ def _posterior_element(r, beta, Vext_radial, los_velocity, Omega_m, zobs,
|
||||||
`Observed2CosmologicalRedshift`.
|
`Observed2CosmologicalRedshift`.
|
||||||
"""
|
"""
|
||||||
zobs_pred = predict_zobs(r, beta, Vext_radial, los_velocity, Omega_m)
|
zobs_pred = predict_zobs(r, beta, Vext_radial, los_velocity, Omega_m)
|
||||||
|
|
||||||
# Likelihood term
|
# Likelihood term
|
||||||
dcz = SPEED_OF_LIGHT * (zobs - zobs_pred)
|
dcz = SPEED_OF_LIGHT * (zobs - zobs_pred)
|
||||||
posterior = jnp.exp(-0.5 * dcz**2 / sigma_v**2) / jnp.sqrt(2 * jnp.pi * sigma_v**2) # noqa
|
posterior = jnp.exp(-0.5 * dcz**2 / sigma_v**2)
|
||||||
|
posterior /= jnp.sqrt(2 * jnp.pi * sigma_v**2)
|
||||||
|
|
||||||
# Prior term
|
# Prior term
|
||||||
posterior *= dVdOmega * los_density**alpha
|
posterior *= dVdOmega * los_density**alpha
|
||||||
|
|
||||||
|
@ -864,7 +922,8 @@ class BaseObserved2CosmologicalRedshift(ABC):
|
||||||
for i, key in enumerate(calibration_samples.keys()):
|
for i, key in enumerate(calibration_samples.keys()):
|
||||||
x = calibration_samples[key]
|
x = calibration_samples[key]
|
||||||
if not isinstance(x, (np.ndarray, jnp.ndarray)):
|
if not isinstance(x, (np.ndarray, jnp.ndarray)):
|
||||||
raise ValueError(f"Calibration sample `{key}` must be an array.") # noqa
|
raise ValueError(
|
||||||
|
f"Calibration sample `{key}` must be an array.")
|
||||||
|
|
||||||
if x.ndim != 1 and key != "Vext":
|
if x.ndim != 1 and key != "Vext":
|
||||||
raise ValueError(f"Calibration samples `{key}` must be 1D.")
|
raise ValueError(f"Calibration samples `{key}` must be 1D.")
|
||||||
|
@ -873,14 +932,19 @@ class BaseObserved2CosmologicalRedshift(ABC):
|
||||||
ncalibratrion = len(x)
|
ncalibratrion = len(x)
|
||||||
|
|
||||||
if len(x) != ncalibratrion:
|
if len(x) != ncalibratrion:
|
||||||
raise ValueError("Calibration samples do not have the same length.") # noqa
|
raise ValueError(
|
||||||
|
"Calibration samples do not have the same length.")
|
||||||
|
|
||||||
calibration_samples[key] = jnp.asarray(x)
|
calibration_samples[key] = jnp.asarray(x)
|
||||||
|
|
||||||
if "alpha" not in calibration_samples:
|
if "alpha" not in calibration_samples:
|
||||||
|
print("No `alpha` calibration sample found. Setting it to 1.",
|
||||||
|
flush=True)
|
||||||
calibration_samples["alpha"] = jnp.ones(ncalibratrion)
|
calibration_samples["alpha"] = jnp.ones(ncalibratrion)
|
||||||
|
|
||||||
if "beta" not in calibration_samples:
|
if "beta" not in calibration_samples:
|
||||||
|
print("No `beta` calibration sample found. Setting it to 1.",
|
||||||
|
flush=True)
|
||||||
calibration_samples["beta"] = jnp.ones(ncalibratrion)
|
calibration_samples["beta"] = jnp.ones(ncalibratrion)
|
||||||
|
|
||||||
# Get the stepsize, we need it to be constant for Simpson's rule.
|
# Get the stepsize, we need it to be constant for Simpson's rule.
|
||||||
|
@ -898,7 +962,8 @@ class BaseObserved2CosmologicalRedshift(ABC):
|
||||||
def get_calibration_samples(self, key):
|
def get_calibration_samples(self, key):
|
||||||
"""Get calibration samples for a given key."""
|
"""Get calibration samples for a given key."""
|
||||||
if key not in self._calibration_samples:
|
if key not in self._calibration_samples:
|
||||||
raise ValueError(f"Key `{key}` not found in calibration samples. Available keys are: `{self.calibration_keys}`.") # noqa
|
raise ValueError(f"Key `{key}` not found in calibration samples. "
|
||||||
|
f"Available keys are: `{self.calibration_keys}`.")
|
||||||
|
|
||||||
return self._calibration_samples[key]
|
return self._calibration_samples[key]
|
||||||
|
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -33,7 +33,7 @@ def parse_args():
|
||||||
parser.add_argument("--simname", type=str, required=True,
|
parser.add_argument("--simname", type=str, required=True,
|
||||||
help="Simulation name.")
|
help="Simulation name.")
|
||||||
parser.add_argument("--catalogue", type=str, required=True,
|
parser.add_argument("--catalogue", type=str, required=True,
|
||||||
help="PV catalogue.")
|
help="PV catalogues.")
|
||||||
parser.add_argument("--ksmooth", type=int, default=1,
|
parser.add_argument("--ksmooth", type=int, default=1,
|
||||||
help="Smoothing index.")
|
help="Smoothing index.")
|
||||||
parser.add_argument("--ksim", type=none_or_int, default=None,
|
parser.add_argument("--ksim", type=none_or_int, default=None,
|
||||||
|
@ -42,7 +42,12 @@ def parse_args():
|
||||||
help="Number of devices to request.")
|
help="Number of devices to request.")
|
||||||
parser.add_argument("--device", type=str, default="cpu",
|
parser.add_argument("--device", type=str, default="cpu",
|
||||||
help="Device to use.")
|
help="Device to use.")
|
||||||
return parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Convert the catalogue to a list of catalogues
|
||||||
|
args.catalogue = args.catalogue.split(",")
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
ARGS = parse_args()
|
ARGS = parse_args()
|
||||||
|
@ -58,8 +63,7 @@ from os.path import join
|
||||||
import csiborgtools # noqa
|
import csiborgtools # noqa
|
||||||
import jax # noqa
|
import jax # noqa
|
||||||
from h5py import File # noqa
|
from h5py import File # noqa
|
||||||
from mpi4py import MPI # noqa
|
from numpyro.infer import MCMC, NUTS, init_to_median # noqa
|
||||||
from numpyro.infer import MCMC, NUTS, Predictive, init_to_median # noqa
|
|
||||||
|
|
||||||
|
|
||||||
def print_variables(names, variables):
|
def print_variables(names, variables):
|
||||||
|
@ -68,8 +72,8 @@ def print_variables(names, variables):
|
||||||
print(flush=True)
|
print(flush=True)
|
||||||
|
|
||||||
|
|
||||||
def get_model(paths, get_model_kwargs, verbose=True):
|
def get_models(get_model_kwargs, verbose=True):
|
||||||
"""Load the data and create the NumPyro model."""
|
"""Load the data and create the NumPyro models."""
|
||||||
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
|
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
|
||||||
folder = "/mnt/extraspace/rstiskalek/catalogs/"
|
folder = "/mnt/extraspace/rstiskalek/catalogs/"
|
||||||
|
|
||||||
|
@ -86,9 +90,12 @@ def get_model(paths, get_model_kwargs, verbose=True):
|
||||||
print(f"{'Num. realisations:':<20} {len(nsims)}")
|
print(f"{'Num. realisations:':<20} {len(nsims)}")
|
||||||
print(flush=True)
|
print(flush=True)
|
||||||
|
|
||||||
if ARGS.catalogue == "A2":
|
# Get models
|
||||||
|
models = [None] * len(ARGS.catalogue)
|
||||||
|
for i, cat in enumerate(ARGS.catalogue):
|
||||||
|
if cat == "A2":
|
||||||
fpath = join(folder, "A2.h5")
|
fpath = join(folder, "A2.h5")
|
||||||
elif ARGS.catalogue in ["LOSS", "Foundation", "Pantheon+", "SFI_gals",
|
elif cat in ["LOSS", "Foundation", "Pantheon+", "SFI_gals",
|
||||||
"2MTF", "SFI_groups", "SFI_gals_masked",
|
"2MTF", "SFI_groups", "SFI_gals_masked",
|
||||||
"Pantheon+_groups", "Pantheon+_groups_zSN",
|
"Pantheon+_groups", "Pantheon+_groups_zSN",
|
||||||
"Pantheon+_zSN"]:
|
"Pantheon+_zSN"]:
|
||||||
|
@ -97,44 +104,31 @@ def get_model(paths, get_model_kwargs, verbose=True):
|
||||||
raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.")
|
raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.")
|
||||||
|
|
||||||
loader = csiborgtools.flow.DataLoader(ARGS.simname, nsim_iterator,
|
loader = csiborgtools.flow.DataLoader(ARGS.simname, nsim_iterator,
|
||||||
ARGS.catalogue, fpath, paths,
|
cat, fpath, paths,
|
||||||
ksmooth=ARGS.ksmooth)
|
ksmooth=ARGS.ksmooth)
|
||||||
|
models[i] = csiborgtools.flow.get_model(loader, **get_model_kwargs)
|
||||||
|
|
||||||
print(f"\n{'Num. radial steps':<20} {len(loader.rdist)}\n", flush=True)
|
print(f"\n{'Num. radial steps':<20} {len(loader.rdist)}\n", flush=True)
|
||||||
|
return models
|
||||||
return csiborgtools.flow.get_model(loader, **get_model_kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def get_harmonic_evidence(samples, log_posterior, nchains_harmonic, epoch_num):
|
def get_harmonic_evidence(samples, log_posterior, nchains_harmonic, epoch_num):
|
||||||
"""Compute evidence using the `harmonic` package."""
|
"""Compute evidence using the `harmonic` package."""
|
||||||
data, names = csiborgtools.dict_samples_to_array(samples)
|
data, names = csiborgtools.dict_samples_to_array(samples)
|
||||||
data = data.reshape(nchains_harmonic, -1, len(names))
|
data = data.reshape(nchains_harmonic, -1, len(names))
|
||||||
log_posterior = log_posterior.reshape(10, -1)
|
log_posterior = log_posterior.reshape(nchains_harmonic, -1)
|
||||||
|
|
||||||
return csiborgtools.harmonic_evidence(
|
return csiborgtools.harmonic_evidence(
|
||||||
data, log_posterior, return_flow_samples=False, epochs_num=epoch_num)
|
data, log_posterior, return_flow_samples=False, epochs_num=epoch_num)
|
||||||
|
|
||||||
|
|
||||||
def get_simulation_weights(samples, model, model_kwargs):
|
|
||||||
"""Get the weights per posterior samples for each simulation."""
|
|
||||||
predictive = Predictive(model, samples)
|
|
||||||
ll_all = predictive(
|
|
||||||
jax.random.PRNGKey(1), store_ll_all=True, **model_kwargs)["ll_all"]
|
|
||||||
|
|
||||||
# Multiply the likelihood of galaxies
|
|
||||||
ll_per_simulation = jax.numpy.sum(ll_all, axis=-1)
|
|
||||||
# Normalization by summing the likelihood over simulations
|
|
||||||
norm = jax.scipy.special.logsumexp(ll_per_simulation, axis=-1)
|
|
||||||
return ll_per_simulation - norm[:, None]
|
|
||||||
|
|
||||||
|
|
||||||
def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
|
def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
|
||||||
calculate_evidence, nchains_harmonic, epoch_num, kwargs_print):
|
calculate_evidence, nchains_harmonic, epoch_num, kwargs_print):
|
||||||
"""Run the NumPyro model and save output to a file."""
|
"""Run the NumPyro model and save output to a file."""
|
||||||
try:
|
try:
|
||||||
ndata = model.ndata
|
ndata = sum(model.ndata for model in model_kwargs["models"])
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
raise AttributeError("The model must have an attribute `ndata` "
|
raise AttributeError("The models must have an attribute `ndata` "
|
||||||
"indicating the number of data points.") from e
|
"indicating the number of data points.") from e
|
||||||
|
|
||||||
nuts_kernel = NUTS(model, init_strategy=init_to_median(num_samples=1000))
|
nuts_kernel = NUTS(model, init_strategy=init_to_median(num_samples=1000))
|
||||||
|
@ -143,7 +137,6 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
|
||||||
|
|
||||||
mcmc.run(rng_key, extra_fields=("potential_energy",), **model_kwargs)
|
mcmc.run(rng_key, extra_fields=("potential_energy",), **model_kwargs)
|
||||||
samples = mcmc.get_samples()
|
samples = mcmc.get_samples()
|
||||||
simulation_weights = get_simulation_weights(samples, model, model_kwargs)
|
|
||||||
|
|
||||||
log_posterior = -mcmc.get_extra_fields()["potential_energy"]
|
log_posterior = -mcmc.get_extra_fields()["potential_energy"]
|
||||||
log_likelihood = samples.pop("ll_values")
|
log_likelihood = samples.pop("ll_values")
|
||||||
|
@ -165,7 +158,7 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
|
||||||
neg_ln_evidence = jax.numpy.nan
|
neg_ln_evidence = jax.numpy.nan
|
||||||
neg_ln_evidence_err = (jax.numpy.nan, jax.numpy.nan)
|
neg_ln_evidence_err = (jax.numpy.nan, jax.numpy.nan)
|
||||||
|
|
||||||
fname = f"samples_{ARGS.simname}_{ARGS.catalogue}_ksmooth{ARGS.ksmooth}.hdf5" # noqa
|
fname = f"samples_{ARGS.simname}_{'+'.join(ARGS.catalogue)}_ksmooth{ARGS.ksmooth}.hdf5" # noqa
|
||||||
if ARGS.ksim is not None:
|
if ARGS.ksim is not None:
|
||||||
fname = fname.replace(".hdf5", f"_nsim{ARGS.ksim}.hdf5")
|
fname = fname.replace(".hdf5", f"_nsim{ARGS.ksim}.hdf5")
|
||||||
|
|
||||||
|
@ -183,7 +176,6 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
|
||||||
# Write log likelihood and posterior
|
# Write log likelihood and posterior
|
||||||
f.create_dataset("log_likelihood", data=log_likelihood)
|
f.create_dataset("log_likelihood", data=log_likelihood)
|
||||||
f.create_dataset("log_posterior", data=log_posterior)
|
f.create_dataset("log_posterior", data=log_posterior)
|
||||||
f.create_dataset("simulation_weights", data=simulation_weights)
|
|
||||||
|
|
||||||
# Write goodness of fit
|
# Write goodness of fit
|
||||||
grp = f.create_group("gof")
|
grp = f.create_group("gof")
|
||||||
|
@ -215,6 +207,29 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
|
||||||
# Command line interface #
|
# Command line interface #
|
||||||
###############################################################################
|
###############################################################################
|
||||||
|
|
||||||
|
def get_distmod_hyperparams(catalogue):
|
||||||
|
alpha_min = -1.0
|
||||||
|
alpha_max = 3.0
|
||||||
|
sample_alpha = True
|
||||||
|
|
||||||
|
if catalogue in ["LOSS", "Foundation", "Pantheon+", "Pantheon+_groups", "Pantheon+_zSN"]: # noqa
|
||||||
|
return {"e_mu_min": 0.001, "e_mu_max": 1.0,
|
||||||
|
"mag_cal_mean": -18.25, "mag_cal_std": 2.0,
|
||||||
|
"alpha_cal_mean": 0.148, "alpha_cal_std": 1.0,
|
||||||
|
"beta_cal_mean": 3.112, "beta_cal_std": 2.0,
|
||||||
|
"alpha_min": alpha_min, "alpha_max": alpha_max,
|
||||||
|
"sample_alpha": sample_alpha
|
||||||
|
}
|
||||||
|
elif catalogue in ["SFI_gals", "2MTF"]:
|
||||||
|
return {"e_mu_min": 0.001, "e_mu_max": 1.0,
|
||||||
|
"a_mean": -21., "a_std": 5.0,
|
||||||
|
"b_mean": -5.95, "b_std": 3.0,
|
||||||
|
"alpha_min": alpha_min, "alpha_max": alpha_max,
|
||||||
|
"sample_alpha": sample_alpha
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
|
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
|
||||||
|
@ -227,14 +242,15 @@ if __name__ == "__main__":
|
||||||
###########################################################################
|
###########################################################################
|
||||||
|
|
||||||
nsteps = 5000
|
nsteps = 5000
|
||||||
nburn = 1000
|
nburn = 1500
|
||||||
zcmb_max = 0.06
|
zcmb_max = 0.05
|
||||||
calculate_evidence = False
|
calculate_evidence = False
|
||||||
nchains_harmonic = 10
|
nchains_harmonic = 10
|
||||||
num_epochs = 30
|
num_epochs = 30
|
||||||
|
|
||||||
if nsteps % nchains_harmonic != 0:
|
if nsteps % nchains_harmonic != 0:
|
||||||
raise ValueError("The number of steps must be divisible by the number of chains.") # noqa
|
raise ValueError(
|
||||||
|
"The number of steps must be divisible by the number of chains.")
|
||||||
|
|
||||||
main_params = {"nsteps": nsteps, "nburn": nburn, "zcmb_max": zcmb_max,
|
main_params = {"nsteps": nsteps, "nburn": nburn, "zcmb_max": zcmb_max,
|
||||||
"calculate_evidence": calculate_evidence,
|
"calculate_evidence": calculate_evidence,
|
||||||
|
@ -244,42 +260,36 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
calibration_hyperparams = {"Vext_min": -1000, "Vext_max": 1000,
|
calibration_hyperparams = {"Vext_min": -1000, "Vext_max": 1000,
|
||||||
"Vmono_min": -1000, "Vmono_max": 1000,
|
"Vmono_min": -1000, "Vmono_max": 1000,
|
||||||
"alpha_min": -1.0, "alpha_max": 3.0,
|
|
||||||
"beta_min": -1.0, "beta_max": 3.0,
|
"beta_min": -1.0, "beta_max": 3.0,
|
||||||
"sigma_v_min": 1.0, "sigma_v_max": 750.,
|
"sigma_v_min": 1.0, "sigma_v_max": 750.,
|
||||||
"sample_Vmono": False,
|
"sample_Vmono": False,
|
||||||
"sample_alpha": True,
|
|
||||||
"sample_beta": True,
|
"sample_beta": True,
|
||||||
"sample_sigma_v_ext": False,
|
"sample_sigma_v_ext": False,
|
||||||
}
|
}
|
||||||
print_variables(
|
print_variables(
|
||||||
calibration_hyperparams.keys(), calibration_hyperparams.values())
|
calibration_hyperparams.keys(), calibration_hyperparams.values())
|
||||||
|
|
||||||
if ARGS.catalogue in ["LOSS", "Foundation", "Pantheon+", "Pantheon+_groups", "Pantheon+_zSN"]: # noqa
|
distmod_hyperparams_per_catalogue = []
|
||||||
distmod_hyperparams = {"e_mu_min": 0.001, "e_mu_max": 1.0,
|
for cat in ARGS.catalogue:
|
||||||
"mag_cal_mean": -18.25, "mag_cal_std": 2.0,
|
x = get_distmod_hyperparams(cat)
|
||||||
"alpha_cal_mean": 0.148, "alpha_cal_std": 1.0,
|
print(f"\n{cat} hyperparameters:")
|
||||||
"beta_cal_mean": 3.112, "beta_cal_std": 2.0,
|
print_variables(x.keys(), x.values())
|
||||||
}
|
distmod_hyperparams_per_catalogue.append(x)
|
||||||
elif ARGS.catalogue in ["SFI_gals", "2MTF"]:
|
|
||||||
distmod_hyperparams = {"e_mu_min": 0.001, "e_mu_max": 1.0,
|
|
||||||
"a_mean": -21., "a_std": 5.0,
|
|
||||||
"b_mean": -5.95, "b_std": 3.0,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.")
|
|
||||||
|
|
||||||
print_variables(
|
kwargs_print = (main_params, calibration_hyperparams,
|
||||||
distmod_hyperparams.keys(), distmod_hyperparams.values())
|
*distmod_hyperparams_per_catalogue)
|
||||||
|
|
||||||
kwargs_print = (main_params, calibration_hyperparams, distmod_hyperparams)
|
|
||||||
###########################################################################
|
###########################################################################
|
||||||
|
|
||||||
model_kwargs = {"calibration_hyperparams": calibration_hyperparams,
|
|
||||||
"distmod_hyperparams": distmod_hyperparams}
|
|
||||||
get_model_kwargs = {"zcmb_max": zcmb_max}
|
get_model_kwargs = {"zcmb_max": zcmb_max}
|
||||||
|
models = get_models(get_model_kwargs, )
|
||||||
|
model_kwargs = {
|
||||||
|
"models": models,
|
||||||
|
"field_calibration_hyperparams": calibration_hyperparams,
|
||||||
|
"distmod_hyperparams_per_model": distmod_hyperparams_per_catalogue,
|
||||||
|
}
|
||||||
|
|
||||||
|
model = csiborgtools.flow.PV_validation_model
|
||||||
|
|
||||||
model = get_model(paths, get_model_kwargs, )
|
|
||||||
run_model(model, nsteps, nburn, model_kwargs, out_folder,
|
run_model(model, nsteps, nburn, model_kwargs, out_folder,
|
||||||
calibration_hyperparams["sample_beta"], calculate_evidence,
|
calibration_hyperparams["sample_beta"], calculate_evidence,
|
||||||
nchains_harmonic, num_epochs, kwargs_print)
|
nchains_harmonic, num_epochs, kwargs_print)
|
||||||
|
|
|
@ -1,44 +1,66 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
memory=8
|
memory=7
|
||||||
on_login=${1}
|
on_login=${1}
|
||||||
|
queue=${2}
|
||||||
ndevice=1
|
ndevice=1
|
||||||
|
|
||||||
device="gpu"
|
|
||||||
queue="gpulong"
|
|
||||||
gputype="rtx2080with12gb"
|
|
||||||
env="/mnt/users/rstiskalek/csiborgtools/venv_gpu_csiborgtools/bin/python"
|
|
||||||
file="flow_validation.py"
|
file="flow_validation.py"
|
||||||
ksmooth=0
|
ksmooth=0
|
||||||
|
|
||||||
|
|
||||||
if [ "$on_login" != "1" ] && [ "$on_login" != "0" ]; then
|
if [ "$on_login" != "1" ] && [ "$on_login" != "0" ]
|
||||||
echo "Invalid input: 'on_login' (1). Please provide 1 or 0."
|
then
|
||||||
|
echo "'on_login' (1) must be either 0 or 1."
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
# Submit a job for each combination of simname, catalogue, ksim
|
if [ "$queue" != "redwood" ] && [ "$queue" != "berg" ] && [ "$queue" != "cmb" ] && [ "$queue" != "gpulong" ] && [ "$queue" != "cmbgpu" ]; then
|
||||||
|
echo "Invalid queue: $queue (2). Please provide one of 'redwood', 'berg', 'cmb', 'gpulong', 'cmbgpu'."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
if [ "$queue" == "gpulong" ]
|
||||||
|
then
|
||||||
|
device="gpu"
|
||||||
|
gputype="rtx2080with12gb"
|
||||||
|
# gputype="rtx3070with8gb"
|
||||||
|
env="/mnt/users/rstiskalek/csiborgtools/venv_gpu_csiborgtools/bin/python"
|
||||||
|
elif [ "$queue" == "cmbgpu" ]
|
||||||
|
then
|
||||||
|
device="gpu"
|
||||||
|
gputype="rtx3090with24gb"
|
||||||
|
env="/mnt/users/rstiskalek/csiborgtools/venv_gpu_csiborgtools/bin/python"
|
||||||
|
else
|
||||||
|
device="cpu"
|
||||||
|
env="/mnt/users/rstiskalek/csiborgtools/venv_csiborg/bin/python"
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
# for simname in "Lilow2024" "CF4" "CF4gp" "csiborg1" "csiborg2_main" "csiborg2X"; do
|
# for simname in "Lilow2024" "CF4" "CF4gp" "csiborg1" "csiborg2_main" "csiborg2X"; do
|
||||||
for simname in "Carrick2015"; do
|
for simname in "Carrick2015"; do
|
||||||
# for simname in "csiborg1" "csiborg2_main" "csiborg2X"; do
|
for catalogue in "LOSS,2MTF,SFI_gals"; do
|
||||||
for catalogue in "Foundation"; do
|
|
||||||
# for catalogue in "2MTF"; do
|
|
||||||
# for ksim in 0 1 2; do
|
|
||||||
# for ksim in 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20; do
|
# for ksim in 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20; do
|
||||||
for ksim in "none"; do
|
for ksim in "none"; do
|
||||||
pythoncm="$env $file --catalogue $catalogue --simname $simname --ksim $ksim --ksmooth $ksmooth --ndevice $ndevice --device $device"
|
pythoncm="$env $file --catalogue $catalogue --simname $simname --ksim $ksim --ksmooth $ksmooth --ndevice $ndevice --device $device"
|
||||||
|
|
||||||
if [ $on_login -eq 1 ]; then
|
if [ "$on_login" == "1" ]; then
|
||||||
echo $pythoncm
|
echo $pythoncm
|
||||||
$pythoncm
|
eval $pythoncm
|
||||||
else
|
else
|
||||||
|
if [ "$device" == "gpu" ]; then
|
||||||
cm="addqueue -q $queue -s -m $memory --gpus 1 --gputype $gputype $pythoncm"
|
cm="addqueue -q $queue -s -m $memory --gpus 1 --gputype $gputype $pythoncm"
|
||||||
|
else
|
||||||
|
cm="addqueue -s -q $queue -n 1 -m $memory $pythoncm"
|
||||||
|
fi
|
||||||
echo "Submitting:"
|
echo "Submitting:"
|
||||||
echo $cm
|
echo $cm
|
||||||
eval $cm
|
eval $cm
|
||||||
fi
|
fi
|
||||||
|
|
||||||
echo
|
echo
|
||||||
sleep 0.001
|
sleep 0.001
|
||||||
|
|
||||||
done
|
done
|
||||||
done
|
done
|
||||||
done
|
done
|
||||||
|
|
Loading…
Reference in a new issue