mirror of
https://github.com/Richard-Sti/csiborgtools_public.git
synced 2025-05-14 06:31:11 +00:00
LSS projected basics (#140)
* Move files * Move files * Add galactic to RA/dec * Update sky maps * Add projected fields * Remove old import * Quick update * Add IO * Add imports * Update imports * Add basic file
This commit is contained in:
parent
3b46f17ead
commit
d578c71b83
36 changed files with 365 additions and 231 deletions
295
scripts/flow/flow_validation.py
Normal file
295
scripts/flow/flow_validation.py
Normal file
|
@ -0,0 +1,295 @@
|
|||
# Copyright (C) 2024 Richard Stiskalek
|
||||
# This program is free software; you can redistribute it and/or modify it
|
||||
# under the terms of the GNU General Public License as published by the
|
||||
# Free Software Foundation; either version 3 of the License, or (at your
|
||||
# option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
|
||||
# Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License along
|
||||
# with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
|
||||
"""
|
||||
Script to run the PV validation model on various catalogues and simulations.
|
||||
The script is not MPI parallelised, instead it is best run on a GPU.
|
||||
"""
|
||||
from argparse import ArgumentParser, ArgumentTypeError
|
||||
|
||||
|
||||
def none_or_int(value):
|
||||
if value.lower() == "none":
|
||||
return None
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
raise ArgumentTypeError(f"Invalid value: {value}. Must be an integer or 'none'.") # noqa
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--simname", type=str, required=True,
|
||||
help="Simulation name.")
|
||||
parser.add_argument("--catalogue", type=str, required=True,
|
||||
help="PV catalogues.")
|
||||
parser.add_argument("--ksmooth", type=int, default=1,
|
||||
help="Smoothing index.")
|
||||
parser.add_argument("--ksim", type=none_or_int, default=None,
|
||||
help="IC iteration number. If 'None', all IC realizations are used.") # noqa
|
||||
parser.add_argument("--ndevice", type=int, default=1,
|
||||
help="Number of devices to request.")
|
||||
parser.add_argument("--device", type=str, default="cpu",
|
||||
help="Device to use.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Convert the catalogue to a list of catalogues
|
||||
args.catalogue = args.catalogue.split(",")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
ARGS = parse_args()
|
||||
# This must be done before we import JAX etc.
|
||||
from numpyro import set_host_device_count, set_platform # noqa
|
||||
|
||||
set_platform(ARGS.device) # noqa
|
||||
set_host_device_count(ARGS.ndevice) # noqa
|
||||
|
||||
import sys # noqa
|
||||
from os.path import join # noqa
|
||||
|
||||
import csiborgtools # noqa
|
||||
import jax # noqa
|
||||
from h5py import File # noqa
|
||||
from numpyro.infer import MCMC, NUTS, init_to_median # noqa
|
||||
|
||||
|
||||
def print_variables(names, variables):
|
||||
for name, variable in zip(names, variables):
|
||||
print(f"{name:<20} {variable}", flush=True)
|
||||
print(flush=True)
|
||||
|
||||
|
||||
def get_models(get_model_kwargs, verbose=True):
|
||||
"""Load the data and create the NumPyro models."""
|
||||
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
|
||||
folder = "/mnt/extraspace/rstiskalek/catalogs/"
|
||||
|
||||
nsims = paths.get_ics(ARGS.simname)
|
||||
if ARGS.ksim is None:
|
||||
nsim_iterator = [i for i in range(len(nsims))]
|
||||
else:
|
||||
nsim_iterator = [ARGS.ksim]
|
||||
nsims = [nsims[ARGS.ksim]]
|
||||
|
||||
if verbose:
|
||||
print(f"{'Simulation:':<20} {ARGS.simname}")
|
||||
print(f"{'Catalogue:':<20} {ARGS.catalogue}")
|
||||
print(f"{'Num. realisations:':<20} {len(nsims)}")
|
||||
print(flush=True)
|
||||
|
||||
# Get models
|
||||
models = [None] * len(ARGS.catalogue)
|
||||
for i, cat in enumerate(ARGS.catalogue):
|
||||
if cat == "A2":
|
||||
fpath = join(folder, "A2.h5")
|
||||
elif cat in ["LOSS", "Foundation", "Pantheon+", "SFI_gals",
|
||||
"2MTF", "SFI_groups", "SFI_gals_masked",
|
||||
"Pantheon+_groups", "Pantheon+_groups_zSN",
|
||||
"Pantheon+_zSN"]:
|
||||
fpath = join(folder, "PV_compilation.hdf5")
|
||||
else:
|
||||
raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.")
|
||||
|
||||
loader = csiborgtools.flow.DataLoader(ARGS.simname, nsim_iterator,
|
||||
cat, fpath, paths,
|
||||
ksmooth=ARGS.ksmooth)
|
||||
models[i] = csiborgtools.flow.get_model(loader, **get_model_kwargs)
|
||||
|
||||
print(f"\n{'Num. radial steps':<20} {len(loader.rdist)}\n", flush=True)
|
||||
return models
|
||||
|
||||
|
||||
def get_harmonic_evidence(samples, log_posterior, nchains_harmonic, epoch_num):
|
||||
"""Compute evidence using the `harmonic` package."""
|
||||
data, names = csiborgtools.dict_samples_to_array(samples)
|
||||
data = data.reshape(nchains_harmonic, -1, len(names))
|
||||
log_posterior = log_posterior.reshape(nchains_harmonic, -1)
|
||||
|
||||
return csiborgtools.harmonic_evidence(
|
||||
data, log_posterior, return_flow_samples=False, epochs_num=epoch_num)
|
||||
|
||||
|
||||
def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
|
||||
calculate_evidence, nchains_harmonic, epoch_num, kwargs_print):
|
||||
"""Run the NumPyro model and save output to a file."""
|
||||
try:
|
||||
ndata = sum(model.ndata for model in model_kwargs["models"])
|
||||
except AttributeError as e:
|
||||
raise AttributeError("The models must have an attribute `ndata` "
|
||||
"indicating the number of data points.") from e
|
||||
|
||||
nuts_kernel = NUTS(model, init_strategy=init_to_median(num_samples=1000))
|
||||
mcmc = MCMC(nuts_kernel, num_warmup=nburn, num_samples=nsteps)
|
||||
rng_key = jax.random.PRNGKey(42)
|
||||
|
||||
mcmc.run(rng_key, extra_fields=("potential_energy",), **model_kwargs)
|
||||
samples = mcmc.get_samples()
|
||||
|
||||
log_posterior = -mcmc.get_extra_fields()["potential_energy"]
|
||||
log_likelihood = samples.pop("ll_values")
|
||||
if log_likelihood is None:
|
||||
raise ValueError("The samples must contain the log likelihood values under the key `ll_values`.") # noqa
|
||||
|
||||
BIC, AIC = csiborgtools.BIC_AIC(samples, log_likelihood, ndata)
|
||||
print(f"{'BIC':<20} {BIC}")
|
||||
print(f"{'AIC':<20} {AIC}")
|
||||
mcmc.print_summary()
|
||||
|
||||
if calculate_evidence:
|
||||
print("Calculating the evidence using `harmonic`.", flush=True)
|
||||
neg_ln_evidence, neg_ln_evidence_err = get_harmonic_evidence(
|
||||
samples, log_posterior, nchains_harmonic, epoch_num)
|
||||
print(f"{'-ln(Z)':<20} {neg_ln_evidence}")
|
||||
print(f"{'-ln(Z) error':<20} {neg_ln_evidence_err}")
|
||||
else:
|
||||
neg_ln_evidence = jax.numpy.nan
|
||||
neg_ln_evidence_err = (jax.numpy.nan, jax.numpy.nan)
|
||||
|
||||
fname = f"samples_{ARGS.simname}_{'+'.join(ARGS.catalogue)}_ksmooth{ARGS.ksmooth}.hdf5" # noqa
|
||||
if ARGS.ksim is not None:
|
||||
fname = fname.replace(".hdf5", f"_nsim{ARGS.ksim}.hdf5")
|
||||
|
||||
if sample_beta:
|
||||
fname = fname.replace(".hdf5", "_sample_beta.hdf5")
|
||||
|
||||
fname = join(out_folder, fname)
|
||||
print(f"Saving results to `{fname}`.")
|
||||
with File(fname, "w") as f:
|
||||
# Write samples
|
||||
grp = f.create_group("samples")
|
||||
for key, value in samples.items():
|
||||
grp.create_dataset(key, data=value)
|
||||
|
||||
# Write log likelihood and posterior
|
||||
f.create_dataset("log_likelihood", data=log_likelihood)
|
||||
f.create_dataset("log_posterior", data=log_posterior)
|
||||
|
||||
# Write goodness of fit
|
||||
grp = f.create_group("gof")
|
||||
grp.create_dataset("BIC", data=BIC)
|
||||
grp.create_dataset("AIC", data=AIC)
|
||||
grp.create_dataset("neg_lnZ", data=neg_ln_evidence)
|
||||
grp.create_dataset("neg_lnZ_err", data=neg_ln_evidence_err)
|
||||
|
||||
fname_summary = fname.replace(".hdf5", ".txt")
|
||||
print(f"Saving summary to `{fname_summary}`.")
|
||||
with open(fname_summary, 'w') as f:
|
||||
original_stdout = sys.stdout
|
||||
sys.stdout = f
|
||||
|
||||
print("User parameters:")
|
||||
for kwargs in kwargs_print:
|
||||
print_variables(kwargs.keys(), kwargs.values())
|
||||
|
||||
print("HMC summary:")
|
||||
print(f"{'BIC':<20} {BIC}")
|
||||
print(f"{'AIC':<20} {AIC}")
|
||||
print(f"{'-ln(Z)':<20} {neg_ln_evidence}")
|
||||
print(f"{'-ln(Z) error':<20} {neg_ln_evidence_err}")
|
||||
mcmc.print_summary(exclude_deterministic=False)
|
||||
sys.stdout = original_stdout
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Command line interface #
|
||||
###############################################################################
|
||||
|
||||
def get_distmod_hyperparams(catalogue):
|
||||
alpha_min = -1.0
|
||||
alpha_max = 3.0
|
||||
sample_alpha = True
|
||||
|
||||
if catalogue in ["LOSS", "Foundation", "Pantheon+", "Pantheon+_groups", "Pantheon+_zSN"]: # noqa
|
||||
return {"e_mu_min": 0.001, "e_mu_max": 1.0,
|
||||
"mag_cal_mean": -18.25, "mag_cal_std": 2.0,
|
||||
"alpha_cal_mean": 0.148, "alpha_cal_std": 1.0,
|
||||
"beta_cal_mean": 3.112, "beta_cal_std": 2.0,
|
||||
"alpha_min": alpha_min, "alpha_max": alpha_max,
|
||||
"sample_alpha": sample_alpha
|
||||
}
|
||||
elif catalogue in ["SFI_gals", "2MTF"]:
|
||||
return {"e_mu_min": 0.001, "e_mu_max": 1.0,
|
||||
"a_mean": -21., "a_std": 5.0,
|
||||
"b_mean": -5.95, "b_std": 3.0,
|
||||
"alpha_min": alpha_min, "alpha_max": alpha_max,
|
||||
"sample_alpha": sample_alpha
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
|
||||
out_folder = "/mnt/extraspace/rstiskalek/csiborg_postprocessing/peculiar_velocity" # noqa
|
||||
print(f"{'Num. devices:':<20} {jax.device_count()}")
|
||||
print(f"{'Devices:':<20} {jax.devices()}")
|
||||
|
||||
###########################################################################
|
||||
# Fixed user parameters #
|
||||
###########################################################################
|
||||
|
||||
nsteps = 5000
|
||||
nburn = 1500
|
||||
zcmb_max = 0.05
|
||||
calculate_evidence = False
|
||||
nchains_harmonic = 10
|
||||
num_epochs = 30
|
||||
|
||||
if nsteps % nchains_harmonic != 0:
|
||||
raise ValueError(
|
||||
"The number of steps must be divisible by the number of chains.")
|
||||
|
||||
main_params = {"nsteps": nsteps, "nburn": nburn, "zcmb_max": zcmb_max,
|
||||
"calculate_evidence": calculate_evidence,
|
||||
"nchains_harmonic": nchains_harmonic,
|
||||
"num_epochs": num_epochs}
|
||||
print_variables(main_params.keys(), main_params.values())
|
||||
|
||||
calibration_hyperparams = {"Vext_min": -1000, "Vext_max": 1000,
|
||||
"Vmono_min": -1000, "Vmono_max": 1000,
|
||||
"beta_min": -1.0, "beta_max": 3.0,
|
||||
"sigma_v_min": 1.0, "sigma_v_max": 750.,
|
||||
"sample_Vmono": False,
|
||||
"sample_beta": True,
|
||||
"sample_sigma_v_ext": False,
|
||||
}
|
||||
print_variables(
|
||||
calibration_hyperparams.keys(), calibration_hyperparams.values())
|
||||
|
||||
distmod_hyperparams_per_catalogue = []
|
||||
for cat in ARGS.catalogue:
|
||||
x = get_distmod_hyperparams(cat)
|
||||
print(f"\n{cat} hyperparameters:")
|
||||
print_variables(x.keys(), x.values())
|
||||
distmod_hyperparams_per_catalogue.append(x)
|
||||
|
||||
kwargs_print = (main_params, calibration_hyperparams,
|
||||
*distmod_hyperparams_per_catalogue)
|
||||
###########################################################################
|
||||
|
||||
get_model_kwargs = {"zcmb_max": zcmb_max}
|
||||
models = get_models(get_model_kwargs, )
|
||||
model_kwargs = {
|
||||
"models": models,
|
||||
"field_calibration_hyperparams": calibration_hyperparams,
|
||||
"distmod_hyperparams_per_model": distmod_hyperparams_per_catalogue,
|
||||
}
|
||||
|
||||
model = csiborgtools.flow.PV_validation_model
|
||||
|
||||
run_model(model, nsteps, nburn, model_kwargs, out_folder,
|
||||
calibration_hyperparams["sample_beta"], calculate_evidence,
|
||||
nchains_harmonic, num_epochs, kwargs_print)
|
66
scripts/flow/flow_validation.sh
Executable file
66
scripts/flow/flow_validation.sh
Executable file
|
@ -0,0 +1,66 @@
|
|||
#!/bin/bash
|
||||
memory=7
|
||||
on_login=${1}
|
||||
queue=${2}
|
||||
ndevice=1
|
||||
file="flow_validation.py"
|
||||
ksmooth=0
|
||||
|
||||
|
||||
if [ "$on_login" != "1" ] && [ "$on_login" != "0" ]
|
||||
then
|
||||
echo "'on_login' (1) must be either 0 or 1."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
if [ "$queue" != "redwood" ] && [ "$queue" != "berg" ] && [ "$queue" != "cmb" ] && [ "$queue" != "gpulong" ] && [ "$queue" != "cmbgpu" ]; then
|
||||
echo "Invalid queue: $queue (2). Please provide one of 'redwood', 'berg', 'cmb', 'gpulong', 'cmbgpu'."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
if [ "$queue" == "gpulong" ]
|
||||
then
|
||||
device="gpu"
|
||||
gputype="rtx2080with12gb"
|
||||
# gputype="rtx3070with8gb"
|
||||
env="/mnt/users/rstiskalek/csiborgtools/venv_gpu_csiborgtools/bin/python"
|
||||
elif [ "$queue" == "cmbgpu" ]
|
||||
then
|
||||
device="gpu"
|
||||
gputype="rtx3090with24gb"
|
||||
env="/mnt/users/rstiskalek/csiborgtools/venv_gpu_csiborgtools/bin/python"
|
||||
else
|
||||
device="cpu"
|
||||
env="/mnt/users/rstiskalek/csiborgtools/venv_csiborg/bin/python"
|
||||
fi
|
||||
|
||||
|
||||
# for simname in "Lilow2024" "CF4" "CF4gp" "csiborg1" "csiborg2_main" "csiborg2X"; do
|
||||
for simname in "Carrick2015"; do
|
||||
for catalogue in "LOSS,2MTF,SFI_gals"; do
|
||||
# for ksim in 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20; do
|
||||
for ksim in "none"; do
|
||||
pythoncm="$env $file --catalogue $catalogue --simname $simname --ksim $ksim --ksmooth $ksmooth --ndevice $ndevice --device $device"
|
||||
|
||||
if [ "$on_login" == "1" ]; then
|
||||
echo $pythoncm
|
||||
eval $pythoncm
|
||||
else
|
||||
if [ "$device" == "gpu" ]; then
|
||||
cm="addqueue -q $queue -s -m $memory --gpus 1 --gputype $gputype $pythoncm"
|
||||
else
|
||||
cm="addqueue -s -q $queue -n 1 -m $memory $pythoncm"
|
||||
fi
|
||||
echo "Submitting:"
|
||||
echo $cm
|
||||
eval $cm
|
||||
fi
|
||||
|
||||
echo
|
||||
sleep 0.001
|
||||
|
||||
done
|
||||
done
|
||||
done
|
174
scripts/flow/post_upglade.py
Normal file
174
scripts/flow/post_upglade.py
Normal file
|
@ -0,0 +1,174 @@
|
|||
# Copyright (C) 2024 Richard Stiskalek
|
||||
# This program is free software; you can redistribute it and/or modify it
|
||||
# under the terms of the GNU General Public License as published by the
|
||||
# Free Software Foundation; either version 3 of the License, or (at your
|
||||
# option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
|
||||
# Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License along
|
||||
# with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
|
||||
"""
|
||||
Script to calculate cosmological redshifts from observed redshifts assuming
|
||||
the Carrick+2015 peculiar velocity model. In the future this may be extended
|
||||
to include other peculiar velocity models.
|
||||
"""
|
||||
from datetime import datetime
|
||||
from os import remove
|
||||
from os.path import join
|
||||
|
||||
import csiborgtools
|
||||
import numpy as np
|
||||
from csiborgtools import fprint
|
||||
from h5py import File
|
||||
from mpi4py import MPI
|
||||
from taskmaster import work_delegation # noqa
|
||||
from tqdm import tqdm
|
||||
|
||||
SPEED_OF_LIGHT = 299792.458 # km / s
|
||||
|
||||
|
||||
def t():
|
||||
return datetime.now().strftime("%H:%M:%S")
|
||||
|
||||
|
||||
def load_calibration(catalogue, simname, ksmooth, sample_beta,
|
||||
verbose=False):
|
||||
"""Load the pre-computed calibration samples."""
|
||||
fname = f"/mnt/extraspace/rstiskalek/csiborg_postprocessing/peculiar_velocity/samples_{simname}_{catalogue}_ksmooth{ksmooth}.hdf5" # noqa
|
||||
if sample_beta:
|
||||
fname = fname.replace(".hdf5", "_sample_beta.hdf5")
|
||||
|
||||
keys = ["Vext", "sigma_v", "alpha", "beta"]
|
||||
calibration_samples = {}
|
||||
with File(fname, 'r') as f:
|
||||
for n, key in enumerate(keys):
|
||||
# In case alpha wasn't sampled just set to 1
|
||||
if key == "alpha" and "alpha" not in f["samples"].keys():
|
||||
calibration_samples[key] = np.full_like(
|
||||
calibration_samples["sigma_v"], 1.0)
|
||||
continue
|
||||
|
||||
# NOTE: here the posterior samples are down-sampled
|
||||
calibration_samples[key] = f[f"samples/{key}"][:][::10]
|
||||
|
||||
if n == 0:
|
||||
num_samples_original = len(f[f"samples/{key}"])
|
||||
num_samples_final = len(calibration_samples[key])
|
||||
|
||||
fprint(f"downsampling calibration samples from {num_samples_original} to {num_samples_final}.", verbose=verbose) # noqa
|
||||
return calibration_samples
|
||||
|
||||
|
||||
def main(loader, nsim, model, indxs, fdir, fname, num_split, verbose):
|
||||
out = np.full(
|
||||
len(indxs), np.nan,
|
||||
dtype=[("mean_zcosmo", float), ("std_zcosmo", float)])
|
||||
|
||||
# Process each galaxy in this split
|
||||
for i, n in enumerate(tqdm(indxs, desc=f"Split {num_split}",
|
||||
disable=not verbose)):
|
||||
x, y = model.posterior_zcosmo(
|
||||
loader.cat["zcmb"][n], loader.cat["RA"][n], loader.cat["DEC"][n],
|
||||
loader.los_density[nsim, n], loader.los_radial_velocity[nsim, n],
|
||||
extra_sigma_v=loader.cat["e_zcmb"][n] * SPEED_OF_LIGHT,
|
||||
verbose=False)
|
||||
|
||||
mu, std = model.posterior_mean_std(x, y)
|
||||
out["mean_zcosmo"][i], out["std_zcosmo"][i] = mu, std
|
||||
|
||||
# Save the results of this rank
|
||||
fname = join(fdir, f"{fname}_{num_split}.hdf5")
|
||||
with File(fname, 'w') as f:
|
||||
f.create_dataset("mean_zcosmo", data=out["mean_zcosmo"])
|
||||
f.create_dataset("std_zcosmo", data=out["std_zcosmo"])
|
||||
f.create_dataset("indxs", data=indxs)
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Command line interface #
|
||||
###############################################################################
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
comm = MPI.COMM_WORLD
|
||||
rank, size = comm.Get_rank(), comm.Get_size()
|
||||
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
|
||||
|
||||
# Calibration parameters
|
||||
simname = "Carrick2015"
|
||||
ksmooth = 0
|
||||
nsim = 0
|
||||
catalogue_calibration = "Pantheon+_zSN"
|
||||
|
||||
# Galaxy sample parameters
|
||||
catalogue = "UPGLADE"
|
||||
fpath_data = "/mnt/users/rstiskalek/csiborgtools/data/upglade_all_z0p05_new_PROCESSED.h5" # noqa
|
||||
|
||||
# Number of splits for MPI
|
||||
nsplits = 1000
|
||||
|
||||
# Folder to save the results
|
||||
fdir = "/mnt/extraspace/rstiskalek/csiborg_postprocessing/peculiar_velocity/UPGLADE" # noqa
|
||||
fname = f"zcosmo_{catalogue}"
|
||||
|
||||
# Load in the data, calibration samples and the model
|
||||
loader = csiborgtools.flow.DataLoader(
|
||||
simname, nsim, catalogue, fpath_data, paths, ksmooth=ksmooth,
|
||||
verbose=rank == 0)
|
||||
calibration_samples = load_calibration(
|
||||
catalogue_calibration, simname, ksmooth, sample_beta=True,
|
||||
verbose=rank == 0)
|
||||
|
||||
model = csiborgtools.flow.Observed2CosmologicalRedshift(
|
||||
calibration_samples, loader.rdist, loader._Omega_m)
|
||||
|
||||
fprint(f"catalogue size is {loader.cat['zcmb'].size}.", verbose=rank == 0)
|
||||
fprint("loaded calibration samples and model.", verbose=rank == 0)
|
||||
|
||||
# Decide how to split up the job
|
||||
if rank == 0:
|
||||
indxs = np.arange(loader.cat["zcmb"].size)
|
||||
split_indxs = np.array_split(indxs, nsplits)
|
||||
else:
|
||||
indxs = None
|
||||
split_indxs = None
|
||||
indxs = comm.bcast(indxs, root=0)
|
||||
split_indxs = comm.bcast(split_indxs, root=0)
|
||||
|
||||
# Process all splits with MPI, the rank 0 delegates the jobs.
|
||||
def main_wrapper(n):
|
||||
main(loader, nsim, model, split_indxs[n], fdir, fname, n,
|
||||
verbose=size == 1)
|
||||
|
||||
comm.Barrier()
|
||||
work_delegation(
|
||||
main_wrapper, list(range(nsplits)), comm, master_verbose=True)
|
||||
comm.Barrier()
|
||||
|
||||
# Combine the results to a single file
|
||||
if rank == 0:
|
||||
print("Combining results from all ranks.", flush=True)
|
||||
mean_zcosmo = np.full(loader.cat["zcmb"].size, np.nan)
|
||||
std_zcosmo = np.full_like(mean_zcosmo, np.nan)
|
||||
|
||||
for n in range(nsplits):
|
||||
fname_current = join(fdir, f"{fname}_{n}.hdf5")
|
||||
with File(fname_current, 'r') as f:
|
||||
mask = f["indxs"][:]
|
||||
mean_zcosmo[mask] = f["mean_zcosmo"][:]
|
||||
std_zcosmo[mask] = f["std_zcosmo"][:]
|
||||
|
||||
remove(fname_current)
|
||||
|
||||
# Save the results
|
||||
fname = join(fdir, f"{fname}.hdf5")
|
||||
print(f"Saving results to `{fname}`.")
|
||||
with File(fname, 'w') as f:
|
||||
f.create_dataset("mean_zcosmo", data=mean_zcosmo)
|
||||
f.create_dataset("std_zcosmo", data=std_zcosmo)
|
||||
f.create_dataset("indxs", data=indxs)
|
31
scripts/flow/post_upglade.sh
Executable file
31
scripts/flow/post_upglade.sh
Executable file
|
@ -0,0 +1,31 @@
|
|||
nthreads=${1}
|
||||
on_login=${2}
|
||||
memory=12
|
||||
queue="redwood"
|
||||
env="/mnt/zfsusers/rstiskalek/csiborgtools/venv_csiborg/bin/python"
|
||||
file="post_upglade.py"
|
||||
|
||||
|
||||
if [[ "$on_login" != "0" && "$on_login" != "1" ]]
|
||||
then
|
||||
echo "Error: on_login must be either 0 or 1."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! [[ "$nthreads" =~ ^[0-9]+$ ]] || [ "$nthreads" -le 0 ]; then
|
||||
echo "Error: nthreads must be an integer larger than 0."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
pythoncm="$env $file"
|
||||
if [ $on_login -eq 1 ]; then
|
||||
echo $pythoncm
|
||||
$pythoncm
|
||||
else
|
||||
cm="addqueue -q $queue -n $nthreads -m $memory $pythoncm"
|
||||
echo "Submitting:"
|
||||
echo $cm
|
||||
echo
|
||||
eval $cm
|
||||
fi
|
217
scripts/flow/quijote_bulkflow.py
Normal file
217
scripts/flow/quijote_bulkflow.py
Normal file
|
@ -0,0 +1,217 @@
|
|||
# Copyright (C) 2023 Richard Stiskalek
|
||||
# This program is free software; you can redistribute it and/or modify it
|
||||
# under the terms of the GNU General Public License as published by the
|
||||
# Free Software Foundation; either version 3 of the License, or (at your
|
||||
# option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
|
||||
# Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License along
|
||||
# with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
|
||||
"""
|
||||
A script to calculate the bulk flow in Quijote simulations from either
|
||||
particles or FoF haloes and to also save the resulting smaller halo catalogues.
|
||||
|
||||
If `Rmin > 0` the bulk flows computed from projected radial velocities are
|
||||
wrong, but the 3D volume average bulk flows are still correct.
|
||||
"""
|
||||
from datetime import datetime
|
||||
from os.path import join
|
||||
|
||||
import csiborgtools
|
||||
import numpy as np
|
||||
from mpi4py import MPI
|
||||
from taskmaster import work_delegation # noqa
|
||||
from warnings import catch_warnings, simplefilter
|
||||
from h5py import File
|
||||
from sklearn.neighbors import NearestNeighbors
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Read in information about the simulation #
|
||||
###############################################################################
|
||||
|
||||
|
||||
def t():
|
||||
return datetime.now()
|
||||
|
||||
|
||||
def get_data(nsim, verbose=True):
|
||||
if verbose:
|
||||
print(f"{t()}: reading particles of simulation `{nsim}`.")
|
||||
reader = csiborgtools.read.QuijoteSnapshot(nsim, 4, paths)
|
||||
part_pos = reader.coordinates().astype(np.float64)
|
||||
part_vel = reader.velocities().astype(np.float64)
|
||||
|
||||
if verbose:
|
||||
print(f"{t()}: reading haloes of simulation `{nsim}`.")
|
||||
reader = csiborgtools.read.QuijoteCatalogue(nsim)
|
||||
halo_pos = reader.coordinates
|
||||
halo_vel = reader.velocities
|
||||
halo_mass = reader.totmass
|
||||
|
||||
return part_pos, part_vel, halo_pos, halo_vel, halo_mass
|
||||
|
||||
|
||||
def volume_bulk_flow(rdist, mass, vel, distances):
|
||||
out = csiborgtools.field.particles_enclosed_momentum(
|
||||
rdist, mass, vel, distances)
|
||||
with catch_warnings():
|
||||
simplefilter("ignore", category=RuntimeWarning)
|
||||
out /= csiborgtools.field.particles_enclosed_mass(
|
||||
rdist, mass, distances)[:, np.newaxis]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Main & command line interface #
|
||||
###############################################################################
|
||||
|
||||
|
||||
def main(nsim, folder, fname_basis, Rmin, Rmax, subtract_observer_velocity,
|
||||
verbose=True):
|
||||
boxsize = csiborgtools.simname2boxsize("quijote")
|
||||
observers = csiborgtools.read.fiducial_observers(boxsize, Rmax)
|
||||
distances = np.linspace(0, Rmax, 101)[1:]
|
||||
part_pos, part_vel, halo_pos, halo_vel, halo_mass = get_data(nsim, verbose)
|
||||
|
||||
if verbose:
|
||||
print(f"{t()}: Fitting the particle and halo trees of simulation `{nsim}`.") # noqa
|
||||
part_tree = NearestNeighbors().fit(part_pos)
|
||||
halo_tree = NearestNeighbors().fit(halo_pos)
|
||||
|
||||
samples = {}
|
||||
bf_volume_part = np.full((len(observers), len(distances), 3), np.nan)
|
||||
bf_volume_halo = np.full_like(bf_volume_part, np.nan)
|
||||
bf_volume_halo_uniform = np.full_like(bf_volume_part, np.nan)
|
||||
bf_vrad_weighted_part = np.full_like(bf_volume_part, np.nan)
|
||||
bf_vrad_weighted_halo_uniform = np.full_like(bf_volume_part, np.nan)
|
||||
bf_vrad_weighted_halo = np.full_like(bf_volume_part, np.nan)
|
||||
obs_vel = np.full((len(observers), 3), np.nan)
|
||||
|
||||
for i in range(len(observers)):
|
||||
print(f"{t()}: Calculating bulk flow for observer {i + 1} of simulation {nsim}.") # noqa
|
||||
|
||||
# Select particles within Rmax of the observer
|
||||
rdist_part, indxs = part_tree.radius_neighbors(
|
||||
np.asarray(observers[i]).reshape(1, -1), Rmax,
|
||||
return_distance=True, sort_results=True)
|
||||
rdist_part, indxs = rdist_part[0], indxs[0]
|
||||
|
||||
# And only the ones that are above Rmin
|
||||
mask = rdist_part > Rmin
|
||||
rdist_part = rdist_part[mask]
|
||||
indxs = indxs[mask]
|
||||
|
||||
part_pos_current = part_pos[indxs] - observers[i]
|
||||
part_vel_current = part_vel[indxs]
|
||||
# Quijote particle masses are all equal
|
||||
part_mass = np.ones_like(rdist_part)
|
||||
|
||||
# Select haloes within Rmax of the observer
|
||||
rdist_halo, indxs = halo_tree.radius_neighbors(
|
||||
np.asarray(observers[i]).reshape(1, -1), Rmax,
|
||||
return_distance=True, sort_results=True)
|
||||
rdist_halo, indxs = rdist_halo[0], indxs[0]
|
||||
mask = rdist_halo > Rmin
|
||||
rdist_halo = rdist_halo[mask]
|
||||
indxs = indxs[mask]
|
||||
|
||||
halo_pos_current = halo_pos[indxs] - observers[i]
|
||||
halo_vel_current = halo_vel[indxs]
|
||||
halo_mass_current = halo_mass[indxs]
|
||||
|
||||
# Subtract the observer velocity
|
||||
rscale = 2.0 # Mpc / h
|
||||
weights = np.exp(-0.5 * (rdist_part / rscale)**2)
|
||||
obs_vel_x = np.average(part_vel_current[:, 0], weights=weights)
|
||||
obs_vel_y = np.average(part_vel_current[:, 1], weights=weights)
|
||||
obs_vel_z = np.average(part_vel_current[:, 2], weights=weights)
|
||||
|
||||
obs_vel[i, 0] = obs_vel_x
|
||||
obs_vel[i, 1] = obs_vel_y
|
||||
obs_vel[i, 2] = obs_vel_z
|
||||
|
||||
if subtract_observer_velocity:
|
||||
part_vel_current[:, 0] -= obs_vel_x
|
||||
part_vel_current[:, 1] -= obs_vel_y
|
||||
part_vel_current[:, 2] -= obs_vel_z
|
||||
|
||||
halo_vel_current[:, 0] -= obs_vel_x
|
||||
halo_vel_current[:, 1] -= obs_vel_y
|
||||
halo_vel_current[:, 2] -= obs_vel_z
|
||||
|
||||
# Calculate the volume average bulk flows
|
||||
bf_volume_part[i, ...] = volume_bulk_flow(
|
||||
rdist_part, part_mass, part_vel_current, distances)
|
||||
bf_volume_halo[i, ...] = volume_bulk_flow(
|
||||
rdist_halo, halo_mass_current, halo_vel_current, distances)
|
||||
bf_volume_halo_uniform[i, ...] = volume_bulk_flow(
|
||||
rdist_halo, np.ones_like(halo_mass_current), halo_vel_current,
|
||||
distances)
|
||||
bf_vrad_weighted_part[i, ...] = csiborgtools.field.bulkflow_peery2018(
|
||||
rdist_part, part_mass, part_pos_current, part_vel_current,
|
||||
distances, weights="1/r^2", verbose=False)
|
||||
|
||||
# Calculate the bulk flow from projected velocities w. 1/r^2 weights
|
||||
bf_vrad_weighted_halo_uniform[i, ...] = csiborgtools.field.bulkflow_peery2018( # noqa
|
||||
rdist_halo, np.ones_like(halo_mass_current), halo_pos_current,
|
||||
halo_vel_current, distances, weights="1/r^2", verbose=False)
|
||||
bf_vrad_weighted_halo[i, ...] = csiborgtools.field.bulkflow_peery2018(
|
||||
rdist_halo, halo_mass_current, halo_pos_current,
|
||||
halo_vel_current, distances, weights="1/r^2", verbose=False)
|
||||
|
||||
# Store the haloes around this observer
|
||||
samples[i] = {
|
||||
"halo_pos": halo_pos_current,
|
||||
"halo_vel": halo_vel_current,
|
||||
"halo_mass": halo_mass_current}
|
||||
|
||||
# Finally save the output
|
||||
fname = join(folder, f"{fname_basis}_{nsim}.hdf5")
|
||||
if verbose:
|
||||
print(f"Saving to `{fname}`.")
|
||||
with File(fname, 'w') as f:
|
||||
f["distances"] = distances
|
||||
f["bf_volume_part"] = bf_volume_part
|
||||
f["bf_volume_halo"] = bf_volume_halo
|
||||
f["bf_vrad_weighted_part"] = bf_vrad_weighted_part
|
||||
f["bf_volume_halo_uniform"] = bf_volume_halo_uniform
|
||||
f["bf_vrad_weighted_halo_uniform"] = bf_vrad_weighted_halo_uniform
|
||||
f["bf_vrad_weighted_halo"] = bf_vrad_weighted_halo
|
||||
f["obs_vel"] = obs_vel
|
||||
|
||||
for i in range(len(observers)):
|
||||
g = f.create_group(f"obs_{str(i)}")
|
||||
g["halo_pos"] = samples[i]["halo_pos"]
|
||||
g["halo_vel"] = samples[i]["halo_vel"]
|
||||
g["halo_mass"] = samples[i]["halo_mass"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Rmin = 0
|
||||
Rmax = 150
|
||||
subtract_observer_velocity = True
|
||||
folder = "/mnt/extraspace/rstiskalek/quijote/BulkFlow_fiducial"
|
||||
fname_basis = "sBF_nsim" if subtract_observer_velocity else "BF_nsim"
|
||||
|
||||
comm = MPI.COMM_WORLD
|
||||
rank = comm.Get_rank()
|
||||
|
||||
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
|
||||
nsims = list(paths.get_ics("quijote"))
|
||||
|
||||
def main_wrapper(nsim):
|
||||
main(nsim, folder, fname_basis, Rmin, Rmax, subtract_observer_velocity,
|
||||
verbose=rank == 0)
|
||||
|
||||
if rank == 0:
|
||||
print(f"Running with {len(nsims)} Quijote simulations.")
|
||||
|
||||
comm.Barrier()
|
||||
work_delegation(main_wrapper, nsims, comm, master_verbose=True)
|
19
scripts/flow/quijote_bulkflow.sh
Executable file
19
scripts/flow/quijote_bulkflow.sh
Executable file
|
@ -0,0 +1,19 @@
|
|||
nthreads=12
|
||||
memory=24
|
||||
on_login=0
|
||||
queue="berg"
|
||||
env="/mnt/zfsusers/rstiskalek/csiborgtools/venv_csiborg/bin/python"
|
||||
file="quijote_bulkflow.py"
|
||||
|
||||
|
||||
pythoncm="$env $file"
|
||||
if [ $on_login -eq 1 ]; then
|
||||
echo $pythoncm
|
||||
$pythoncm
|
||||
else
|
||||
cm="addqueue -q $queue -n $nthreads -m $memory $pythoncm"
|
||||
echo "Submitting:"
|
||||
echo $cm
|
||||
echo
|
||||
eval $cm
|
||||
fi
|
126
scripts/flow/quijote_pecvel_covmat.py
Normal file
126
scripts/flow/quijote_pecvel_covmat.py
Normal file
|
@ -0,0 +1,126 @@
|
|||
# Copyright (C) 2024 Richard Stiskalek
|
||||
# This program is free software; you can redistribute it and/or modify it
|
||||
# under the terms of the GNU General Public License as published by the
|
||||
# Free Software Foundation; either version 3 of the License, or (at your
|
||||
# option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
|
||||
# Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License along
|
||||
# with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
A script to calculate the bulk flow in Quijote simulations from either
|
||||
particles or FoF haloes and to also save the resulting smaller halo catalogues.
|
||||
|
||||
|
||||
|
||||
"""
|
||||
import csiborgtools
|
||||
import healpy as hp
|
||||
import numpy as np
|
||||
from csiborgtools.field import evaluate_cartesian_cic
|
||||
from h5py import File
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def load_field(nsim, MAS, grid, paths):
|
||||
"""Load the precomputed velocity field from the Quijote simulations."""
|
||||
reader = csiborgtools.read.QuijoteField(nsim, paths)
|
||||
return reader.velocity_field(MAS, grid)
|
||||
|
||||
|
||||
def skymap_coordinates(nside, R):
|
||||
"""Generate 3D pixel positions at a given radius."""
|
||||
theta, phi = hp.pix2ang(nside, np.arange(hp.nside2npix(nside)), )
|
||||
pos = R * np.vstack([np.sin(theta) * np.cos(phi),
|
||||
np.sin(theta) * np.sin(phi),
|
||||
np.cos(theta)]).T
|
||||
|
||||
# Quijote expects float32, otherwise it will crash
|
||||
return pos.astype(np.float32)
|
||||
|
||||
|
||||
def make_radvel_skymap(velocity_field, pos, observer, boxsize):
|
||||
"""
|
||||
Make a skymap of the radial velocity field at the given 3D positions which
|
||||
correspond to the pixels.
|
||||
"""
|
||||
# Velocities on the shell
|
||||
Vx, Vy, Vz = [evaluate_cartesian_cic(velocity_field[i], pos=pos / boxsize,
|
||||
smooth_scales=None) for i in range(3)]
|
||||
|
||||
# Observer velocity
|
||||
obs = np.asarray(observer).reshape(1, 3) / boxsize
|
||||
Vx_obs, Vy_obs, Vz_obs = [evaluate_cartesian_cic(
|
||||
velocity_field[i], pos=obs, smooth_scales=None)[0] for i in range(3)]
|
||||
|
||||
# Subtract observer velocity
|
||||
Vx -= Vx_obs
|
||||
Vy -= Vy_obs
|
||||
Vz -= Vz_obs
|
||||
|
||||
# Radial velocity
|
||||
norm_pos = pos - observer
|
||||
norm_pos /= np.linalg.norm(norm_pos, axis=1).reshape(-1, 1)
|
||||
Vrad = Vx * norm_pos[:, 0] + Vy * norm_pos[:, 1] + Vz * norm_pos[:, 2]
|
||||
|
||||
return Vrad
|
||||
|
||||
|
||||
def main(nsims, observers, nside, ell_max, radii, boxsize, MAS, grid, fname):
|
||||
"""Calculate the sky maps and C_ell."""
|
||||
# 3D pixel positions at each radius in box units
|
||||
map_pos = [skymap_coordinates(nside, R) for R in radii]
|
||||
|
||||
print(f"Writing to `{fname}`...")
|
||||
f = File(fname, 'w')
|
||||
f.create_dataset("ell", data=np.arange(ell_max + 1))
|
||||
f.create_dataset("radii", data=radii)
|
||||
f.attrs["num_simulations"] = len(nsims)
|
||||
f.attrs["num_observers"] = len(observers)
|
||||
f.attrs["num_radii"] = len(radii)
|
||||
f.attrs["npix_per_map"] = hp.nside2npix(nside)
|
||||
|
||||
for nsim in tqdm(nsims, desc="Simulations"):
|
||||
grp_sim = f.create_group(f"nsim_{nsim}")
|
||||
velocity_field = load_field(nsim, MAS, grid, paths)
|
||||
|
||||
for n in range(len(observers)):
|
||||
grp_observer = grp_sim.create_group(f"observer_{n}")
|
||||
|
||||
for i in range(len(radii)):
|
||||
pos = map_pos[i] + observers[n]
|
||||
|
||||
skymap = make_radvel_skymap(velocity_field, pos, observers[n],
|
||||
boxsize)
|
||||
C_ell = hp.sphtfunc.anafast(skymap, lmax=ell_max)
|
||||
|
||||
grp_observer.create_dataset(f"skymap_{i}", data=skymap)
|
||||
grp_observer.create_dataset(f"C_ell_{i}", data=C_ell)
|
||||
|
||||
print(f"Closing `{fname}`.")
|
||||
f.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
|
||||
|
||||
MAS = "PCS"
|
||||
grid = 512
|
||||
nside = 256
|
||||
ell_max = 16
|
||||
boxsize = 1000
|
||||
Rmax = 200
|
||||
radii = np.linspace(100, 150, 5)
|
||||
fname = "/mnt/extraspace/rstiskalek/BBF/Quijote_Cell/C_ell_fiducial.h5"
|
||||
nsims = list(range(50))
|
||||
observers = csiborgtools.read.fiducial_observers(boxsize, Rmax)
|
||||
|
||||
main(nsims, observers, nside, ell_max, radii, boxsize, MAS, grid, fname)
|
25
scripts/flow/quijote_pecvel_covmat.sh
Executable file
25
scripts/flow/quijote_pecvel_covmat.sh
Executable file
|
@ -0,0 +1,25 @@
|
|||
#!/bin/bash
|
||||
nthreads=1
|
||||
memory=16
|
||||
on_login=${1}
|
||||
queue="berg"
|
||||
env="/mnt/zfsusers/rstiskalek/csiborgtools/venv_csiborg/bin/python"
|
||||
file="quijote_pecvel_covmat.py"
|
||||
|
||||
|
||||
if [ "$on_login" != "1" ] && [ "$on_login" != "0" ]; then
|
||||
echo "Invalid input: 'on_login' (1). Please provide 1 or 0."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
pythoncm="$env $file"
|
||||
if [ $on_login -eq 1 ]; then
|
||||
echo $pythoncm
|
||||
$pythoncm
|
||||
else
|
||||
cm="addqueue -q $queue -n $nthreads -m $memory $pythoncm"
|
||||
echo "Submitting:"
|
||||
echo $cm
|
||||
echo
|
||||
eval $cm
|
||||
fi
|
Loading…
Add table
Add a link
Reference in a new issue