418 lines
15 KiB
Python
418 lines
15 KiB
Python
import aquila_borg as borg
|
||
import numpy as np
|
||
import numbers
|
||
import jaxlib
|
||
import jax.numpy as jnp
|
||
import jax
|
||
import configparser
|
||
|
||
# Output stream management
|
||
cons = borg.console()
|
||
def myprint(x):
|
||
if isinstance(x, str):
|
||
cons.print_std(x)
|
||
else:
|
||
cons.print_std(repr(x))
|
||
|
||
def get_cosmopar(ini_file):
|
||
"""
|
||
Extract cosmological parameters from an ini file
|
||
|
||
Args:
|
||
:ini_file (str): Path to the ini file
|
||
|
||
Returns:
|
||
:cpar (borg.cosmo.CosmologicalParameters): Cosmological parameters
|
||
"""
|
||
|
||
config = configparser.ConfigParser()
|
||
config.read(ini_file)
|
||
|
||
cpar = borg.cosmo.CosmologicalParameters()
|
||
cpar.default()
|
||
cpar.fnl = float(config['cosmology']['fnl'])
|
||
cpar.omega_k = float(config['cosmology']['omega_k'])
|
||
cpar.omega_m = float(config['cosmology']['omega_m'])
|
||
cpar.omega_b = float(config['cosmology']['omega_b'])
|
||
cpar.omega_q = float(config['cosmology']['omega_q'])
|
||
cpar.h = float(config['cosmology']['h100'])
|
||
cpar.sigma8 = float(config['cosmology']['sigma8'])
|
||
cpar.n_s = float(config['cosmology']['n_s'])
|
||
cpar.w = float(config['cosmology']['w'])
|
||
cpar.wprime = float(config['cosmology']['wprime'])
|
||
|
||
cpar = compute_As(cpar)
|
||
|
||
return cpar
|
||
|
||
|
||
def compute_As(cpar):
|
||
"""
|
||
Compute As given values of sigma8
|
||
|
||
Args:
|
||
:cpar (borg.cosmo.CosmologicalParameters): Cosmological parameters with wrong As
|
||
|
||
Returns:
|
||
:cpar (borg.cosmo.CosmologicalParameters): Cosmological parameters with updated As
|
||
|
||
"""
|
||
|
||
# requires BORG-CLASS
|
||
if not hasattr(borg.cosmo, 'ClassCosmo'):
|
||
raise ImportError(
|
||
"BORG-CLASS is required to compute As, but is not installed.")
|
||
|
||
sigma8_true = jnp.copy(cpar.sigma8)
|
||
cpar.sigma8 = 0
|
||
cpar.A_s = 2.3e-9
|
||
k_max, k_per_decade = 10, 100
|
||
extra_class = {}
|
||
extra_class['YHe'] = '0.24'
|
||
cosmo = borg.cosmo.ClassCosmo(cpar, k_per_decade, k_max, extra=extra_class)
|
||
cosmo.computeSigma8()
|
||
cos = cosmo.getCosmology()
|
||
cpar.A_s = (sigma8_true/cos['sigma_8'])**2*cpar.A_s
|
||
cpar.sigma8 = sigma8_true
|
||
|
||
print('Updated cosmology:', cpar)
|
||
|
||
return cpar
|
||
|
||
|
||
class MyLikelihood(borg.likelihood.BaseLikelihood):
|
||
"""
|
||
HADES likelihood class
|
||
"""
|
||
|
||
def __init__(self, fwd: borg.forward.BaseForwardModel,
|
||
fwd_vel: borg.forward.BaseForwardModel,
|
||
ini_fname: str):
|
||
|
||
self.fwd = fwd
|
||
self.fwd_vel = fwd_vel
|
||
|
||
# Read the ini file
|
||
self.ini_fname = ini_fname
|
||
self.config = configparser.ConfigParser()
|
||
self.config.read(ini_fname)
|
||
self.N = [int(self.config['system'][f'N{i}']) for i in range(3)] # Number of grid points per side
|
||
self.L = [float(self.config['system'][f'L{i}']) for i in range(3)] # Box size lenght Mpc/h
|
||
|
||
self.sigma_dens = float(self.config['mock']['sigma_dens']) # Density scatter
|
||
self.sigma_vel = float(self.config['mock']['sigma_vel']) # Velocity scatter
|
||
|
||
myprint(f"Likelihood initialized with {self.N} grid points and box size {self.L} Mpc/h")
|
||
super().__init__(fwd, self.N, self.L)
|
||
|
||
# Set up cosmoligical parameters
|
||
cpar = get_cosmopar(ini_fname)
|
||
self.updateCosmology(cpar)
|
||
|
||
# Gradient of the likelihood
|
||
self.grad_like = jax.grad(self.dens2like, argnums=(0, 1))
|
||
|
||
def updateCosmology(self, cosmo: borg.cosmo.CosmologicalParameters) -> None:
|
||
cpar = compute_As(cosmo)
|
||
self.fwd.setCosmoParams(cpar)
|
||
|
||
def updateMetaParameters(self, state: borg.likelihood.MarkovState) -> None:
|
||
"""
|
||
Update the meta parameters of the sampler (not sampled) from the MarkovState.
|
||
|
||
Args:
|
||
- state (borg.likelihood.MarkovState): The state object to be used in the likelihood.
|
||
|
||
"""
|
||
cosmo = state['cosmology']
|
||
cpar = compute_As(cosmo)
|
||
self.fwd.setCosmoParams(cpar)
|
||
|
||
def initializeLikelihood(self, state: borg.likelihood.MarkovState) -> None:
|
||
"""
|
||
Initialize the likelihood function.
|
||
|
||
Args:
|
||
- state (borg.likelihood.MarkovState): The state object to be used in the likelihood.
|
||
|
||
"""
|
||
myprint("Init likelihood")
|
||
state.newArray3d("BORG_final_density", *self.fwd.getOutputBoxModel().N, True)
|
||
state.newArray3d("BORG_final_velocity_x", *self.fwd.getOutputBoxModel().N, True)
|
||
state.newArray3d("BORG_final_velocity_y", *self.fwd.getOutputBoxModel().N, True)
|
||
state.newArray3d("BORG_final_velocity_z", *self.fwd.getOutputBoxModel().N, True)
|
||
|
||
# Could load real data
|
||
# We'll generate mock data which has its own function
|
||
|
||
def generateMockData(self, s_hat:np.ndarray, state: borg.likelihood.MarkovState) -> None:
|
||
"""
|
||
Generates mock data by simulating the forward model with the given white noise
|
||
|
||
Args:
|
||
- s_hat (np.ndarray): The input (initial) white noise field.
|
||
- state (borg.likelihood.MarkovState): The Markov state object to be used in the likelihood.
|
||
"""
|
||
myprint('Making mock from BORG')
|
||
|
||
# Get density field from the initial conditions
|
||
# Could replace with any (better) simulation here
|
||
# This version is self-consistnet
|
||
dens = np.zeros(self.fwd.getOutputBoxModel().N)
|
||
myprint('Running forward model')
|
||
myprint(self.fwd.getCosmoParams())
|
||
self.fwd.forwardModel_v2(s_hat)
|
||
self.fwd.getDensityFinal(dens)
|
||
state["BORG_final_density"][:] = dens
|
||
self.true_dens = dens.copy()
|
||
|
||
# Get velocity field
|
||
vel = self.fwd_vel.getVelocityField()
|
||
self.true_vel = vel.copy()
|
||
|
||
# Add some scatter
|
||
myprint('Adding scatter')
|
||
self.obs_dens = self.true_dens + np.random.randn(*self.true_dens.shape) * self.sigma_dens
|
||
self.obs_vel = self.true_vel + np.random.randn(*self.true_vel.shape) * self.sigma_vel
|
||
|
||
# Compute the likelihood and print it
|
||
myprint('From mock')
|
||
self.saved_s_hat = s_hat.copy()
|
||
self.logLikelihoodComplex(s_hat, False)
|
||
self.commitAuxiliaryFields(state)
|
||
myprint('Done')
|
||
|
||
|
||
def dens2like(self, output_density: np.ndarray, output_velocity: np.ndarray) -> float:
|
||
"""
|
||
Compute the likelihood from the density field
|
||
Args:
|
||
- output_density (np.ndarray): The density field to be used in the likelihood.
|
||
- output_velocity (np.ndarray): The velocity field to be used in the likelihood.
|
||
Returns:
|
||
- float: The likelihood value.
|
||
"""
|
||
# Compute the likelihood from the density field
|
||
# This is a simple Gaussian likelihood
|
||
# Could be replaced with any other likelihood
|
||
diff = output_density - self.obs_dens
|
||
diff_vel = output_velocity - self.obs_vel
|
||
like = 0.5 * jnp.sum(diff**2) / (self.sigma_dens**2)
|
||
like += 0.5 * jnp.sum(diff_vel**2) / (self.sigma_vel**2)
|
||
|
||
return like
|
||
|
||
|
||
def logLikelihoodComplex(self, s_hat:np.ndarray, gradientIsNext:bool):
|
||
|
||
# myprint('Getting density field now')
|
||
# Get the density field from the forward model
|
||
dens = np.zeros(self.fwd.getOutputBoxModel().N)
|
||
self.fwd.forwardModel_v2(s_hat)
|
||
self.fwd.getDensityFinal(dens)
|
||
|
||
# Get the velocity field from the forward model
|
||
vel = self.fwd_vel.getVelocityField()
|
||
|
||
L = self.dens2like(dens, vel)
|
||
|
||
if isinstance(L, numbers.Number) or isinstance(L, jaxlib.xla_extension.ArrayImpl):
|
||
myprint(f"var(s_hat): {np.var(s_hat)}, Call to logLike: {L}")
|
||
|
||
self.delta = dens.copy()
|
||
self.vel = vel.copy()
|
||
|
||
return L
|
||
|
||
|
||
def gradientLikelihoodComplex(self, s_hat:np.ndarray):
|
||
|
||
# Run BORG density field
|
||
output_density = np.zeros(self.N)
|
||
self.fwd.forwardModel_v2(s_hat)
|
||
self.fwd.getDensityFinal(output_density)
|
||
|
||
# Run BORG velocity field
|
||
vel = self.fwd_vel.getVelocityField()
|
||
|
||
# Compute the gradient of the likelihood
|
||
# d logL / d dens, d logL / d vel
|
||
dens_gradient, vel_gradient = self.grad_like(output_density, vel)
|
||
|
||
# Now get d logL / d s_hat
|
||
dens_gradient = np.array(dens_gradient, dtype=np.float64)
|
||
vel_gradient = np.array(vel_gradient, dtype=np.float64)
|
||
|
||
self.fwd_vel.computeAdjointModel(vel_gradient)
|
||
|
||
self.fwd.adjointModel_v2(dens_gradient)
|
||
mygrad_hat = np.zeros(s_hat.shape, dtype=np.complex128)
|
||
self.fwd.getAdjointModel(mygrad_hat)
|
||
self.fwd.clearAdjointGradient()
|
||
|
||
return mygrad_hat
|
||
|
||
def commitAuxiliaryFields(self, state: borg.likelihood.MarkovState) -> None:
|
||
"""
|
||
Commits the final density field to the Markov state.
|
||
Args:
|
||
- state (borg.state.State): The state object containing the final density field.
|
||
"""
|
||
self.updateCosmology(self.fwd.getCosmoParams())
|
||
self.dens2like(self.delta, self.vel)
|
||
state["BORG_final_density"][:] = self.delta.copy()
|
||
state["BORG_final_velocity_x"][:] = self.vel[0].copy()
|
||
state["BORG_final_velocity_y"][:] = self.vel[1].copy()
|
||
state["BORG_final_velocity_z"][:] = self.vel[2].copy()
|
||
|
||
|
||
@borg.registerGravityBuilder
|
||
def build_gravity_model(state: borg.likelihood.MarkovState, box: borg.forward.BoxModel, ini_fname=None) -> borg.forward.BaseForwardModel:
|
||
"""
|
||
Builds the gravity model and returns the forward model chain.
|
||
|
||
Args:
|
||
- state (borg.likelihood.MarkovState): The Markov state object to be used in the likelihood.
|
||
- box (borg.forward.BoxModel): The input box model.
|
||
- ini_file (str, default=None): The location of the ini file. If None, use borg.getIniConfigurationFilename()
|
||
|
||
Returns:
|
||
borg.forward.BaseForwardModel: The forward model.
|
||
|
||
"""
|
||
|
||
global chain, fwd_vel
|
||
myprint("Building gravity model")
|
||
|
||
if ini_fname is None:
|
||
ini_fname=borg.getIniConfigurationFilename()
|
||
config = configparser.ConfigParser()
|
||
config.read(ini_fname)
|
||
# READ FROM INI FILE
|
||
which_model = config['gravity']['which_model']
|
||
ai = float(config['gravity']['ai']) # Initial scale factor
|
||
af = float(config['gravity']['af']) # Final scale factor
|
||
supersampling = int(config['gravity']['supersampling'])
|
||
forcesampling = int(config['gravity']['forcesampling'])
|
||
nsteps = int(config['gravity']['nsteps']) # Number of steps in the PM solver
|
||
|
||
chain = borg.forward.ChainForwardModel(box)
|
||
|
||
# Make sure that the initial conditions are real in position space
|
||
chain.addModel(borg.forward.models.HermiticEnforcer(box))
|
||
|
||
# CLASS transfer function
|
||
chain @= borg.forward.model_lib.M_PRIMORDIAL_AS(box)
|
||
transfer_class = borg.forward.model_lib.M_TRANSFER_CLASS(box, opts=dict(a_transfer=1.0))
|
||
transfer_class.setModelParams({"extra_class_arguments":{'YHe':'0.24'}}) # helps deals with errors with primordial physics in CLASS for weird cosmologies
|
||
chain @= transfer_class
|
||
|
||
# Gravity model
|
||
if which_model == 'lpt':
|
||
mod = borg.forward.model_lib.M_LPT_CIC(
|
||
box,
|
||
opts=dict(a_initial=af,
|
||
a_final=af,
|
||
do_rsd=False,
|
||
supersampling=supersampling,
|
||
lightcone=False,
|
||
part_factor=1.01,))
|
||
elif which_model == '2lpt':
|
||
mod = borg.forward.model_lib.M_2LPT_CIC(
|
||
box,
|
||
opts=dict(a_initial=af,
|
||
a_final=af,
|
||
do_rsd=False,
|
||
supersampling=supersampling,
|
||
lightcone=False,
|
||
part_factor=1.01,))
|
||
elif which_model == 'pm':
|
||
mod = borg.forward.model_lib.M_PM_CIC(
|
||
box,
|
||
opts=dict(a_initial=af,
|
||
a_final=af,
|
||
pm_start_z=1/ai - 1,
|
||
do_rsd=False,
|
||
supersampling=supersampling,
|
||
forcesampling=forcesampling,
|
||
lightcone=False,
|
||
part_factor=1.01,
|
||
pm_nsteps=nsteps, # Number of steps in the PM solver
|
||
tcola=False
|
||
))
|
||
elif which_model == 'cola':
|
||
mod = borg.forward.model_lib.M_PM_CIC(
|
||
box,
|
||
opts=dict(a_initial=af,
|
||
a_final=af,
|
||
pm_start_z=1/ai - 1,
|
||
do_rsd=False,
|
||
supersampling=supersampling,
|
||
forcesampling=forcesampling,
|
||
lightcone=False,
|
||
part_factor=1.01,
|
||
pm_nsteps=nsteps, # Number of steps in the PM solver
|
||
tcola=True
|
||
))
|
||
else:
|
||
raise ValueError(f"Unknown model {which_model}")
|
||
|
||
mod.accumulateAdjoint(True)
|
||
chain @= mod
|
||
|
||
# Cosmological parameters
|
||
cpar = get_cosmopar(borg.getIniConfigurationFilename())
|
||
print('Setting cosmo params', cpar)
|
||
chain.setCosmoParams(cpar)
|
||
|
||
# Set the forward model for velocities
|
||
vel_model = config['velocity']['which_model']
|
||
if vel_model == 'linear':
|
||
fwd_vel = borg.forward.velocity.LinearModel(box, mod, af)
|
||
elif vel_model == 'cic':
|
||
rsmooth = float(config['velocity']['rsmooth'])
|
||
fwd_vel = borg.forward.velocity.CICModel(box, mod, rsmooth)
|
||
elif vel_model == 'sic':
|
||
fwd_vel = borg.forward.velocity.SICModel(box, mod)
|
||
else:
|
||
raise ValueError(f"Unknown model {vel_model}")
|
||
|
||
return chain
|
||
|
||
@borg.registerSamplerBuilder
|
||
def build_sampler(state: borg.likelihood.MarkovState, info: borg.likelihood.LikelihoodInfo, loop: borg.samplers.MainLoop):
|
||
"""
|
||
Builds the sampler and returns the main loop.
|
||
|
||
Args:
|
||
- state (borg.likelihood.MarkovState): The Markov state object to be used in the likelihood.
|
||
- info (borg.likelihood.LikelihoodInfo): The likelihood info object to be used in the likelihood.
|
||
- loop (borg.samplers.MainLoop): The main loop object to be used in the likelihood.
|
||
|
||
Returns:
|
||
borg.samplers.MainLoop: The main loop.
|
||
"""
|
||
|
||
# Here you can add cosmology sampling, model parameter sampling, etc.
|
||
# For now, we'll just sample the initial conditions so don't need to do anything
|
||
|
||
return []
|
||
|
||
|
||
@borg.registerLikelihoodBuilder
|
||
def build_likelihood(state: borg.likelihood.MarkovState, info: borg.likelihood.LikelihoodInfo) -> borg.likelihood.BaseLikelihood:
|
||
"""
|
||
Builds the likelihood and returns the likelihood object.
|
||
|
||
Args:
|
||
- state (borg.likelihood.MarkovState): The Markov state object to be used in the likelihood.
|
||
- info (borg.likelihood.LikelihoodInfo): The likelihood info object to be used in the likelihood.
|
||
|
||
Returns:
|
||
borg.likelihood.BaseLikelihood: The likelihood object.
|
||
"""
|
||
|
||
myprint("Building likelihood")
|
||
global likelihood
|
||
likelihood = MyLikelihood(chain, fwd_vel, borg.getIniConfigurationFilename())
|
||
return likelihood
|