from argparse import ArgumentParser


def parse_args():
    parser = ArgumentParser()
    parser.add_argument("--device", type=str, default="cpu",
                        help="Device to use.")
    return parser.parse_args()


ARGS = parse_args()
# This must be done before we import JAX etc.
from numpyro import set_platform                                                # noqa

set_platform(ARGS.device)                                                       # noqa

from jax import numpy as jnp                                                    # noqa
import numpy as np                                                              # noqa
import csiborgtools                                                             # noqa
from scipy.stats import multivariate_normal                                     # noqa


def get_harmonic_evidence(samples, log_posterior, nchains_harmonic, epoch_num):
    """Compute evidence using the `harmonic` package."""
    data, names = csiborgtools.dict_samples_to_array(samples)
    data = data.reshape(nchains_harmonic, -1, len(names))
    log_posterior = log_posterior.reshape(nchains_harmonic, -1)

    return csiborgtools.harmonic_evidence(
        data, log_posterior, return_flow_samples=False, epochs_num=epoch_num)


ndim = 150
nsamples = 50_000
nchains_split = 10
loc = jnp.zeros(ndim)
cov = jnp.eye(ndim)


gen = np.random.default_rng()
X = gen.multivariate_normal(loc, cov, size=nsamples)
samples = {f"x_{i}": X[:, i] for i in range(ndim)}
logprob = multivariate_normal(loc, cov).logpdf(X)


neg_lnZ_harmonic, neg_lnZ_harmonic_error = get_harmonic_evidence(
    samples, logprob, nchains_split, epoch_num=30)
print(f"neg_lnZ_harmonic: {neg_lnZ_harmonic} +/- {neg_lnZ_harmonic_error}")