Combine PV samples (#139)

* Update imports

* Update submission script

* Update script

* Add simulataenous sampling of many catalogues

* Update nb
This commit is contained in:
Richard Stiskalek 2024-07-30 17:02:48 +01:00 committed by GitHub
parent 9756175943
commit 3b46f17ead
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 356 additions and 147 deletions

View file

@ -33,7 +33,7 @@ def parse_args():
parser.add_argument("--simname", type=str, required=True,
help="Simulation name.")
parser.add_argument("--catalogue", type=str, required=True,
help="PV catalogue.")
help="PV catalogues.")
parser.add_argument("--ksmooth", type=int, default=1,
help="Smoothing index.")
parser.add_argument("--ksim", type=none_or_int, default=None,
@ -42,7 +42,12 @@ def parse_args():
help="Number of devices to request.")
parser.add_argument("--device", type=str, default="cpu",
help="Device to use.")
return parser.parse_args()
args = parser.parse_args()
# Convert the catalogue to a list of catalogues
args.catalogue = args.catalogue.split(",")
return args
ARGS = parse_args()
@ -58,8 +63,7 @@ from os.path import join
import csiborgtools # noqa
import jax # noqa
from h5py import File # noqa
from mpi4py import MPI # noqa
from numpyro.infer import MCMC, NUTS, Predictive, init_to_median # noqa
from numpyro.infer import MCMC, NUTS, init_to_median # noqa
def print_variables(names, variables):
@ -68,8 +72,8 @@ def print_variables(names, variables):
print(flush=True)
def get_model(paths, get_model_kwargs, verbose=True):
"""Load the data and create the NumPyro model."""
def get_models(get_model_kwargs, verbose=True):
"""Load the data and create the NumPyro models."""
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
folder = "/mnt/extraspace/rstiskalek/catalogs/"
@ -86,55 +90,45 @@ def get_model(paths, get_model_kwargs, verbose=True):
print(f"{'Num. realisations:':<20} {len(nsims)}")
print(flush=True)
if ARGS.catalogue == "A2":
fpath = join(folder, "A2.h5")
elif ARGS.catalogue in ["LOSS", "Foundation", "Pantheon+", "SFI_gals",
"2MTF", "SFI_groups", "SFI_gals_masked",
"Pantheon+_groups", "Pantheon+_groups_zSN",
"Pantheon+_zSN"]:
fpath = join(folder, "PV_compilation.hdf5")
else:
raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.")
# Get models
models = [None] * len(ARGS.catalogue)
for i, cat in enumerate(ARGS.catalogue):
if cat == "A2":
fpath = join(folder, "A2.h5")
elif cat in ["LOSS", "Foundation", "Pantheon+", "SFI_gals",
"2MTF", "SFI_groups", "SFI_gals_masked",
"Pantheon+_groups", "Pantheon+_groups_zSN",
"Pantheon+_zSN"]:
fpath = join(folder, "PV_compilation.hdf5")
else:
raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.")
loader = csiborgtools.flow.DataLoader(ARGS.simname, nsim_iterator,
ARGS.catalogue, fpath, paths,
ksmooth=ARGS.ksmooth)
loader = csiborgtools.flow.DataLoader(ARGS.simname, nsim_iterator,
cat, fpath, paths,
ksmooth=ARGS.ksmooth)
models[i] = csiborgtools.flow.get_model(loader, **get_model_kwargs)
print(f"\n{'Num. radial steps':<20} {len(loader.rdist)}\n", flush=True)
return csiborgtools.flow.get_model(loader, **get_model_kwargs)
return models
def get_harmonic_evidence(samples, log_posterior, nchains_harmonic, epoch_num):
"""Compute evidence using the `harmonic` package."""
data, names = csiborgtools.dict_samples_to_array(samples)
data = data.reshape(nchains_harmonic, -1, len(names))
log_posterior = log_posterior.reshape(10, -1)
log_posterior = log_posterior.reshape(nchains_harmonic, -1)
return csiborgtools.harmonic_evidence(
data, log_posterior, return_flow_samples=False, epochs_num=epoch_num)
def get_simulation_weights(samples, model, model_kwargs):
"""Get the weights per posterior samples for each simulation."""
predictive = Predictive(model, samples)
ll_all = predictive(
jax.random.PRNGKey(1), store_ll_all=True, **model_kwargs)["ll_all"]
# Multiply the likelihood of galaxies
ll_per_simulation = jax.numpy.sum(ll_all, axis=-1)
# Normalization by summing the likelihood over simulations
norm = jax.scipy.special.logsumexp(ll_per_simulation, axis=-1)
return ll_per_simulation - norm[:, None]
def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
calculate_evidence, nchains_harmonic, epoch_num, kwargs_print):
"""Run the NumPyro model and save output to a file."""
try:
ndata = model.ndata
ndata = sum(model.ndata for model in model_kwargs["models"])
except AttributeError as e:
raise AttributeError("The model must have an attribute `ndata` "
raise AttributeError("The models must have an attribute `ndata` "
"indicating the number of data points.") from e
nuts_kernel = NUTS(model, init_strategy=init_to_median(num_samples=1000))
@ -143,7 +137,6 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
mcmc.run(rng_key, extra_fields=("potential_energy",), **model_kwargs)
samples = mcmc.get_samples()
simulation_weights = get_simulation_weights(samples, model, model_kwargs)
log_posterior = -mcmc.get_extra_fields()["potential_energy"]
log_likelihood = samples.pop("ll_values")
@ -165,7 +158,7 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
neg_ln_evidence = jax.numpy.nan
neg_ln_evidence_err = (jax.numpy.nan, jax.numpy.nan)
fname = f"samples_{ARGS.simname}_{ARGS.catalogue}_ksmooth{ARGS.ksmooth}.hdf5" # noqa
fname = f"samples_{ARGS.simname}_{'+'.join(ARGS.catalogue)}_ksmooth{ARGS.ksmooth}.hdf5" # noqa
if ARGS.ksim is not None:
fname = fname.replace(".hdf5", f"_nsim{ARGS.ksim}.hdf5")
@ -183,7 +176,6 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
# Write log likelihood and posterior
f.create_dataset("log_likelihood", data=log_likelihood)
f.create_dataset("log_posterior", data=log_posterior)
f.create_dataset("simulation_weights", data=simulation_weights)
# Write goodness of fit
grp = f.create_group("gof")
@ -215,6 +207,29 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
# Command line interface #
###############################################################################
def get_distmod_hyperparams(catalogue):
alpha_min = -1.0
alpha_max = 3.0
sample_alpha = True
if catalogue in ["LOSS", "Foundation", "Pantheon+", "Pantheon+_groups", "Pantheon+_zSN"]: # noqa
return {"e_mu_min": 0.001, "e_mu_max": 1.0,
"mag_cal_mean": -18.25, "mag_cal_std": 2.0,
"alpha_cal_mean": 0.148, "alpha_cal_std": 1.0,
"beta_cal_mean": 3.112, "beta_cal_std": 2.0,
"alpha_min": alpha_min, "alpha_max": alpha_max,
"sample_alpha": sample_alpha
}
elif catalogue in ["SFI_gals", "2MTF"]:
return {"e_mu_min": 0.001, "e_mu_max": 1.0,
"a_mean": -21., "a_std": 5.0,
"b_mean": -5.95, "b_std": 3.0,
"alpha_min": alpha_min, "alpha_max": alpha_max,
"sample_alpha": sample_alpha
}
else:
raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.")
if __name__ == "__main__":
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
@ -227,14 +242,15 @@ if __name__ == "__main__":
###########################################################################
nsteps = 5000
nburn = 1000
zcmb_max = 0.06
nburn = 1500
zcmb_max = 0.05
calculate_evidence = False
nchains_harmonic = 10
num_epochs = 30
if nsteps % nchains_harmonic != 0:
raise ValueError("The number of steps must be divisible by the number of chains.") # noqa
raise ValueError(
"The number of steps must be divisible by the number of chains.")
main_params = {"nsteps": nsteps, "nburn": nburn, "zcmb_max": zcmb_max,
"calculate_evidence": calculate_evidence,
@ -244,42 +260,36 @@ if __name__ == "__main__":
calibration_hyperparams = {"Vext_min": -1000, "Vext_max": 1000,
"Vmono_min": -1000, "Vmono_max": 1000,
"alpha_min": -1.0, "alpha_max": 3.0,
"beta_min": -1.0, "beta_max": 3.0,
"sigma_v_min": 1.0, "sigma_v_max": 750.,
"sample_Vmono": False,
"sample_alpha": True,
"sample_beta": True,
"sample_sigma_v_ext": False,
}
print_variables(
calibration_hyperparams.keys(), calibration_hyperparams.values())
if ARGS.catalogue in ["LOSS", "Foundation", "Pantheon+", "Pantheon+_groups", "Pantheon+_zSN"]: # noqa
distmod_hyperparams = {"e_mu_min": 0.001, "e_mu_max": 1.0,
"mag_cal_mean": -18.25, "mag_cal_std": 2.0,
"alpha_cal_mean": 0.148, "alpha_cal_std": 1.0,
"beta_cal_mean": 3.112, "beta_cal_std": 2.0,
}
elif ARGS.catalogue in ["SFI_gals", "2MTF"]:
distmod_hyperparams = {"e_mu_min": 0.001, "e_mu_max": 1.0,
"a_mean": -21., "a_std": 5.0,
"b_mean": -5.95, "b_std": 3.0,
}
else:
raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.")
distmod_hyperparams_per_catalogue = []
for cat in ARGS.catalogue:
x = get_distmod_hyperparams(cat)
print(f"\n{cat} hyperparameters:")
print_variables(x.keys(), x.values())
distmod_hyperparams_per_catalogue.append(x)
print_variables(
distmod_hyperparams.keys(), distmod_hyperparams.values())
kwargs_print = (main_params, calibration_hyperparams, distmod_hyperparams)
kwargs_print = (main_params, calibration_hyperparams,
*distmod_hyperparams_per_catalogue)
###########################################################################
model_kwargs = {"calibration_hyperparams": calibration_hyperparams,
"distmod_hyperparams": distmod_hyperparams}
get_model_kwargs = {"zcmb_max": zcmb_max}
models = get_models(get_model_kwargs, )
model_kwargs = {
"models": models,
"field_calibration_hyperparams": calibration_hyperparams,
"distmod_hyperparams_per_model": distmod_hyperparams_per_catalogue,
}
model = csiborgtools.flow.PV_validation_model
model = get_model(paths, get_model_kwargs, )
run_model(model, nsteps, nburn, model_kwargs, out_folder,
calibration_hyperparams["sample_beta"], calculate_evidence,
nchains_harmonic, num_epochs, kwargs_print)

View file

@ -1,44 +1,66 @@
#!/bin/bash
memory=8
memory=7
on_login=${1}
queue=${2}
ndevice=1
device="gpu"
queue="gpulong"
gputype="rtx2080with12gb"
env="/mnt/users/rstiskalek/csiborgtools/venv_gpu_csiborgtools/bin/python"
file="flow_validation.py"
ksmooth=0
if [ "$on_login" != "1" ] && [ "$on_login" != "0" ]; then
echo "Invalid input: 'on_login' (1). Please provide 1 or 0."
if [ "$on_login" != "1" ] && [ "$on_login" != "0" ]
then
echo "'on_login' (1) must be either 0 or 1."
exit 1
fi
if [ "$queue" != "redwood" ] && [ "$queue" != "berg" ] && [ "$queue" != "cmb" ] && [ "$queue" != "gpulong" ] && [ "$queue" != "cmbgpu" ]; then
echo "Invalid queue: $queue (2). Please provide one of 'redwood', 'berg', 'cmb', 'gpulong', 'cmbgpu'."
exit 1
fi
# Submit a job for each combination of simname, catalogue, ksim
if [ "$queue" == "gpulong" ]
then
device="gpu"
gputype="rtx2080with12gb"
# gputype="rtx3070with8gb"
env="/mnt/users/rstiskalek/csiborgtools/venv_gpu_csiborgtools/bin/python"
elif [ "$queue" == "cmbgpu" ]
then
device="gpu"
gputype="rtx3090with24gb"
env="/mnt/users/rstiskalek/csiborgtools/venv_gpu_csiborgtools/bin/python"
else
device="cpu"
env="/mnt/users/rstiskalek/csiborgtools/venv_csiborg/bin/python"
fi
# for simname in "Lilow2024" "CF4" "CF4gp" "csiborg1" "csiborg2_main" "csiborg2X"; do
for simname in "Carrick2015"; do
# for simname in "csiborg1" "csiborg2_main" "csiborg2X"; do
for catalogue in "Foundation"; do
# for catalogue in "2MTF"; do
# for ksim in 0 1 2; do
for catalogue in "LOSS,2MTF,SFI_gals"; do
# for ksim in 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20; do
for ksim in "none"; do
pythoncm="$env $file --catalogue $catalogue --simname $simname --ksim $ksim --ksmooth $ksmooth --ndevice $ndevice --device $device"
if [ $on_login -eq 1 ]; then
if [ "$on_login" == "1" ]; then
echo $pythoncm
$pythoncm
eval $pythoncm
else
cm="addqueue -q $queue -s -m $memory --gpus 1 --gputype $gputype $pythoncm"
if [ "$device" == "gpu" ]; then
cm="addqueue -q $queue -s -m $memory --gpus 1 --gputype $gputype $pythoncm"
else
cm="addqueue -s -q $queue -n 1 -m $memory $pythoncm"
fi
echo "Submitting:"
echo $cm
eval $cm
fi
echo
sleep 0.001
done
done
done