mirror of
https://github.com/Richard-Sti/csiborgtools.git
synced 2024-12-23 09:48:01 +00:00
2b938c112c
* Add more comments * Add flow paths * Simplify paths * Update default arguemnts * Update paths * Update param names * Update some of scipts for reading files * Add the Mike method option * Update plotting * Update fnames * Simplify things * Make more default options * Add print * Update * Downsample CF4 * Update numpyro selection * Add selection fitting nb * Add coeffs * Update script * Add nb * Add label * Increase number of steps * Update default params * Add more labels * Improve file name * Update nb * Fix little bug * Remove import * Update scales * Update labels * Add script * Update script * Add more * Add more labels * Add script * Add submit * Update spacing * Update submit scrips * Update script * Update defaults * Update defaults * Update nb * Update test * Update imports * Add script * Add support for Indranil void * Add a dipole * Update nb * Update submit * Update Om0 * Add final * Update default params * Fix bug * Add option to fix to LG frame * Add Vext label * Add Vext label * Update script * Rm fixed LG * rm LG stuff * Update script * Update bulk flow plotting * Update nb * Add no field option * Update defaults * Update nb * Update script * Update nb * Update nb * Add names to plots * Update nb * Update plot * Add more latex names * Update default * Update nb * Update np * Add plane slicing * Add nb with slices * Update nb * Update script * Upddate nb * Update nb
56 lines
1.7 KiB
Python
56 lines
1.7 KiB
Python
from argparse import ArgumentParser
|
|
|
|
|
|
def parse_args():
|
|
parser = ArgumentParser()
|
|
parser.add_argument("--device", type=str, default="cpu",
|
|
help="Device to use.")
|
|
return parser.parse_args()
|
|
|
|
|
|
ARGS = parse_args()
|
|
# This must be done before we import JAX etc.
|
|
from numpyro import set_platform # noqa
|
|
|
|
set_platform(ARGS.device) # noqa
|
|
|
|
from jax import numpy as jnp # noqa
|
|
import numpy as np # noqa
|
|
import csiborgtools # noqa
|
|
from scipy.stats import multivariate_normal # noqa
|
|
|
|
|
|
def get_harmonic_evidence(samples, log_posterior, nchains_harmonic, epoch_num):
|
|
"""Compute evidence using the `harmonic` package."""
|
|
data, names = csiborgtools.dict_samples_to_array(samples)
|
|
data = data.reshape(nchains_harmonic, -1, len(names))
|
|
log_posterior = log_posterior.reshape(nchains_harmonic, -1)
|
|
|
|
return csiborgtools.harmonic_evidence(
|
|
data, log_posterior, return_flow_samples=False, epochs_num=epoch_num)
|
|
|
|
|
|
ndim = 150
|
|
nsamples = 50_000
|
|
nchains_split = 10
|
|
loc = jnp.zeros(ndim)
|
|
cov = jnp.eye(ndim)
|
|
|
|
|
|
gen = np.random.default_rng()
|
|
X = gen.multivariate_normal(loc, cov, size=nsamples)
|
|
samples = {f"x_{i}": X[:, i] for i in range(ndim)}
|
|
logprob = multivariate_normal(loc, cov).logpdf(X)
|
|
|
|
|
|
neg_lnZ_harmonic, neg_lnZ_harmonic_error = get_harmonic_evidence(
|
|
samples, logprob, nchains_split, epoch_num=30)
|
|
print(f"neg_lnZ_harmonic: {neg_lnZ_harmonic} +/- {neg_lnZ_harmonic_error}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|