from argparse import ArgumentParser def parse_args(): parser = ArgumentParser() parser.add_argument("--device", type=str, default="cpu", help="Device to use.") return parser.parse_args() ARGS = parse_args() # This must be done before we import JAX etc. from numpyro import set_platform # noqa set_platform(ARGS.device) # noqa from jax import numpy as jnp # noqa import numpy as np # noqa import csiborgtools # noqa from scipy.stats import multivariate_normal # noqa def get_harmonic_evidence(samples, log_posterior, nchains_harmonic, epoch_num): """Compute evidence using the `harmonic` package.""" data, names = csiborgtools.dict_samples_to_array(samples) data = data.reshape(nchains_harmonic, -1, len(names)) log_posterior = log_posterior.reshape(nchains_harmonic, -1) return csiborgtools.harmonic_evidence( data, log_posterior, return_flow_samples=False, epochs_num=epoch_num) ndim = 150 nsamples = 50_000 nchains_split = 10 loc = jnp.zeros(ndim) cov = jnp.eye(ndim) gen = np.random.default_rng() X = gen.multivariate_normal(loc, cov, size=nsamples) samples = {f"x_{i}": X[:, i] for i in range(ndim)} logprob = multivariate_normal(loc, cov).logpdf(X) neg_lnZ_harmonic, neg_lnZ_harmonic_error = get_harmonic_evidence( samples, logprob, nchains_split, epoch_num=30) print(f"neg_lnZ_harmonic: {neg_lnZ_harmonic} +/- {neg_lnZ_harmonic_error}")