mirror of
https://github.com/Richard-Sti/csiborgtools_public.git
synced 2025-07-03 21:11:11 +00:00
Add more about evidence and selection to flow (#142)
* Add Laplace evidence * Numerically stable laplace evidence * Minor edits to Laplace * Remove rmax * Rm old things * Rm comments * Add script * Add super toy selection * Add super toy selection * Update script
This commit is contained in:
parent
d13246a394
commit
3d1e1c0ae3
8 changed files with 243 additions and 57 deletions
60
scripts/flow/test_harmonic.py
Normal file
60
scripts/flow/test_harmonic.py
Normal file
|
@ -0,0 +1,60 @@
|
|||
from argparse import ArgumentParser, ArgumentTypeError
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--device", type=str, default="cpu",
|
||||
help="Device to use.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
ARGS = parse_args()
|
||||
# This must be done before we import JAX etc.
|
||||
from numpyro import set_host_device_count, set_platform # noqa
|
||||
|
||||
set_platform(ARGS.device) # noqa
|
||||
|
||||
from jax import numpy as jnp # noqa
|
||||
import numpy as np # noqa
|
||||
import csiborgtools # noqa
|
||||
from scipy.stats import multivariate_normal # noqa
|
||||
|
||||
|
||||
def get_harmonic_evidence(samples, log_posterior, nchains_harmonic, epoch_num):
|
||||
"""Compute evidence using the `harmonic` package."""
|
||||
data, names = csiborgtools.dict_samples_to_array(samples)
|
||||
data = data.reshape(nchains_harmonic, -1, len(names))
|
||||
log_posterior = log_posterior.reshape(nchains_harmonic, -1)
|
||||
|
||||
return csiborgtools.harmonic_evidence(
|
||||
data, log_posterior, return_flow_samples=False, epochs_num=epoch_num)
|
||||
|
||||
|
||||
ndim = 250
|
||||
nsamples = 100_000
|
||||
nchains_split = 10
|
||||
loc = jnp.zeros(ndim)
|
||||
cov = jnp.eye(ndim)
|
||||
|
||||
|
||||
gen = np.random.default_rng()
|
||||
X = gen.multivariate_normal(loc, cov, size=nsamples)
|
||||
samples = {f"x_{i}": X[:, i] for i in range(ndim)}
|
||||
logprob = multivariate_normal(loc, cov).logpdf(X)
|
||||
|
||||
neg_lnZ_laplace, neg_lnZ_laplace_error = csiborgtools.laplace_evidence(
|
||||
samples, logprob, nchains_split)
|
||||
print(f"neg_lnZ_laplace: {neg_lnZ_laplace} +/- {neg_lnZ_laplace_error}")
|
||||
|
||||
|
||||
neg_lnZ_harmonic, neg_lnZ_harmonic_error = get_harmonic_evidence(
|
||||
samples, logprob, nchains_split, epoch_num=30)
|
||||
print(f"neg_lnZ_harmonic: {neg_lnZ_harmonic} +/- {neg_lnZ_harmonic_error}")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
Loading…
Add table
Add a link
Reference in a new issue