Add hmc model sampler and fix initial As set to zero
4
.gitignore
vendored
|
@ -164,3 +164,7 @@ cython_debug/
|
||||||
*timing_stats_0.txt
|
*timing_stats_0.txt
|
||||||
*fft_wisdom
|
*fft_wisdom
|
||||||
tests/*.h5
|
tests/*.h5
|
||||||
|
*.o*
|
||||||
|
*timing_stats*
|
||||||
|
*fft_wisdom*
|
||||||
|
*allocation_stats*
|
||||||
|
|
|
@ -35,7 +35,6 @@ class NullForward(borg.forward.BaseForwardModel):
|
||||||
"""
|
"""
|
||||||
for k, v in params.items():
|
for k, v in params.items():
|
||||||
self.params[k] = v
|
self.params[k] = v
|
||||||
print(" ")
|
|
||||||
myprint(f'Updated model parameters: {self.params}')
|
myprint(f'Updated model parameters: {self.params}')
|
||||||
|
|
||||||
def getModelParam(self, model, keyname: str):
|
def getModelParam(self, model, keyname: str):
|
||||||
|
|
|
@ -5,14 +5,18 @@ import warnings
|
||||||
import aquila_borg as borg
|
import aquila_borg as borg
|
||||||
import symbolic_pofk.linear
|
import symbolic_pofk.linear
|
||||||
import jax
|
import jax
|
||||||
|
import jaxlib
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import ast
|
import ast
|
||||||
|
import numbers
|
||||||
|
|
||||||
import borg_velocity.utils as utils
|
import borg_velocity.utils as utils
|
||||||
from borg_velocity.utils import myprint
|
from borg_velocity.utils import myprint
|
||||||
import borg_velocity.forwards as forwards
|
import borg_velocity.forwards as forwards
|
||||||
import borg_velocity.mock_maker as mock_maker
|
import borg_velocity.mock_maker as mock_maker
|
||||||
import borg_velocity.projection as projection
|
import borg_velocity.projection as projection
|
||||||
|
from borg_velocity.samplers import HMCBiasSampler, derive_prior
|
||||||
|
import borg_velocity.samplers as samplers
|
||||||
|
|
||||||
class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
|
class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
|
||||||
"""
|
"""
|
||||||
|
@ -84,13 +88,13 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
|
||||||
self.fwd_vel = fwd_vel
|
self.fwd_vel = fwd_vel
|
||||||
|
|
||||||
# Initialise model parameters
|
# Initialise model parameters
|
||||||
model_params = {
|
self.model_params = {
|
||||||
**{f'mua{i}':self.muA[i] for i in range(self.nsamp)},
|
**{f'mua{i}':self.muA[i] for i in range(self.nsamp)},
|
||||||
**{f'alpha{i}':self.alpha[i] for i in range(self.nsamp)},
|
**{f'alpha{i}':self.alpha[i] for i in range(self.nsamp)},
|
||||||
**{f'bulk_flow_{d}':self.bulk_flow[i] for i, d in enumerate(['x', 'y', 'z'])},
|
**{f'bulk_flow_{d}':self.bulk_flow[i] for i, d in enumerate(['x', 'y', 'z'])},
|
||||||
'sig_v':self.sig_v
|
'sig_v':self.sig_v
|
||||||
}
|
}
|
||||||
self.fwd_param.setModelParams(model_params)
|
self.fwd_param.setModelParams(self.model_params)
|
||||||
|
|
||||||
# Initialise cosmological parameters
|
# Initialise cosmological parameters
|
||||||
cpar = utils.get_cosmopar(self.ini_file)
|
cpar = utils.get_cosmopar(self.ini_file)
|
||||||
|
@ -159,11 +163,6 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
|
||||||
self.fwd.setCosmoParams(cpar)
|
self.fwd.setCosmoParams(cpar)
|
||||||
self.fwd_param.setCosmoParams(cpar)
|
self.fwd_param.setCosmoParams(cpar)
|
||||||
|
|
||||||
# # Compute growth rate
|
|
||||||
# cosmology = borg.cosmo.Cosmology(cosmo)
|
|
||||||
# f = cosmology.gplus(self.af) # dD / da
|
|
||||||
# f *= self.af / cosmology.d_plus(self.af) # f = dlnD / dlna
|
|
||||||
# self.f = f
|
|
||||||
|
|
||||||
def generateMBData(self) -> None:
|
def generateMBData(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -214,6 +213,7 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
|
||||||
- state (borg.likelihood.MarkovState): The Markov state object to be used in the likelihood.
|
- state (borg.likelihood.MarkovState): The Markov state object to be used in the likelihood.
|
||||||
- make_plot (bool, default=True): Whether to make diagnostic plots for the mock data generation
|
- make_plot (bool, default=True): Whether to make diagnostic plots for the mock data generation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.run_type == 'data':
|
if self.run_type == 'data':
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
elif self.run_type == 'velmass':
|
elif self.run_type == 'velmass':
|
||||||
|
@ -226,7 +226,9 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
|
||||||
|
|
||||||
self.r_hMpc = [np.sqrt(np.sum(self.coord_meas[i] ** 2, axis=0)) for i in range(self.nsamp)]
|
self.r_hMpc = [np.sqrt(np.sum(self.coord_meas[i] ** 2, axis=0)) for i in range(self.nsamp)]
|
||||||
self.generateMBData()
|
self.generateMBData()
|
||||||
|
myprint('From mock')
|
||||||
|
self.saved_s_hat = s_hat.copy()
|
||||||
|
self.logLikelihoodComplex(s_hat, False)
|
||||||
|
|
||||||
def dens2like(self, output_density: np.ndarray, output_velocity: np.ndarray):
|
def dens2like(self, output_density: np.ndarray, output_velocity: np.ndarray):
|
||||||
"""
|
"""
|
||||||
|
@ -240,20 +242,20 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
lkl = 0
|
lkl = 0
|
||||||
sig_v = self.fwd_param.getModelParam('nullforward', 'sig_v')
|
|
||||||
|
sig_v = self.model_params['sig_v']
|
||||||
|
|
||||||
# Compute velocity field
|
# Compute velocity field
|
||||||
bulk_flow = jnp.array([self.fwd_param.getModelParam('nullforward', 'bulk_flow_x'),
|
bulk_flow = jnp.array([self.model_params['bulk_flow_x'],
|
||||||
self.fwd_param.getModelParam('nullforward', 'bulk_flow_y'),
|
self.model_params['bulk_flow_y'],
|
||||||
self.fwd_param.getModelParam('nullforward', 'bulk_flow_z')])
|
self.model_params['bulk_flow_z']])
|
||||||
v = output_velocity + self.bulk_flow.reshape((3, 1, 1, 1))
|
v = output_velocity + self.bulk_flow.reshape((3, 1, 1, 1))
|
||||||
|
|
||||||
omega_m = self.fwd.getCosmoParams().omega_m
|
omega_m = self.fwd.getCosmoParams().omega_m
|
||||||
|
|
||||||
for i in range(self.nsamp):
|
for i in range(self.nsamp):
|
||||||
print(f'\nSample {i}')
|
muA = self.model_params[f'mua{i}']
|
||||||
muA = self.fwd_param.getModelParam('nullforward', f'mua{i}')
|
alpha = self.model_params[f'alpha{i}']
|
||||||
alpha = self.fwd_param.getModelParam('nullforward', f'alpha{i}')
|
|
||||||
lkl += vel2like(
|
lkl += vel2like(
|
||||||
self.cz_obs[i],
|
self.cz_obs[i],
|
||||||
v,
|
v,
|
||||||
|
@ -277,13 +279,16 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
|
||||||
|
|
||||||
return lkl
|
return lkl
|
||||||
|
|
||||||
def logLikelihoodComplex(self, s_hat: np.ndarray, gradientIsNext: bool):
|
def logLikelihoodComplex(self, s_hat: np.ndarray, gradientIsNext: bool, skip_density: bool=False, update_from_model: bool=True):
|
||||||
"""
|
"""
|
||||||
Calculates the negative log-likelihood of the data.
|
Calculates the negative log-likelihood of the data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
- s_hat (np.ndarray): The input white noise.
|
- s_hat (np.ndarray): The input white noise.
|
||||||
- gradientIsNext (bool): If True, prepares the forward model for gradient calculations.
|
- gradientIsNext (bool): If True, prepares the forward model for gradient calculations.
|
||||||
|
- skip_density (bool, default=False): If True, do not repeat the s_hat -> density, velocity computation
|
||||||
|
and use the stored result
|
||||||
|
- update_from_model (bool, default=True): If True, update self.model_params with self.fwd_param.getModelParam
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The negative log-likelihood value.
|
The negative log-likelihood value.
|
||||||
|
@ -292,20 +297,32 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
|
||||||
|
|
||||||
N = self.fwd.getBoxModel().N[0]
|
N = self.fwd.getBoxModel().N[0]
|
||||||
L = self.fwd.getOutputBoxModel().L[0]
|
L = self.fwd.getOutputBoxModel().L[0]
|
||||||
|
|
||||||
# Run BORG density field
|
|
||||||
output_density = np.zeros((N,N,N))
|
|
||||||
self.fwd.forwardModel_v2(s_hat)
|
|
||||||
self.fwd.getDensityFinal(output_density)
|
|
||||||
|
|
||||||
# Get velocity field
|
if update_from_model:
|
||||||
output_velocity = self.fwd_vel.getVelocityField()
|
for k in self.model_params.keys():
|
||||||
|
self.model_params[k] = self.fwd_param.getModelParam('nullforward', k)
|
||||||
|
self.updateCosmology(self.fwd.getCosmoParams()) # for sigma8 -> As
|
||||||
|
|
||||||
self.delta = output_density
|
if not skip_density:
|
||||||
self.vel = output_velocity
|
# Run BORG density field
|
||||||
|
output_density = np.zeros((N,N,N))
|
||||||
|
self.fwd.forwardModel_v2(s_hat)
|
||||||
|
self.fwd.getDensityFinal(output_density)
|
||||||
|
|
||||||
|
# Get velocity field
|
||||||
|
output_velocity = self.fwd_vel.getVelocityField()
|
||||||
|
else:
|
||||||
|
output_density = self.delta.copy()
|
||||||
|
output_velocity = self.vel.copy()
|
||||||
|
|
||||||
L = self.dens2like(output_density, output_velocity)
|
L = self.dens2like(output_density, output_velocity)
|
||||||
myprint(f"var(s_hat): {np.var(s_hat)}, Call to logLike: {L}")
|
if isinstance(L, numbers.Number) or isinstance(L, jaxlib.xla_extension.ArrayImpl):
|
||||||
|
myprint(f"var(s_hat): {np.var(s_hat)}, Call to logLike: {L}")
|
||||||
|
myprint(self.model_params)
|
||||||
|
myprint(self.fwd.getCosmoParams())
|
||||||
|
|
||||||
|
self.delta = output_density.copy()
|
||||||
|
self.vel = output_velocity.copy()
|
||||||
|
|
||||||
return L
|
return L
|
||||||
|
|
||||||
|
@ -364,10 +381,6 @@ def vel2like(cz_obs, v, MB_field, MB_pos, r, r_hMpc, sig_mu, sig_v, omega_m, muA
|
||||||
"""
|
"""
|
||||||
Jitted part of dens2like
|
Jitted part of dens2like
|
||||||
"""
|
"""
|
||||||
print('Dens', MB_field.min(), MB_field.max())
|
|
||||||
print('vel', v.min(), v.max())
|
|
||||||
|
|
||||||
quit()
|
|
||||||
|
|
||||||
tracer_vel = projection.interp_field(v,
|
tracer_vel = projection.interp_field(v,
|
||||||
MB_pos,
|
MB_pos,
|
||||||
|
@ -421,6 +434,43 @@ def vel2like(cz_obs, v, MB_field, MB_pos, r, r_hMpc, sig_mu, sig_v, omega_m, muA
|
||||||
lkl = - lkl_ind.sum()
|
lkl = - lkl_ind.sum()
|
||||||
|
|
||||||
return lkl
|
return lkl
|
||||||
|
|
||||||
|
|
||||||
|
@derive_prior
|
||||||
|
def transform_mua(x):
|
||||||
|
a, b = model_params_prior['mua']
|
||||||
|
return samplers.transform_uniform(x, a, b)
|
||||||
|
|
||||||
|
def inv_transform_mua(alpha):
|
||||||
|
a, b = model_params_prior['mua']
|
||||||
|
return samplers.inv_transform_uniform(alpha, a, b)
|
||||||
|
|
||||||
|
@derive_prior
|
||||||
|
def transform_alpha(x):
|
||||||
|
a, b = model_params_prior['alpha']
|
||||||
|
return samplers.transform_uniform(x, a, b)
|
||||||
|
|
||||||
|
def inv_transform_alpha(alpha):
|
||||||
|
a, b = model_params_prior['alpha']
|
||||||
|
return samplers.inv_transform_uniform(alpha, a, b)
|
||||||
|
|
||||||
|
@derive_prior
|
||||||
|
def transform_sig_v(x):
|
||||||
|
a, b = model_params_prior['sig_v']
|
||||||
|
return samplers.transform_uniform(x, a, b)
|
||||||
|
|
||||||
|
def inv_transform_sig_v(alpha):
|
||||||
|
a, b = model_params_prior['sig_v']
|
||||||
|
return samplers.inv_transform_uniform(alpha, a, b)
|
||||||
|
|
||||||
|
@derive_prior
|
||||||
|
def transform_bulk_flow(x):
|
||||||
|
a, b = model_params_prior['bulk_flow']
|
||||||
|
return samplers.transform_uniform(x, a, b)
|
||||||
|
|
||||||
|
def inv_transform_bulk_flow(alpha):
|
||||||
|
a, b = model_params_prior['bulk_flow']
|
||||||
|
return samplers.inv_transform_uniform(alpha, a, b)
|
||||||
|
|
||||||
|
|
||||||
@borg.registerGravityBuilder
|
@borg.registerGravityBuilder
|
||||||
|
@ -451,6 +501,10 @@ def build_gravity_model(state: borg.likelihood.MarkovState, box: borg.forward.Bo
|
||||||
config.read(ini_file)
|
config.read(ini_file)
|
||||||
ai = float(config['model']['ai'])
|
ai = float(config['model']['ai'])
|
||||||
af = float(config['model']['af'])
|
af = float(config['model']['af'])
|
||||||
|
supersampling = int(config['model']['supersampling'])
|
||||||
|
|
||||||
|
if config['model']['gravity'] in ['pm', 'cola']:
|
||||||
|
forcesampling = int(config['model']['forcesampling'])
|
||||||
|
|
||||||
# Setup forward model
|
# Setup forward model
|
||||||
chain = borg.forward.ChainForwardModel(box)
|
chain = borg.forward.ChainForwardModel(box)
|
||||||
|
@ -470,7 +524,7 @@ def build_gravity_model(state: borg.likelihood.MarkovState, box: borg.forward.Bo
|
||||||
opts=dict(a_initial=af,
|
opts=dict(a_initial=af,
|
||||||
a_final=af,
|
a_final=af,
|
||||||
do_rsd=False,
|
do_rsd=False,
|
||||||
supersampling=1,
|
supersampling=supersampling,
|
||||||
lightcone=False,
|
lightcone=False,
|
||||||
part_factor=1.01,))
|
part_factor=1.01,))
|
||||||
elif config['model']['gravity'] == '2lpt':
|
elif config['model']['gravity'] == '2lpt':
|
||||||
|
@ -479,7 +533,7 @@ def build_gravity_model(state: borg.likelihood.MarkovState, box: borg.forward.Bo
|
||||||
opts=dict(a_initial=af,
|
opts=dict(a_initial=af,
|
||||||
a_final=af,
|
a_final=af,
|
||||||
do_rsd=False,
|
do_rsd=False,
|
||||||
supersampling=1,
|
supersampling=supersampling,
|
||||||
lightcone=False,
|
lightcone=False,
|
||||||
part_factor=1.01,))
|
part_factor=1.01,))
|
||||||
elif config['model']['gravity'] == 'pm':
|
elif config['model']['gravity'] == 'pm':
|
||||||
|
@ -488,9 +542,9 @@ def build_gravity_model(state: borg.likelihood.MarkovState, box: borg.forward.Bo
|
||||||
opts=dict(a_initial=af,
|
opts=dict(a_initial=af,
|
||||||
a_final=af,
|
a_final=af,
|
||||||
do_rsd=False,
|
do_rsd=False,
|
||||||
supersampling=1,
|
supersampling=supersampling,
|
||||||
part_factor=1.01,
|
part_factor=1.01,
|
||||||
forcesampling=2,
|
forcesampling=forcesampling,
|
||||||
pm_start_z=1/ai - 1,
|
pm_start_z=1/ai - 1,
|
||||||
pm_nsteps=int(config['model']['nsteps']),
|
pm_nsteps=int(config['model']['nsteps']),
|
||||||
tcola=False))
|
tcola=False))
|
||||||
|
@ -500,9 +554,9 @@ def build_gravity_model(state: borg.likelihood.MarkovState, box: borg.forward.Bo
|
||||||
opts=dict(a_initial=af,
|
opts=dict(a_initial=af,
|
||||||
a_final=af,
|
a_final=af,
|
||||||
do_rsd=False,
|
do_rsd=False,
|
||||||
supersampling=1,
|
supersampling=supersampling,
|
||||||
part_factor=1.01,
|
part_factor=1.01,
|
||||||
forcesampling=2,
|
forcesampling=forcesampling,
|
||||||
pm_start_z=1/ai - 1,
|
pm_start_z=1/ai - 1,
|
||||||
pm_nsteps=int(config['model']['nsteps']),
|
pm_nsteps=int(config['model']['nsteps']),
|
||||||
tcola=True))
|
tcola=True))
|
||||||
|
@ -531,8 +585,7 @@ def build_gravity_model(state: borg.likelihood.MarkovState, box: borg.forward.Bo
|
||||||
if velmodel_name == 'LinearModel':
|
if velmodel_name == 'LinearModel':
|
||||||
fwd_vel = velmodel(box, mod, af)
|
fwd_vel = velmodel(box, mod, af)
|
||||||
elif velmodel_name == 'CICModel':
|
elif velmodel_name == 'CICModel':
|
||||||
rsmooth = float(config['model']['rsmooth'])
|
rsmooth = float(config['model']['rsmooth']) # Mpc/h
|
||||||
print("I AM USING RSMOOTH", rsmooth)
|
|
||||||
fwd_vel = velmodel(box, mod, rsmooth)
|
fwd_vel = velmodel(box, mod, rsmooth)
|
||||||
else:
|
else:
|
||||||
fwd_vel = velmodel(box, mod)
|
fwd_vel = velmodel(box, mod)
|
||||||
|
@ -543,6 +596,7 @@ _glob_model = None
|
||||||
_glob_cosmo = None
|
_glob_cosmo = None
|
||||||
begin_model = None
|
begin_model = None
|
||||||
begin_cosmo = None
|
begin_cosmo = None
|
||||||
|
model_params_prior = {}
|
||||||
|
|
||||||
def check_model_sampling(loop):
|
def check_model_sampling(loop):
|
||||||
return loop.getStepID() > begin_model
|
return loop.getStepID() > begin_model
|
||||||
|
@ -571,7 +625,7 @@ def build_sampler(
|
||||||
List of samplers to use.
|
List of samplers to use.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
global _glob_model, _glob_cosmo, begin_model, begin_cosmo
|
global _glob_model, _glob_cosmo, begin_model, begin_cosmo, model_params_prior
|
||||||
borg.print_msg(borg.Level.std, "Hello sampler, loop is {l}, step_id={s}", l=loop, s=loop.getStepID())
|
borg.print_msg(borg.Level.std, "Hello sampler, loop is {l}, step_id={s}", l=loop, s=loop.getStepID())
|
||||||
|
|
||||||
myprint("Building sampler")
|
myprint("Building sampler")
|
||||||
|
@ -606,7 +660,6 @@ def build_sampler(
|
||||||
to_sample.remove(p)
|
to_sample.remove(p)
|
||||||
|
|
||||||
begin_cosmo = int(config['mcmc']['warmup_cosmo'])
|
begin_cosmo = int(config['mcmc']['warmup_cosmo'])
|
||||||
|
|
||||||
|
|
||||||
if len(params) > 0:
|
if len(params) > 0:
|
||||||
myprint('Adding cosmological parameter sampler')
|
myprint('Adding cosmological parameter sampler')
|
||||||
|
@ -623,6 +676,9 @@ def build_sampler(
|
||||||
params = []
|
params = []
|
||||||
initial_values = {}
|
initial_values = {}
|
||||||
prior = {}
|
prior = {}
|
||||||
|
transform_attributes = []
|
||||||
|
inv_transform_attributes = []
|
||||||
|
|
||||||
for p in to_sample:
|
for p in to_sample:
|
||||||
if p in config['prior'].keys():
|
if p in config['prior'].keys():
|
||||||
if p == 'sig_v':
|
if p == 'sig_v':
|
||||||
|
@ -634,6 +690,9 @@ def build_sampler(
|
||||||
prior[p] = np.array([xx if xx != 'inf' else np.inf for xx in x])
|
prior[p] = np.array([xx if xx != 'inf' else np.inf for xx in x])
|
||||||
else:
|
else:
|
||||||
prior[p] = np.array(ast.literal_eval(config['prior'][p]))
|
prior[p] = np.array(ast.literal_eval(config['prior'][p]))
|
||||||
|
model_params_prior[p] = prior[p]
|
||||||
|
transform_attributes.append(globals().get(f'transform_{p}'))
|
||||||
|
inv_transform_attributes.append(globals().get(f'inv_transform_{p}'))
|
||||||
elif p == 'bulk_flow':
|
elif p == 'bulk_flow':
|
||||||
for i, d in enumerate(['_x', '_y', '_z']):
|
for i, d in enumerate(['_x', '_y', '_z']):
|
||||||
myprint(f'Adding {p}{d} sampler')
|
myprint(f'Adding {p}{d} sampler')
|
||||||
|
@ -644,6 +703,9 @@ def build_sampler(
|
||||||
prior[f'{p}{d}'] = np.array([xx if xx != 'inf' else np.inf for xx in x])
|
prior[f'{p}{d}'] = np.array([xx if xx != 'inf' else np.inf for xx in x])
|
||||||
else:
|
else:
|
||||||
prior[f'{p}{d}'] = np.array(ast.literal_eval(config['prior'][p]))
|
prior[f'{p}{d}'] = np.array(ast.literal_eval(config['prior'][p]))
|
||||||
|
transform_attributes.append(globals().get(f'transform_{p}'))
|
||||||
|
inv_transform_attributes.append(globals().get(f'inv_transform_{p}'))
|
||||||
|
model_params_prior[p] = prior[f'{p}_x']
|
||||||
else:
|
else:
|
||||||
for i in range(nsamp):
|
for i in range(nsamp):
|
||||||
myprint(f'Adding {p}{i} sampler')
|
myprint(f'Adding {p}{i} sampler')
|
||||||
|
@ -654,24 +716,39 @@ def build_sampler(
|
||||||
prior[f'{p}{i}'] = np.array([xx if xx != 'inf' else np.inf for xx in x])
|
prior[f'{p}{i}'] = np.array([xx if xx != 'inf' else np.inf for xx in x])
|
||||||
else:
|
else:
|
||||||
prior[f'{p}{i}'] = np.array(ast.literal_eval(config['prior'][p]))
|
prior[f'{p}{i}'] = np.array(ast.literal_eval(config['prior'][p]))
|
||||||
|
transform_attributes.append(globals().get(f'transform_{p}'))
|
||||||
|
inv_transform_attributes.append(globals().get(f'inv_transform_{p}'))
|
||||||
|
model_params_prior[p] = prior[f'{p}{0}']
|
||||||
else:
|
else:
|
||||||
s = f'Could not find {p} prior, so will not sample'
|
s = f'Could not find {p} prior, so will not sample'
|
||||||
warnings.warn(s, stacklevel=2)
|
warnings.warn(s, stacklevel=2)
|
||||||
|
|
||||||
begin_model = int(config['mcmc']['warmup_model'])
|
begin_model = int(config['mcmc']['warmup_model'])
|
||||||
|
|
||||||
|
|
||||||
if len(params) > 0:
|
if len(params) > 0:
|
||||||
myprint('Adding model parameter sampler')
|
myprint('Adding model parameter sampler')
|
||||||
model_sampler = borg.samplers.ModelParamsSampler(prefix, params, likelihood, fwd_param, initial_values, prior)
|
if config['sampling']['algorithm'].lower() == 'slice':
|
||||||
|
model_sampler = borg.samplers.ModelParamsSampler(prefix, params, likelihood, fwd_param, initial_values, prior)
|
||||||
|
elif config['sampling']['algorithm'].lower() == 'hmc':
|
||||||
|
model_sampler = HMCBiasSampler(
|
||||||
|
prefix,
|
||||||
|
likelihood,
|
||||||
|
params,
|
||||||
|
transform_attributes=transform_attributes,
|
||||||
|
inv_transform_attributes=inv_transform_attributes,
|
||||||
|
prior = [p.prior for p in transform_attributes],
|
||||||
|
Nsteps = int(config['sampling']['nsteps']),
|
||||||
|
epsilon = float(config['sampling']['epsilon']),
|
||||||
|
refresh = float(config['sampling']['refresh']),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
model_sampler.setName("model_sampler")
|
model_sampler.setName("model_sampler")
|
||||||
_glob_model = model_sampler
|
_glob_model = model_sampler
|
||||||
loop.push(model_sampler)
|
loop.push(model_sampler)
|
||||||
all_sampler.append(model_sampler)
|
all_sampler.append(model_sampler)
|
||||||
loop.addToConditionGroup("warmup_model", "model_sampler")
|
loop.addToConditionGroup("warmup_model", "model_sampler")
|
||||||
loop.addConditionToConditionGroup("warmup_model", partial(check_model_sampling, loop))
|
loop.addConditionToConditionGroup("warmup_model", partial(check_model_sampling, loop))
|
||||||
|
|
||||||
print('Warmups:', begin_cosmo, begin_model)
|
|
||||||
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,6 @@ import configparser
|
||||||
import ast
|
import ast
|
||||||
|
|
||||||
import borg_velocity.utils as utils
|
import borg_velocity.utils as utils
|
||||||
from borg_velocity.utils import myprint
|
|
||||||
import borg_velocity.forwards as forwards
|
import borg_velocity.forwards as forwards
|
||||||
import borg_velocity.poisson_process as poisson_process
|
import borg_velocity.poisson_process as poisson_process
|
||||||
import borg_velocity.projection as projection
|
import borg_velocity.projection as projection
|
||||||
|
@ -62,16 +61,6 @@ def borg_mock(s_hat, state, fwd_model, fwd_vel, ini_file, seed=None):
|
||||||
cosmo = utils.get_cosmopar(ini_file)
|
cosmo = utils.get_cosmopar(ini_file)
|
||||||
output_vel = fwd_vel.getVelocityField()
|
output_vel = fwd_vel.getVelocityField()
|
||||||
|
|
||||||
nzero = np.sum(output_density == -1)
|
|
||||||
ntot = np.prod(output_density.shape)
|
|
||||||
print(f'Number of zeros: {nzero} of {ntot} - {nzero/ntot * 100} percent')
|
|
||||||
|
|
||||||
print("\nMy velocity field", output_vel.min(), output_vel.max(), '\n')
|
|
||||||
nnan = np.sum(np.isnan(output_vel))
|
|
||||||
ntot = np.prod(output_vel.shape)
|
|
||||||
print(f'Number of nans: {nnan} of {ntot} - {nnan/ntot * 100} percent')
|
|
||||||
quit()
|
|
||||||
|
|
||||||
# Add bulk flow
|
# Add bulk flow
|
||||||
bulk_flow = np.array(ast.literal_eval(config['model']['bulk_flow']))
|
bulk_flow = np.array(ast.literal_eval(config['model']['bulk_flow']))
|
||||||
output_vel = output_vel + bulk_flow.reshape((3, 1, 1, 1))
|
output_vel = output_vel + bulk_flow.reshape((3, 1, 1, 1))
|
||||||
|
|
206
borg_velocity/samplers.py
Normal file
|
@ -0,0 +1,206 @@
|
||||||
|
import numpy as np
|
||||||
|
import aquila_borg as borg
|
||||||
|
from typing import List
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import jax
|
||||||
|
from borg_velocity.utils import myprint
|
||||||
|
|
||||||
|
|
||||||
|
class HMCBiasSampler(borg.samplers.PyBaseSampler):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefix: str,
|
||||||
|
likelihood: borg.likelihood.BaseLikelihood,
|
||||||
|
attributes: List[str],
|
||||||
|
transform_attributes: List[object],
|
||||||
|
inv_transform_attributes: List[object],
|
||||||
|
prior: List[object],
|
||||||
|
algo="DKD",
|
||||||
|
refresh: float=0.1,
|
||||||
|
Nsteps: int=50,
|
||||||
|
epsilon: float=0.001,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.likelihood = likelihood
|
||||||
|
self.attributes = attributes
|
||||||
|
self.prefix = prefix
|
||||||
|
self.transform_attributes = transform_attributes
|
||||||
|
self.inv_transform_attributes = inv_transform_attributes
|
||||||
|
self.prior_attributes = prior
|
||||||
|
|
||||||
|
self.refresh = refresh
|
||||||
|
self.p_old = None
|
||||||
|
self.Nsteps = Nsteps
|
||||||
|
self.epsilon = epsilon
|
||||||
|
|
||||||
|
# Drift-Kick-Drift
|
||||||
|
if algo == "DKD":
|
||||||
|
self.coefs = [[0, 0.5], [1, 0.5]]
|
||||||
|
# Kick-Drift-Kick
|
||||||
|
elif algo == "KDK":
|
||||||
|
self.coefs = [[0.5, 1], [0.5, 0]]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unknown algo {algo=}")
|
||||||
|
|
||||||
|
def initialize(self, state: borg.likelihood.MarkovState):
|
||||||
|
self.restore(state)
|
||||||
|
for i in range(len(self.attributes)):
|
||||||
|
attr = self.likelihood.fwd_param.getModelParam(
|
||||||
|
'mod_null',
|
||||||
|
self.attributes[i]
|
||||||
|
)
|
||||||
|
self.y_real[i] = attr
|
||||||
|
if isinstance(attr, float):
|
||||||
|
attr = jnp.array(attr)
|
||||||
|
if self.inv_transform_attributes[i]:
|
||||||
|
self.y[i] = self.inv_transform_attributes[i](attr).item()
|
||||||
|
else:
|
||||||
|
self.y[i] = attr.item()
|
||||||
|
|
||||||
|
myprint(f"initial {self.y=}")
|
||||||
|
|
||||||
|
def _update_attrs(self, x):
|
||||||
|
for i in range(len(self.attributes)):
|
||||||
|
xnew = x[i]
|
||||||
|
if self.transform_attributes[i]:
|
||||||
|
xnew = self.transform_attributes[i](xnew)
|
||||||
|
self.likelihood.model_params[self.attributes[i]] = xnew
|
||||||
|
|
||||||
|
def _likelihood(self, x_hat, x, skip_density=False):
|
||||||
|
self._update_attrs(x)
|
||||||
|
return self.likelihood.logLikelihoodComplex(
|
||||||
|
x_hat, False, skip_density=skip_density, update_from_model=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _grad_likelihood(self, x_hat, x, skip_density=False):
|
||||||
|
grad_fn = jax.grad(self._likelihood, argnums=1)
|
||||||
|
grad_x = grad_fn(x_hat, x, skip_density=skip_density)
|
||||||
|
return grad_x
|
||||||
|
|
||||||
|
def _prior(self, x):
|
||||||
|
prior_value = 0.0
|
||||||
|
for i in range(len(self.attributes)):
|
||||||
|
if self.prior_attributes[i]:
|
||||||
|
prior_value += self.prior_attributes[i](x[i])
|
||||||
|
return prior_value
|
||||||
|
|
||||||
|
def _grad_prior(self, x):
|
||||||
|
grad_x = jax.grad(self._prior, argnums=0)(x)
|
||||||
|
return grad_x
|
||||||
|
|
||||||
|
|
||||||
|
def sample(self, state: borg.likelihood.MarkovState):
|
||||||
|
myprint(f"Sampling attributes {self.attributes}")
|
||||||
|
x_hat = state["s_hat_field"]
|
||||||
|
|
||||||
|
mass = 1.0
|
||||||
|
inv_mass = 1 / mass
|
||||||
|
|
||||||
|
params = jnp.array(self.y)
|
||||||
|
|
||||||
|
# Initialise the PRNG key
|
||||||
|
key = jax.random.PRNGKey(np.random.randint(0, 1e6))
|
||||||
|
|
||||||
|
# Split the key for independent random number generation
|
||||||
|
key, subkey = jax.random.split(key)
|
||||||
|
p_new = jax.random.normal(key, shape=params.shape) * jnp.sqrt(mass)
|
||||||
|
|
||||||
|
if self.p_old is not None:
|
||||||
|
key, subkey = jax.random.split(key) # Split the key again
|
||||||
|
p = self.p_old * self.refresh + jnp.sqrt(1 - self.refresh**2) * p_new
|
||||||
|
else:
|
||||||
|
p = p_new
|
||||||
|
|
||||||
|
Lstart = (
|
||||||
|
self._likelihood(x_hat, params, skip_density=False)
|
||||||
|
+ self._prior(params)
|
||||||
|
+ 0.5 * jnp.sum(p**2) * inv_mass
|
||||||
|
)
|
||||||
|
|
||||||
|
new_gradient = True
|
||||||
|
myprint('running model parameter hmc ...')
|
||||||
|
for i in range(self.Nsteps):
|
||||||
|
for c in self.coefs:
|
||||||
|
if c[0] != 0:
|
||||||
|
if new_gradient:
|
||||||
|
gradient = self._grad_likelihood(
|
||||||
|
x_hat, params, skip_density=True
|
||||||
|
) + self._grad_prior(params)
|
||||||
|
new_gradient = False
|
||||||
|
p -= self.epsilon * c[0] * gradient
|
||||||
|
if c[1] != 0:
|
||||||
|
params += self.epsilon * c[1] * p * inv_mass
|
||||||
|
new_gradient = True
|
||||||
|
|
||||||
|
Lend = self._likelihood(x_hat, params, skip_density=True)
|
||||||
|
|
||||||
|
deltaH = (
|
||||||
|
Lend - Lstart + self._prior(params) + 0.5 * jnp.sum(p**2) * inv_mass
|
||||||
|
)
|
||||||
|
myprint(f"# deltaH={deltaH}")
|
||||||
|
key, subkey = jax.random.split(key)
|
||||||
|
if jnp.log(jax.random.uniform(subkey)) <= -deltaH:
|
||||||
|
# accept
|
||||||
|
myprint(f"# accepting...")
|
||||||
|
self.y[:] = params
|
||||||
|
for i in range(len(self.attributes)):
|
||||||
|
xnew = params[i]
|
||||||
|
if self.transform_attributes[i]:
|
||||||
|
xnew = self.transform_attributes[i](xnew)
|
||||||
|
self.y_real[i] = xnew
|
||||||
|
|
||||||
|
to_set = {k:v for k, v in zip(self.attributes, self.y_real)}
|
||||||
|
self.likelihood.fwd_param.setModelParams(to_set)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# reject
|
||||||
|
myprint(f"# rejecting...")
|
||||||
|
|
||||||
|
self._update_attrs(params)
|
||||||
|
self.p_old = p
|
||||||
|
|
||||||
|
def restore(self, state: borg.likelihood.MarkovState):
|
||||||
|
|
||||||
|
# Define attribute names
|
||||||
|
attrname_real = f"{self.prefix}attributes"
|
||||||
|
attrname_var = f"{self.prefix}hmc_var"
|
||||||
|
|
||||||
|
# Initialize attributes in the state
|
||||||
|
state.newArray1d(attrname_real, len(self.attributes), True)
|
||||||
|
self.y_real = state[attrname_real]
|
||||||
|
|
||||||
|
state.newArray1d(attrname_var, len(self.attributes), True)
|
||||||
|
self.y = state[attrname_var]
|
||||||
|
|
||||||
|
def _loaded():
|
||||||
|
myprint(f"Reinjecting parameters for bias sampler: {self.attributes}")
|
||||||
|
# Ensure `self.y` is a JAX array
|
||||||
|
self._update_attrs(jnp.array(self.y))
|
||||||
|
|
||||||
|
# Subscribe to updates
|
||||||
|
state.subscribeLoaded(attrname_var, _loaded)
|
||||||
|
|
||||||
|
|
||||||
|
def derive_prior(f):
|
||||||
|
def _prior(x):
|
||||||
|
alpha = f(x)
|
||||||
|
grad_fn = jax.grad(f)
|
||||||
|
grad = grad_fn(x)
|
||||||
|
return -jnp.log(jnp.abs(grad))
|
||||||
|
f.prior = _prior
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
def transform_uniform(x, a, b):
|
||||||
|
"""
|
||||||
|
Take a x value, from ]-inf,inf[ and map onto alpha in ]a,b[
|
||||||
|
"""
|
||||||
|
return jax.nn.sigmoid(x) * (b - a) + a
|
||||||
|
|
||||||
|
|
||||||
|
def inv_transform_uniform(alpha, a, b):
|
||||||
|
"""
|
||||||
|
Take an alpha value, from ]a,b[ and map onto x in ]-inf,inf[
|
||||||
|
"""
|
||||||
|
x = (alpha - a) / (b - a)
|
||||||
|
return - jnp.log(1/x - 1)
|
|
@ -15,23 +15,23 @@ test_mode = true
|
||||||
seed_cpower = true
|
seed_cpower = true
|
||||||
|
|
||||||
[block_loop]
|
[block_loop]
|
||||||
hades_sampler_blocked = false
|
hades_sampler_blocked = true
|
||||||
bias_sampler_blocked= true
|
bias_sampler_blocked= false
|
||||||
nmean_sampler_blocked= true
|
nmean_sampler_blocked= true
|
||||||
sigma8_sampler_blocked = false
|
sigma8_sampler_blocked = true
|
||||||
muA_sampler_blocked = false
|
muA_sampler_blocked = false
|
||||||
omega_m_sampler_blocked = false
|
omega_m_sampler_blocked = true
|
||||||
alpha_sampler_blocked = false
|
alpha_sampler_blocked = false
|
||||||
sig_v_sampler_blocked = false
|
sig_v_sampler_blocked = false
|
||||||
bulk_flow_sampler_blocked = false
|
bulk_flow_sampler_blocked = false
|
||||||
ares_heat = 1.0
|
ares_heat = 1.0
|
||||||
|
|
||||||
[mcmc]
|
[mcmc]
|
||||||
number_to_generate = 10
|
number_to_generate = 10000
|
||||||
warmup_model = 3
|
warmup_model = 0
|
||||||
warmup_cosmo = 7
|
warmup_cosmo = 0
|
||||||
random_ic = false
|
random_ic = false
|
||||||
init_random_scaling = 0.1
|
init_random_scaling = 1.0
|
||||||
bignum = 1e300
|
bignum = 1e300
|
||||||
|
|
||||||
[hades]
|
[hades]
|
||||||
|
@ -40,8 +40,15 @@ max_epsilon = 0.01
|
||||||
max_timesteps = 50
|
max_timesteps = 50
|
||||||
mixing = 1
|
mixing = 1
|
||||||
|
|
||||||
|
[sampling]
|
||||||
|
algorithm = HMC
|
||||||
|
epsilon = 0.001
|
||||||
|
Nsteps = 20
|
||||||
|
refresh = 0.1
|
||||||
|
|
||||||
[model]
|
[model]
|
||||||
gravity = lpt
|
gravity = lpt
|
||||||
|
supersampling = 2
|
||||||
velocity = CICModel
|
velocity = CICModel
|
||||||
af = 1.0
|
af = 1.0
|
||||||
ai = 0.05
|
ai = 0.05
|
||||||
|
@ -49,7 +56,7 @@ nsteps = 20
|
||||||
smooth_R = 4
|
smooth_R = 4
|
||||||
bias_epsilon = 1e-7
|
bias_epsilon = 1e-7
|
||||||
interp_order = 1
|
interp_order = 1
|
||||||
rsmooth = 100.
|
rsmooth = 4.
|
||||||
sig_v = 150.
|
sig_v = 150.
|
||||||
R_lim = none
|
R_lim = none
|
||||||
Nint_points = 201
|
Nint_points = 201
|
||||||
|
@ -68,7 +75,7 @@ bulk_flow = [-200.0, 200.0]
|
||||||
omega_r = 0
|
omega_r = 0
|
||||||
fnl = 0
|
fnl = 0
|
||||||
omega_k = 0
|
omega_k = 0
|
||||||
omega_m = 0.315
|
omega_m = 0.335
|
||||||
omega_b = 0.049
|
omega_b = 0.049
|
||||||
omega_q = 0.685
|
omega_q = 0.685
|
||||||
h100 = 0.68
|
h100 = 0.68
|
||||||
|
@ -85,7 +92,7 @@ NCAT = 0
|
||||||
NSAMP = 2
|
NSAMP = 2
|
||||||
|
|
||||||
[mock]
|
[mock]
|
||||||
seed = 123
|
seed = 1234
|
||||||
R_max = 100
|
R_max = 100
|
||||||
|
|
||||||
[python]
|
[python]
|
||||||
|
|
Before Width: | Height: | Size: 15 KiB After Width: | Height: | Size: 14 KiB |
BIN
figs/gradient_test_16.png
Normal file
After Width: | Height: | Size: 83 KiB |
Before Width: | Height: | Size: 92 KiB After Width: | Height: | Size: 91 KiB |
Before Width: | Height: | Size: 17 KiB After Width: | Height: | Size: 18 KiB |
Before Width: | Height: | Size: 15 KiB After Width: | Height: | Size: 15 KiB |
Before Width: | Height: | Size: 13 KiB After Width: | Height: | Size: 15 KiB |
Before Width: | Height: | Size: 13 KiB After Width: | Height: | Size: 15 KiB |
29
scripts/build_borg.sh
Executable file
|
@ -0,0 +1,29 @@
|
||||||
|
#!/bin/sh
|
||||||
|
|
||||||
|
# Modules
|
||||||
|
module purge
|
||||||
|
module restore myborg
|
||||||
|
module load cuda/12.3
|
||||||
|
|
||||||
|
# Environment
|
||||||
|
source /home/bartlett/.bashrc
|
||||||
|
source /home/bartlett/anaconda3/etc/profile.d/conda.sh
|
||||||
|
conda deactivate
|
||||||
|
conda activate borg_env
|
||||||
|
|
||||||
|
# Kill job if there are any errors
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# Path variables
|
||||||
|
BORG_DIR=/home/bartlett/borg
|
||||||
|
BUILD_DIR=/data101/bartlett/build_borg/
|
||||||
|
|
||||||
|
cd $BORG_DIR
|
||||||
|
# git pull
|
||||||
|
# rm -rf $BUILD_DIR
|
||||||
|
bash build.sh --c-compiler $(which x86_64-conda_cos6-linux-gnu-gcc) --cxx-compiler $(which x86_64-conda_cos6-linux-gnu-g++) --python --hades-python --build-dir $BUILD_DIR
|
||||||
|
cd $BUILD_DIR
|
||||||
|
make -j
|
||||||
|
|
||||||
|
|
||||||
|
exit 0
|
|
@ -15,7 +15,7 @@ conda activate borg_env
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
# Path variables
|
# Path variables
|
||||||
BORG=/home/bartlett/anaconda3/envs/borg_env/bin/hades_python
|
BORG=/data101/bartlett/build_borg/tools/hades_python/hades_python
|
||||||
RUN_DIR=/data101/bartlett/fsigma8/borg_velocity/basic_run
|
RUN_DIR=/data101/bartlett/fsigma8/borg_velocity/basic_run
|
||||||
|
|
||||||
mkdir -p $RUN_DIR
|
mkdir -p $RUN_DIR
|
||||||
|
|
47
scripts/submit_borg.sh
Executable file
|
@ -0,0 +1,47 @@
|
||||||
|
#!/bin/sh
|
||||||
|
#PBS -S /bin/sh
|
||||||
|
#PBS -N model_hmc_test
|
||||||
|
#PBS -j oe
|
||||||
|
#PBS -m ae
|
||||||
|
#PBS -l nodes=h13:has1gpu:ppn=40,walltime=24:00:00
|
||||||
|
|
||||||
|
# Modules
|
||||||
|
module purge
|
||||||
|
module restore myborg
|
||||||
|
module load cuda/12.3
|
||||||
|
|
||||||
|
# Environment
|
||||||
|
source /home/bartlett/.bashrc
|
||||||
|
source /home/bartlett/anaconda3/etc/profile.d/conda.sh
|
||||||
|
conda deactivate
|
||||||
|
conda activate borg_env
|
||||||
|
|
||||||
|
# Kill job if there are any errors
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# Path variables
|
||||||
|
BORG=/data101/bartlett/build_borg/tools/hades_python/hades_python
|
||||||
|
RUN_DIR=/data101/bartlett/fsigma8/borg_velocity/basic_run
|
||||||
|
|
||||||
|
mkdir -p $RUN_DIR
|
||||||
|
cd $RUN_DIR
|
||||||
|
|
||||||
|
# Create a custom file descriptor (3) for tracing
|
||||||
|
exec 3>$RUN_DIR/trace_file.txt
|
||||||
|
|
||||||
|
# Redirect trace output to the custom file descriptor
|
||||||
|
BASH_XTRACEFD="3"
|
||||||
|
set -x
|
||||||
|
|
||||||
|
# Run BORG
|
||||||
|
INI_FILE=/home/bartlett/fsigma8/borg_velocity/conf/basic_ini.ini
|
||||||
|
cp $INI_FILE basic_ini.ini
|
||||||
|
$BORG INIT basic_ini.ini
|
||||||
|
|
||||||
|
conda deactivate
|
||||||
|
|
||||||
|
# Disable tracing and close the custom file descriptor
|
||||||
|
set +x
|
||||||
|
exec 3>&-
|
||||||
|
|
||||||
|
exit 0
|
|
@ -1,24 +0,0 @@
|
||||||
Memory still allocated at the end: 34.75 MB
|
|
||||||
|
|
||||||
Statistics per context (name, allocated, freed, peak)
|
|
||||||
======================
|
|
||||||
|
|
||||||
*none* 32.25 410.232 170.482
|
|
||||||
BORG LPT MODEL 113.12 0 348.602
|
|
||||||
BORGForwardModel::setup 0.000183105 0 137.982
|
|
||||||
BorgLptModel::BorgLptModel 96.75 0 137.982
|
|
||||||
BorgLptModel::~BorgLptModel 0 129.28 0
|
|
||||||
CICModel::getVelocityField 240.75 48.375 621.512
|
|
||||||
CICModel::getVelocityFieldAlpha 146.25 145.125 638.762
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/chain_forward_model.cpp]virtual void LibLSS::ChainForwardModel::forwardModel_v2(LibLSS::detail_input::ModelInput<3>) 65 16.25 235.482
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/forward_model.cpp]void LibLSS::BORGForwardModel::setupDefault() 32.5 0 41.232
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/forwards/borg_lpt.cpp]std::shared_ptr<LibLSS::BORGForwardModel> build_borg_lpt(std::shared_ptr<LibLSS::MPI_Communication>, const LibLSS::BoxModel&, const LibLSS::PropertyProxy&) [with Grid = LibLSS::ClassicCloudInCell<double>; LibLSS::BoxModel = LibLSS::NBoxModel<3>] 0 0 8.73196
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/forwards/particle_balancer/balanceinfo.hpp]void LibLSS::BalanceInfo::allocate(LibLSS::MPI_Communication*, size_t) 48.16 0 380.762
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/forwards/primordial_as.cpp]std::shared_ptr<LibLSS::BORGForwardModel> build_primordial_as(std::shared_ptr<LibLSS::MPI_Communication>, const LibLSS::BoxModel&, const LibLSS::PropertyProxy&) 4.36594 7.62939e-06 4.36601
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/forwards/transfer_class.cpp]std::shared_ptr<LibLSS::BORGForwardModel> build_class(std::shared_ptr<LibLSS::MPI_Communication>, const LibLSS::BoxModel&, const LibLSS::PropertyProxy&) 4.36594 7.62939e-06 8.73197
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/samplers/core/gridLikelihoodBase.cpp]LibLSS::GridDensityLikelihoodBase<Dims>::GridDensityLikelihoodBase(LibLSS::MPI_Communication*, const GridSizes&, const GridLengths&) [with int Dims = 3; GridSizes = std::array<long unsigned int, 3>; GridLengths = std::array<double, 3>] 32.5 32.25 170.482
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/tools/mpi/ghost_planes.hpp]void LibLSS::GhostPlanes<T, Nd>::setup(LibLSS::MPI_Communication*, PlaneList&&, PlaneSet&&, DimList&&, size_t) [with PlaneList = std::set<long int>&; PlaneSet = std::set<long int>&; DimList = std::array<long int, 2>&; T = std::complex<double>; long unsigned int Nd = 2; size_t = long unsigned int] 7.62939e-06 0.000495911 5.34058e-05
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/python/pyforward.cpp]void do_get_density_final(LibLSS::BORGForwardModel*, pybind11::array) 16.25 16 364.762
|
|
||||||
dispatch_plane_map 0.000988007 0.000499725 0.00104141
|
|
||||||
exchanging nearby planes after projection 0.5 0.5 606.387
|
|
||||||
lpt_ic 16.25 16.25 381.012
|
|
|
@ -1,44 +0,0 @@
|
||||||
(fftw-3.3.10 fftw_wisdom #x3c273403 #x192df114 #x4d08727c #xe98e9b9d
|
|
||||||
(fftw_rdft2_rank_geq2_register 0 #x10bdd #x10bdd #x0 #x5ee9deaf #x803496c0 #xdcaa85b6 #x11dfd168)
|
|
||||||
(fftw_codelet_n1bv_128_avx 0 #x10bdd #x10bdd #x0 #x0dab20d0 #x841c928d #x8f95a044 #x8648861a)
|
|
||||||
(fftw_rdft2_thr_vrank_geq1_register 0 #x10bdd #x10bdd #x0 #x27ed4958 #x7b810169 #x868b075a #xbdfeccde)
|
|
||||||
(fftw_rdft2_thr_vrank_geq1_register 0 #x10bdd #x10bdd #x0 #xf7cc2e58 #xe8598af0 #xccbe32f3 #x01416c7a)
|
|
||||||
(fftw_rdft2_rank_geq2_register 0 #x10bdd #x10bdd #x0 #xfc62f116 #x4b25ceff #xf057784c #x1ea58257)
|
|
||||||
(fftw_codelet_r2cf_4 2 #x10bdd #x10bdd #x0 #x1a0eadc5 #x01b966bd #xba3ec0cb #x3bbf7704)
|
|
||||||
(fftw_rdft2_vrank_geq1_register 0 #x10bdd #x10bdd #x0 #xbf003229 #x49df7e30 #xb57b18ee #xbdc8f696)
|
|
||||||
(fftw_codelet_n1fv_128_avx 0 #x10bdd #x10bdd #x0 #xf1051924 #x3cb53423 #xbc2120c2 #xdee4d17c)
|
|
||||||
(fftw_codelet_r2cbIII_2 2 #x10bdd #x10bdd #x0 #xee2cb083 #x182690ea #x841aa951 #xd87b9709)
|
|
||||||
(fftw_rdft2_rank_geq2_register 0 #x10bdd #x10bdd #x0 #xd18676aa #xa490b2be #x8a110381 #x0674ed90)
|
|
||||||
(fftw_rdft2_vrank_geq1_register 0 #x10bdd #x10bdd #x0 #x5863400c #x3e549e36 #x70702991 #x4cb1a8b4)
|
|
||||||
(fftw_dft_nop_register 0 #x10bdd #x10bdd #x0 #xefadd2cf #x2599b1f7 #xf3609948 #x6f687984)
|
|
||||||
(fftw_codelet_t2bv_4_avx 0 #x10bdd #x10bdd #x0 #x222fe56a #x400e7062 #xaa0102ab #xdac525c7)
|
|
||||||
(fftw_rdft2_rank_geq2_register 0 #x10bdd #x10bdd #x0 #xb5717c43 #xea8d28e7 #x773b72ad #xab4f7e98)
|
|
||||||
(fftw_dft_thr_vrank_geq1_register 0 #x10bdd #x10bdd #x0 #x74a4b38e #x3a9542bb #x7c53e952 #xd047b5ac)
|
|
||||||
(fftw_codelet_n1fv_128_avx 0 #x10bdd #x10bdd #x0 #x9ca2bad9 #x6195e6cd #xf7a9bd69 #x1a6adf98)
|
|
||||||
(fftw_rdft2_vrank_geq1_register 0 #x10bdd #x10bdd #x0 #x626328cb #x27898515 #xac7524f5 #xe381e924)
|
|
||||||
(fftw_codelet_hc2cbdftv_2_avx 0 #x10bdd #x10bdd #x0 #xbcdfd9e5 #xf7782660 #x8651ac2f #xd4223c63)
|
|
||||||
(fftw_codelet_n2fv_32_avx 0 #x10bdd #x10bdd #x0 #x58b9641c #x9e30bdf8 #xf0f47ae8 #x7655e569)
|
|
||||||
(fftw_codelet_n1bv_128_avx 0 #x10bdd #x10bdd #x0 #x466c1810 #xae9f2666 #xc4baf3f0 #xab78a00c)
|
|
||||||
(fftw_rdft2_thr_vrank_geq1_register 0 #x10bdd #x10bdd #x0 #x0b420fb7 #x88eec800 #x18e6f77b #xee6f47cc)
|
|
||||||
(fftw_codelet_n1bv_128_avx 0 #x10bdd #x10bdd #x0 #xfaefc863 #x3a3ba8fc #x031c855a #x015b826a)
|
|
||||||
(fftw_rdft2_rank_geq2_register 0 #x10bdd #x10bdd #x0 #x2b1f5925 #xf71b5dfe #xe9835918 #xebd9e013)
|
|
||||||
(fftw_rdft2_thr_vrank_geq1_register 0 #x10bdd #x10bdd #x0 #xbd602caa #x29e01361 #xedd1a545 #x43ec5754)
|
|
||||||
(fftw_codelet_r2cfII_4 2 #x10bdd #x10bdd #x0 #xefbb9b83 #x7a1273f7 #x44738c83 #x09e09d33)
|
|
||||||
(fftw_codelet_hc2cfdftv_4_avx 0 #x10bdd #x10bdd #x0 #x181d5f02 #xe52fa6f8 #x2c40517b #x5c7e639e)
|
|
||||||
(fftw_rdft2_rank_geq2_register 0 #x10bdd #x10bdd #x0 #xac037af4 #x0cfa848c #x57252557 #xda3837a4)
|
|
||||||
(fftw_dft_thr_vrank_geq1_register 0 #x10bdd #x10bdd #x0 #xa0f80596 #x3d1d8059 #xce975432 #x8f20aeca)
|
|
||||||
(fftw_codelet_r2cb_2 2 #x10bdd #x10bdd #x0 #x99718546 #x87fcd081 #x11825ef6 #x9398cd27)
|
|
||||||
(fftw_dft_nop_register 0 #x10bdd #x10bdd #x0 #x54cdef15 #xc5bb0b98 #x2bd11131 #x922134d4)
|
|
||||||
(fftw_rdft_rank0_register 3 #x10bdd #x10bdd #x0 #xec8516b0 #xa88a37cc #x2870b0c6 #xe1bc2e8e)
|
|
||||||
(fftw_dft_vrank_geq1_register 0 #x10bdd #x10bdd #x0 #x007adcc2 #x9232f02f #x8145da1e #x6f71eb23)
|
|
||||||
(fftw_codelet_n2bv_16_avx 0 #x10bdd #x10bdd #x0 #x2bbbefa0 #xe1658c41 #x1edcd3e8 #xfc8538a3)
|
|
||||||
(fftw_rdft2_thr_vrank_geq1_register 0 #x10bdd #x10bdd #x0 #x70aef0dd #xa106222c #x1a3885f9 #x770e380d)
|
|
||||||
(fftw_dft_buffered_register 1 #x10bdd #x10bdd #x0 #x5a8b97d2 #x7beb56fa #xd471ea51 #x958af4a2)
|
|
||||||
(fftw_codelet_n1fv_128_avx 0 #x10bdd #x10bdd #x0 #x4e3cf147 #x8f793c68 #x70b49925 #x52917fe0)
|
|
||||||
(fftw_dft_buffered_register 1 #x10bdd #x10bdd #x0 #x221ab00c #x079c5d40 #x5de8abb4 #xaf85805a)
|
|
||||||
(fftw_dft_r2hc_register 0 #x10bdd #x10bdd #x0 #x6b45fee6 #xba1b589d #x3066e6b5 #x33be97f3)
|
|
||||||
(fftw_rdft2_vrank_geq1_register 0 #x10bdd #x10bdd #x0 #x201d20d9 #x0d2e56ec #xc45b8566 #xa8cd33c7)
|
|
||||||
(fftw_dft_r2hc_register 0 #x10bdd #x10bdd #x0 #xf9968ceb #xd3b2d43e #x0346b437 #xd743c3e1)
|
|
||||||
(fftw_rdft2_thr_vrank_geq1_register 0 #x10bdd #x10bdd #x0 #x4052c515 #x478d7f95 #x5afa91a6 #x4aa8e677)
|
|
||||||
(fftw_dft_vrank_geq1_register 0 #x10bdd #x10bdd #x0 #xeb6d2e88 #x6227e668 #xd7d170d2 #x1daa4985)
|
|
||||||
)
|
|
|
@ -37,12 +37,6 @@ mylike.updateCosmology(cosmo)
|
||||||
s_hat = np.fft.rfftn(np.random.randn(*box_in.N)) / box_in.Ntot ** (0.5)
|
s_hat = np.fft.rfftn(np.random.randn(*box_in.N)) / box_in.Ntot ** (0.5)
|
||||||
mylike.generateMockData(s_hat, state)
|
mylike.generateMockData(s_hat, state)
|
||||||
|
|
||||||
scale = 1.0
|
|
||||||
L = mylike.logLikelihoodComplex(scale * s_hat, None)
|
|
||||||
print(L)
|
|
||||||
|
|
||||||
quit()
|
|
||||||
|
|
||||||
if test_scaling:
|
if test_scaling:
|
||||||
all_scale = np.linspace(0.5, 1.5, 100)
|
all_scale = np.linspace(0.5, 1.5, 100)
|
||||||
all_lkl = np.empty(all_scale.shape)
|
all_lkl = np.empty(all_scale.shape)
|
||||||
|
@ -66,7 +60,7 @@ if test_scaling:
|
||||||
|
|
||||||
# Test sigma8
|
# Test sigma8
|
||||||
if test_sigma8:
|
if test_sigma8:
|
||||||
all_sigma8 = np.linspace(0.5, 1.2, 40)
|
all_sigma8 = np.linspace(0.5, 1.2, 100)
|
||||||
all_lkl = np.empty(all_sigma8.shape)
|
all_lkl = np.empty(all_sigma8.shape)
|
||||||
cosmo_true = mylike.fwd.getCosmoParams()
|
cosmo_true = mylike.fwd.getCosmoParams()
|
||||||
cosmo = mylike.fwd.getCosmoParams()
|
cosmo = mylike.fwd.getCosmoParams()
|
||||||
|
@ -93,7 +87,7 @@ if test_sigma8:
|
||||||
|
|
||||||
# Test sigma8
|
# Test sigma8
|
||||||
if test_omegam:
|
if test_omegam:
|
||||||
all_omegam = np.linspace(0.1, 0.6, 40)
|
all_omegam = np.linspace(0.1, 0.6, 100)
|
||||||
all_lkl = np.empty(all_omegam.shape)
|
all_lkl = np.empty(all_omegam.shape)
|
||||||
cosmo_true = mylike.fwd.getCosmoParams()
|
cosmo_true = mylike.fwd.getCosmoParams()
|
||||||
cosmo = mylike.fwd.getCosmoParams()
|
cosmo = mylike.fwd.getCosmoParams()
|
||||||
|
@ -119,7 +113,7 @@ if test_omegam:
|
||||||
|
|
||||||
# Test bias model
|
# Test bias model
|
||||||
if test_alpha:
|
if test_alpha:
|
||||||
all_alpha = np.linspace(-2.0, 5.0, 50)
|
all_alpha = np.linspace(-2.0, 5.0, 100)
|
||||||
all_lkl = np.empty(all_alpha.shape)
|
all_lkl = np.empty(all_alpha.shape)
|
||||||
for i, alpha in enumerate(all_alpha):
|
for i, alpha in enumerate(all_alpha):
|
||||||
mylike.fwd_param.setModelParams({'alpha0':alpha})
|
mylike.fwd_param.setModelParams({'alpha0':alpha})
|
||||||
|
@ -143,7 +137,7 @@ if test_alpha:
|
||||||
|
|
||||||
# Test bias model
|
# Test bias model
|
||||||
if test_muA:
|
if test_muA:
|
||||||
all_muA = np.linspace(0.95, 1.05, 50)
|
all_muA = np.linspace(0.95, 1.05, 100)
|
||||||
all_lkl = np.empty(all_muA.shape)
|
all_lkl = np.empty(all_muA.shape)
|
||||||
for i, muA in enumerate(all_muA):
|
for i, muA in enumerate(all_muA):
|
||||||
mylike.fwd_param.setModelParams({'mua0':muA})
|
mylike.fwd_param.setModelParams({'mua0':muA})
|
||||||
|
|
|
@ -1,65 +0,0 @@
|
||||||
ARES version 53e8df9fe11c732cda13f7ffc821238622068e57 modules
|
|
||||||
|
|
||||||
Cumulative timing spent in different context
|
|
||||||
--------------------------------------------
|
|
||||||
Context, Total time (seconds)
|
|
||||||
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/forward_model.cpp]void LibLSS::ForwardModel::setCosmoParams(const LibLSS::CosmologicalParameters&) 34 2.57512
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/forwards/borg_lpt.cpp]std::shared_ptr<LibLSS::BORGForwardModel> build_borg_lpt(std::shared_ptr<LibLSS::MPI_Communication>, const LibLSS::BoxModel&, const LibLSS::PropertyProxy&) [with Grid = LibLSS::ClassicCloudInCell<double>; LibLSS::BoxModel = LibLSS::NBoxModel<3>] 1 2.08847
|
|
||||||
BorgLptModel::BorgLptModel 1 2.08788
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/forward_model.cpp]void LibLSS::BORGForwardModel::setupDefault() 1 2.07805
|
|
||||||
FFTW_Manager::create_r2c_plan 7 1.13805
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/forwards/borg_lpt.cpp]void LibLSS::BorgLptModel<CIC>::updateCosmo() [with CIC = LibLSS::ClassicCloudInCell<double>] 5 1.03688
|
|
||||||
lightcone computation 1 1.01288
|
|
||||||
FFTW_Manager::create_c2r_plan 6 0.994615
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/cosmo.cpp]void LibLSS::Cosmology::precompute_d_plus() 1 0.917438
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/forwards/primordial_as.cpp]std::shared_ptr<LibLSS::BORGForwardModel> build_primordial_as(std::shared_ptr<LibLSS::MPI_Communication>, const LibLSS::BoxModel&, const LibLSS::PropertyProxy&) 1 0.385783
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/forwards/transfer_class.cpp]std::shared_ptr<LibLSS::BORGForwardModel> build_class(std::shared_ptr<LibLSS::MPI_Communication>, const LibLSS::BoxModel&, const LibLSS::PropertyProxy&) 1 0.38134
|
|
||||||
CICModel::getVelocityField 1 0.381026
|
|
||||||
Classic CIC projection 5 0.346875
|
|
||||||
CICModel::getVelocityFieldAlpha 3 0.289289
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/forwards/transfer_class.cpp]virtual void LibLSS::ForwardClass::updateCosmo() 5 0.237377
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/class_cosmo.cpp]LibLSS::ClassCosmo::ClassCosmo(const LibLSS::CosmologicalParameters&, unsigned int, double, std::string, unsigned int, const std::map<std::__cxx11::basic_string<char>, std::__cxx11::basic_string<char> >&) 1 0.227272
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/chain_forward_model.cpp]virtual void LibLSS::ChainForwardModel::forwardModel_v2(LibLSS::detail_input::ModelInput<3>) 1 0.110549
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/cosmo.cpp]void LibLSS::Cosmology::precompute_com2a() 1 0.0689217
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/forwards/lpt/borg_fwd_lpt.cpp]void LibLSS::BorgLptModel<CIC>::getDensityFinal(LibLSS::detail_output::ModelOutput<3>) [with CIC = LibLSS::ClassicCloudInCell<double>] 1 0.0646358
|
|
||||||
BORG LPT MODEL 1 0.0386933
|
|
||||||
BORG forward model 1 0.0373443
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/forwards/primordial_as.cpp]virtual void LibLSS::ForwardPrimordial_As::updateCosmo() 10 0.0132038
|
|
||||||
lpt_ic 1 0.0129245
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/forwards/primordial_as.cpp]void LibLSS::ForwardPrimordial_As::updatePower() 5 0.0126693
|
|
||||||
FFTW_Manager::execute_c2r 7 0.00893392
|
|
||||||
FFTW_Manager::execute_r2c 4 0.00713005
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/samplers/core/gridLikelihoodBase.cpp]LibLSS::GridDensityLikelihoodBase<Dims>::GridDensityLikelihoodBase(LibLSS::MPI_Communication*, const GridSizes&, const GridLengths&) [with int Dims = 3; GridSizes = std::array<long unsigned int, 3>; GridLengths = std::array<double, 3>] 1 0.00538298
|
|
||||||
FFTW_Manager::destroy_plan 12 0.00456276
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/tools/hermiticity_fixup.cpp]LibLSS::Hermiticity_fixer<T, Nd>::Hermiticity_fixer(Mgr_p) [with T = double; long unsigned int Nd = 3; Mgr_p = std::shared_ptr<LibLSS::FFTW_Manager<double, 3> >] 1 0.00376522
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/tools/mpi/ghost_planes.hpp]void LibLSS::GhostPlanes<T, Nd>::setup(LibLSS::MPI_Communication*, PlaneList&&, PlaneSet&&, DimList&&, size_t) [with PlaneList = std::set<long int>&; PlaneSet = std::set<long int>&; DimList = std::array<long int, 2>&; T = std::complex<double>; long unsigned int Nd = 2; size_t = long unsigned int] 1 0.0037254
|
|
||||||
dispatch_plane_map 1 0.00349579
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/forwards/particle_balancer/balanceinfo.hpp]void LibLSS::BalanceInfo::allocate(LibLSS::MPI_Communication*, size_t) 2 0.00317642
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/hermitic.hpp]virtual void LibLSS::ForwardHermiticOperation::getDensityFinal(LibLSS::detail_output::ModelOutput<3>) 1 0.00280722
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/python/pyforward.cpp]void transfer_in(std::shared_ptr<LibLSS::FFTW_Manager<double, 3> >&, T&, U&, bool) [with T = boost::multi_array_ref<std::complex<double>, 3>; U = pybind11::detail::unchecked_reference<std::complex<double>, 3>] 1 0.00212912
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/python/pyforward.cpp]void do_get_density_final(LibLSS::BORGForwardModel*, pybind11::array) 1 0.00185894
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/class_cosmo.cpp]void LibLSS::ClassCosmo::retrieve_Tk(double) 2 0.00141643
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/forwards/adapt_generic_bias.cpp]void {anonymous}::bias_registrator() 1 0.00135291
|
|
||||||
BorgLptModel::~BorgLptModel 1 0.00134733
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/class_cosmo.cpp]void LibLSS::ClassCosmo::reinterpolate(const array_ref_1d&, const array_ref_1d&, LibLSS::internal_auto_interp::auto_interpolator<double>&) 6 0.00125166
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/tools/hermiticity_fixup.cpp]void LibLSS::Hermiticity_fixer<T, Nd>::forward(CArrayRef&) [with T = double; long unsigned int Nd = 3; CArrayRef = boost::multi_array_ref<std::complex<double>, 3>] 1 0.000808169
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/tools/hermiticity_fixup.cpp]typename std::enable_if<(Dim != 1), void>::type fix_plane(Mgr&, Ghosts&&, CArray&&, size_t*) [with long unsigned int rank = 0; Mgr = LibLSS::FFTW_Manager<double, 3>; Ghosts = LibLSS::Hermiticity_fixer<double, 3>::forward(CArrayRef&)::<lambda(ssize_t)>; CArray = boost::detail::multi_array::multi_array_view<std::complex<double>, 2>; long unsigned int Dim = 2; typename std::enable_if<(Dim != 1), void>::type = void; size_t = long unsigned int] 2 0.000755922
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/chain_forward_model.cpp]virtual void LibLSS::ChainForwardModel::getDensityFinal(LibLSS::detail_output::ModelOutput<3>) 1 0.000752412
|
|
||||||
exchanging nearby planes after projection 4 0.000681672
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/model_io.cpp]LibLSS::detail_output::ModelOutputBase<Nd, Super>::~ModelOutputBase() [with long unsigned int Nd = 3; Super = LibLSS::detail_model::ModelIO<3>] 17 0.00031809
|
|
||||||
BORGForwardModel::setup 8 0.000195441
|
|
||||||
gather_peer_by_plane 1 0.000147003
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/model_io.cpp]void LibLSS::detail_output::ModelOutputBase<Nd, Super>::transfer(LibLSS::detail_output::ModelOutputBase<Nd, Super>&&) [with long unsigned int Nd = 3; Super = LibLSS::detail_model::ModelIO<3>] 11 0.000116935
|
|
||||||
Initializing peer system 14 7.2966e-05
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/model_io/base.hpp]void LibLSS::detail_model::ModelIO<Nd>::transfer(LibLSS::detail_model::ModelIO<Nd>&&) [with long unsigned int Nd = 3] 29 2.9749e-05
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/model_io.cpp]void LibLSS::detail_output::ModelOutputBase<Nd, Super>::close() [with long unsigned int Nd = 3; Super = LibLSS::detail_model::ModelIO<3>] 17 2.0688e-05
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/forwards/transfer_class.cpp]virtual void LibLSS::ForwardClass::setModelParams(const LibLSS::ModelDictionnary&) 1 1.9715e-05
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/forwards/transfer_class.cpp]virtual void LibLSS::ForwardClass::forwardModel_v2(LibLSS::detail_input::ModelInput<3>) 1 1.7707e-05
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/forwards/primordial_as.cpp]virtual void LibLSS::ForwardPrimordial_As::forwardModel_v2(LibLSS::detail_input::ModelInput<3>) 1 1.6895e-05
|
|
||||||
ghost synchronize 1 1.2845e-05
|
|
||||||
particle distribution 2 8.438e-06
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/model_io.cpp]void LibLSS::detail_output::ModelOutputBase<Nd, Super>::setRequestedIO(LibLSS::PreferredIO) [with long unsigned int Nd = 3; Super = LibLSS::detail_model::ModelIO<3>] 5 6.215e-06
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/model_io.cpp]void LibLSS::detail_input::ModelInputBase<Nd, Super>::setRequestedIO(LibLSS::PreferredIO) [with long unsigned int Nd = 3; Super = LibLSS::detail_model::ModelIO<3>] 4 5.24e-06
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/forward_model.cpp]virtual void LibLSS::ForwardModel::setModelParams(const LibLSS::ModelDictionnary&) 1 2.907e-06
|
|
||||||
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1717878335917/work/libLSS/physics/model_io.cpp]void LibLSS::detail_input::ModelInputBase<Nd, Super>::needDestroyInput() [with long unsigned int Nd = 3; Super = LibLSS::detail_model::ModelIO<3>] 1 2.668e-06
|
|