Generate mock data from borg, create MB data and prepare likelihood

This commit is contained in:
Deaglan Bartlett 2024-04-23 22:43:49 +02:00
parent 874d26e38c
commit 85b93c539f
15 changed files with 1192 additions and 11 deletions

View file

@ -1,8 +1,53 @@
import aquila_borg as borg
import jax
import jax.numpy as jnp
from functools import partial
class NullForward(borg.forward.BaseForwardModel):
"""
BORG forward model which does nothing but stores
the values of parameters to be used by the likelihood
"""
def __init__(self, box: borg.forward.BoxModel) -> None:
"""
Initialise the NullForward class
Args:
box (borg.forward.BoxModel): The input box model.
"""
super().__init__(box, box)
self.setName("nullforward")
self.params = {}
self.setCosmoParams(borg.cosmo.CosmologicalParameters())
cosmo = self.getCosmoParams()
cosmo.n_s = 0.96241
self.setCosmoParams(cosmo)
def setModelParams(self, params: dict) -> None:
"""
Change the values of the model parameters to those given by params
Args:
params (dict): Dictionary of updated model parameters.
"""
for k, v in params.items():
self.params[k] = v
print(" ")
myprint(f'Updated model parameters: {self.params}')
def getModelParam(self, model, keyname: str):
"""
This queries the current state of the parameters keyname in model model.
Args:
model: The model
keyname (str): The name of the parameter of interest
"""
return self.params[keyname]
@partial(jax.jit, static_argnums=(2,3))
def dens2vel_linear(delta, f, Lbox, smooth_R):
"""
@ -44,7 +89,7 @@ def dens2vel_linear(delta, f, Lbox, smooth_R):
update_index_imag_ny = jnp.index_exp[1,N_Z:,:,N_Z-1]
flip_indices = -jnp.arange(N)
flip_indices = flip_indices.at[N_Z-1].set(-flip_indices[N_Z-1])
flip_indices = jnp.array(flip_indices.tolist())
# flip_indices = jnp.array(flip_indices.tolist())
# Symmetrise
delta_k = Fourier_mask[jnp.newaxis] * delta_k
@ -65,8 +110,8 @@ def dens2vel_linear(delta, f, Lbox, smooth_R):
# Get k grid
k = 2*jnp.pi*jnp.fft.fftfreq(N, d=Lbox/N)
k_norm = jnp.sqrt(k[:,None,None]**2 + k[None,:,None]**2 + kz_vec[None,None,:]**2)
k_norm = k_norm.at[(k_norm < 1e-10)].set(1e-15)
k_norm = jnp.sqrt(k[:,None,None]**2 + k[None,:,None]**2 + k[None,None,:N_Z]**2)
k_norm = jnp.where(k_norm < 1e-10, 1e-15, k_norm)
# Filter
k_filter = jnp.exp(-0.5 * (k_norm[:,:,:N_Z] * smooth_R) ** 2)
@ -74,7 +119,7 @@ def dens2vel_linear(delta, f, Lbox, smooth_R):
vx = (
smooth_filter * k_filter
* jnp.array(complex(0, 1)) * 100 * f
* delta_k_complex
* delta_k
* jnp.tile(k[:,None,None],(1,N,N_Z))
/ k_norm**2
)
@ -83,8 +128,8 @@ def dens2vel_linear(delta, f, Lbox, smooth_R):
vy = (
smooth_filter * k_filter
* jnp.array(complex(0, 1)) * 100 * f
* delta_k_complex
* jnp.tile(ky[None,:,None], (N,1,N_Z))
* delta_k
* jnp.tile(k[None,:,None], (N,1,N_Z))
/ k_norm**2
)
vy = (jnp.fft.irfftn(vy)*V/dV)
@ -92,8 +137,8 @@ def dens2vel_linear(delta, f, Lbox, smooth_R):
vz = (
smooth_filter * k_filter
* jnp.array(complex(0, 1)) * 100 * f
* delta_k_complex
* jnp.tile(kz[None,None,:N_Z], (N,N,1))
* delta_k
* jnp.tile(k[None,None,:N_Z], (N,N,1))
/ k_norm**2
)
vz = (jnp.fft.irfftn(vz)*V/dV)

475
borg_velocity/likelihood.py Normal file
View file

@ -0,0 +1,475 @@
import numpy as np
import jax.numpy as jnp
import configparser
import warnings
import aquila_borg as borg
import symbolic_pofk.linear
import borg_velocity.utils as utils
from borg_velocity.utils import myprint
import borg_velocity.forwards as forwards
import borg_velocity.mock_maker as mock_maker
import borg_velocity.projection as projection
class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
"""
HADES likelihood for distance-tracers
"""
def __init__(self, fwd: borg.forward.BaseForwardModel, param_model: forwards.NullForward, ini_file: str) -> None:
"""
Initialises the VelocityBORGLikelihood class
Args:
- fwd (borg.forward.BaseForwardModel): The forward model to be used in the likelihood.
- param_model (forwards.NullForward): An empty forward model for storing model parameters.
- ini_file (str): The name of the ini file containing the model and borg parameters.
"""
self.ini_file = ini_file
myprint("Reading from configuration file: " + ini_file)
config = configparser.ConfigParser()
config.read(ini_file)
# Grid parameters
self.N = [int(config['system'][f'N{i}']) for i in range(3)]
self.L = [float(config['system'][f'L{i}']) for i in range(3)]
# Catalogues
self.nsamp = int(config['run']['nsamp'])
assert self.nsamp > 0, "Need at least one sample to run"
self.Nt = [int(config[f'sample_{i}']['Nt']) for i in range(self.nsamp)]
self.alpha = [float(config[f'sample_{i}']['alpha']) for i in range(self.nsamp)]
self.muA = [float(config[f'sample_{i}']['muA']) for i in range(self.nsamp)]
self.frac_sig_rhMpc = [float(config[f'sample_{i}']['frac_sig_rhMpc']) for i in range(self.nsamp)]
# What type of run we're doing
self.run_type = config['run']['run_type']
# Model parameters
self.sig_v = float(config['model']['sig_v'])
if config['model']['R_lim'] == 'none':
self.R_lim = self.L[0]/2
else:
self.R_lim = float(config['model']['R_lim'])
self.Nint_points = int(config['model']['nint_points'])
self.Nsig = float(config['model']['Nsig'])
myprint(f" Init {self.N}, {self.L}")
super().__init__(fwd, self.N, self.L)
# Define the forward models
self.fwd = fwd
self.fwd_param = param_model
def initializeLikelihood(self, state: borg.likelihood.MarkovState) -> None:
"""
Initialise the likelihood internal variables and MarkovState variables.
Args:
- state (borg.likelihood.MarkovState): The Markov state object to be used in the likelihood.
"""
myprint("Init likelihood")
# for i in range(self.nsamp):
# state.newArray1d(f"tracer_vr_{i}", self.Nt[i], True)
# if self.run_type != 'data':
# state.newArray1d(f"data_x_{i}", self.Nt[i], True)
# state.newArray1d(f"data_y_{i}", self.Nt[i], True)
# state.newArray1d(f"data_z_{i}", self.Nt[i], True)
# state.newArray1d(f"true_x_{i}", self.Nt[i], True)
# state.newArray1d(f"true_y_{i}", self.Nt[i], True)
# state.newArray1d(f"true_z_{i}", self.Nt[i], True)
# self.data = [state[f"tracer_vr_{i}"] for i in range(self.nsamp)]
state.newArray3d("BORG_final_density", *self.fwd.getOutputBoxModel().N, True)
def updateMetaParameters(self, state: borg.likelihood.MarkovState) -> None:
"""
Update the meta parameters of the sampler (not sampled) from the MarkovState.
Args:
- state (borg.likelihood.MarkovState): The state object to be used in the likelihood.
"""
cpar = state['cosmology']
cpar.omega_q = 1. - cpar.omega_m - cpar.omega_k
self.fwd.setCosmoParams(cpar)
self.fwd_param.setCosmoParams(cpar)
def updateCosmology(self, cosmo: borg.cosmo.CosmologicalParameters) -> None:
"""
Updates the forward model's cosmological parameters with the given values.
Args:
- cosmo (borg.cosmo.CosmologicalParameters): The cosmological parameters.
"""
cpar = cosmo
# Convert sigma8 to As
cpar.A_s = 1.e-9 * symbolic_pofk.linear.sigma8_to_As(
cpar.sigma8, cpar.omega_m, cpar.omega_b, cpar.h, cpar.n_s)
myprint(f"Updating cosmology Om = {cosmo.omega_m}, sig8 = {cosmo.sigma8}, As = {cosmo.A_s}")
cpar.omega_q = 1. - cpar.omega_m - cpar.omega_k
self.fwd.setCosmoParams(cpar)
self.fwd_param.setCosmoParams(cpar)
def generateMBData(self) -> None:
"""
Generate points along the line of sight of each tracer to enable marginalisation
over distance uncertainty. The central distance is given such that the observed
redshift equals the cosmological redshift at this distance. The range is then
+/- Nsig * sig, where
sig^2 = (sig_v/100)^2 + sig_r^2
and sig_v is the velocity uncertainty in km/s
"""
self.MB_pos = [None] * self.nsamp
self.r = [None] * self.nsamp
self.RA = [None] * self.nsamp
self.DEC = [None] * self.nsamp
for i in range(self.nsamp):
myprint(f"Making MB data for sample {i}")
# Get angular coordinates of all points
r_hat = projection.get_radial_vectors(self.coord_meas[i])
# Get min and max distance to integrate over
robs = self.cz_obs[i] / 100
sigr = np.sqrt((self.sig_v / 100) ** 2 + (self.frac_sig_rhMpc[i] * robs)**2)
rmin = robs - self.Nsig * sigr
rmin = rmin.at[rmin <= 0].set(self.L[0] / self.N[0] / 100.)
rmax = robs + self.Nsig * sigr
rmax = rmax.at[rmax > self.R_lim].set(self.R_lim)
# Compute coordinates of integration points
self.r[i] = np.linspace(rmin, rmax, self.Nint_points)
cartesian_pos_MB = np.expand_dims(self.r[i], axis=2) * r_hat
self.MB_pos[i] = cartesian_pos_MB
self.MB_pos[i] = jnp.transpose(self.MB_pos[i], (2, 0, 1))
def generateMockData(self, s_hat: np.ndarray, state: borg.likelihood.MarkovState, make_plot: bool=True) -> None:
"""
Generates mock data by simulating the forward model with the given white noise,
drawing distance tracers from the density field, computing their distance
moduli and radial velocities, and adding Gaussian noise to the appropriate
variables. Also calculates the initial negative log-likelihood of the data.
Args:
- s_hat (np.ndarray): The input (initial) density field.
- state (borg.likelihood.MarkovState): The Markov state object to be used in the likelihood.
- make_plot (bool, default=True): Whether to make diagnostic plots for the mock data generation
"""
if self.run_type == 'data':
raise NotImplementedError
elif self.run_type == 'velmass':
raise NotImplementedError
elif self.run_type == 'mock':
self.coord_true, self.coord_meas, self.sig_mu, self.vr_true, self.cz_obs = \
mock_maker.borg_mock(s_hat, state, self.fwd, self.ini_file)
else:
raise NotImplementedError
self.generateMBData()
quit()
def logLikelihoodComplex(self, s_hat: np.ndarray, gradientIsNext: bool):
"""
Calculates the negative log-likelihood of the data.
Args:
- s_hat (np.ndarray): The input white noise.
- gradientIsNext (bool): If True, prepares the forward model for gradient calculations.
Returns:
The negative log-likelihood value.
"""
N = self.fwd.getBoxModel().N[0]
L = self.fwd.getOutputBoxModel().L[0]
# Run BORG density field
output_density = np.zeros((N,N,N))
self.fwd.forwardModel_v2(s_hat)
self.fwd.getDensityFinal(output_density)
self.delta = output_density
# L = self.dens2like(output_density)
L = 1.
myprint(f"var(s_hat): {np.var(s_hat)}, Call to logLike: {L}")
return L
def gradientLikelihoodComplex(self, s_hat: np.ndarray):
"""
Calculates the adjoint negative log-likelihood of the data.
Args:
- s_hat (np.ndarray): The input density field.
Returns:
The adjoint negative log-likelihood gradient.
"""
N = self.fwd.getBoxModel().N[0]
L = self.fwd.getOutputBoxModel().L[0]
# Run BORG density field
output_density = np.zeros((N,N,N))
self.fwd.forwardModel_v2(s_hat)
self.fwd.getDensityFinal(output_density)
# mygradient = self.grad_like(output_density)
# mygradient = np.array(mygradient, dtype=np.float64)
# self.fwd.adjointModel_v2(mygradient)
mygrad_hat = np.zeros(s_hat.shape, dtype=np.complex128)
self.fwd.getAdjointModel(mygrad_hat)
return mygrad_hat
def commitAuxiliaryFields(self, state: borg.likelihood.MarkovState) -> None:
"""
Commits the final density field to the Markov state.
Args:
- state (borg.state.State): The state object containing the final density field.
"""
self.updateCosmology(self.fwd.getCosmoParams())
self.dens2like(self.delta)
state["BORG_final_density"][:] = self.delta
@borg.registerGravityBuilder
def build_gravity_model(state: borg.likelihood.MarkovState, box: borg.forward.BoxModel, ini_file=None) -> borg.forward.BaseForwardModel:
"""
Builds the gravity model and returns the forward model chain.
Args:
- state (borg.likelihood.MarkovState): The Markov state object to be used in the likelihood.
- box (borg.forward.BoxModel): The input box model.
- ini_file (str, default=None): The location of the ini file. If None, use borg.getIniConfigurationFilename()
Returns:
borg.forward.BaseForwardModel: The forward model.
"""
global chain, fwd_param
myprint("Building gravity model")
if ini_file is None:
myprint("Reading from configuration file: " + borg.getIniConfigurationFilename())
config = configparser.ConfigParser()
config.read(borg.getIniConfigurationFilename())
else:
myprint("Reading from configuration file: " + ini_file)
config = configparser.ConfigParser()
config.read(ini_file)
ai = float(config['model']['ai'])
af = float(config['model']['af'])
# Setup forward model
chain = borg.forward.ChainForwardModel(box)
# CLASS transfer function
chain @= borg.forward.model_lib.M_PRIMORDIAL_AS(box)
transfer_class = borg.forward.model_lib.M_TRANSFER_CLASS(box, opts=dict(a_transfer=1.0))
transfer_class.setModelParams({"extra_class_arguments":{'YHe':'0.24'}})
chain @= transfer_class
if config['model']['gravity'] == 'linear':
raise NotImplementedError(config['model']['gravity'])
elif config['model']['gravity'] == 'lpt':
chain @= borg.forward.model_lib.M_LPT_CIC(
box,
opts=dict(a_initial=af,
a_final=af,
do_rsd=False,
supersampling=1,
lightcone=False,
part_factor=1.01,))
elif config['model']['gravity'] == '2lpt':
chain @= borg.forward.model_lib.M_2LPT_CIC(
box,
opts=dict(a_initial=af,
a_final=af,
do_rsd=False,
supersampling=1,
lightcone=False,
part_factor=1.01,))
elif config['model']['gravity'] == 'pm':
chain @= borg.forward.model_lib.M_PM_CIC(
box,
opts=dict(a_initial=af,
a_final=af,
do_rsd=False,
supersampling=1,
part_factor=1.01,
forcesampling=2,
pm_start_z=1/ai - 1,
pm_nsteps=int(config['model']['nsteps']),
tcola=False))
elif config['model']['gravity'] == 'cola':
chain @= borg.forward.model_lib.M_PM_CIC(
box,
opts=dict(a_initial=af,
a_final=af,
do_rsd=False,
supersampling=1,
part_factor=1.01,
forcesampling=2,
pm_start_z=1/ai - 1,
pm_nsteps=int(config['model']['nsteps']),
tcola=True))
else:
raise NotImplementedError(config['model']['gravity'])
# Cosmological parameters
if ini_file is None:
cpar = utils.get_cosmopar(borg.getIniConfigurationFilename())
else:
cpar = utils.get_cosmopar(ini_file)
chain.setCosmoParams(cpar)
# This is the forward model for the model parameters
fwd_param = borg.forward.ChainForwardModel(box)
mod = forwards.NullForward(box)
fwd_param.addModel(mod)
fwd_param.setCosmoParams(cpar)
return chain
@borg.registerSamplerBuilder
def build_sampler(
state: borg.likelihood.MarkovState,
info: borg.likelihood.LikelihoodInfo,
):
"""
Builds the sampler and returns it.
Which parameters to sample are given in the ini file.
We assume all parameters are NOT meant to be sampled, unless we find "XX_sampler_blocked = false" in the ini file
Args:
- state (borg.likelihood.MarkovState): The Markov state object to be used in the likelihood.
- info (borg.likelihood.LikelihoodInfo): The likelihood information.
Returns:
List of samplers to use.
"""
myprint("Building sampler")
myprint("Reading from configuration file: " + borg.getIniConfigurationFilename())
config = configparser.ConfigParser()
config.read(borg.getIniConfigurationFilename())
end = '_sampler_blocked'
to_sample = [k[:-len(end)] for (k, v) in config['block_loop'].items() if k[-len(end):] == end and v.lower() == 'false']
myprint(f'Parameters to sample: {to_sample}')
nsamp = int(config['run']['nsamp'])
all_sampler = []
# Cosmology sampler arguments
prefix = ""
params = []
initial_values = {}
prior = {}
for p in ["omega_m", "sigma8"]:
if p not in to_sample:
continue
if p in config['prior'].keys() and p in config['cosmology'].keys():
myprint(f'Adding {p} sampler')
params.append(f"cosmology.{p}")
initial_values[f"cosmology.{p}"] = float(config['cosmology'][p])
prior[f"cosmology.{p}"] = np.array(ast.literal_eval(config['prior'][p]))
else:
s = f'Could not find {p} prior and/or default, so will not sample'
warnings.warn(s, stacklevel=2)
# Remove for later to prevent duplication
to_sample.remove(p)
if len(params) > 0:
myprint('Adding cosmological parameter sampler')
all_sampler.append(borg.samplers.ModelParamsSampler(prefix, params, likelihood, chain, initial_values, prior))
# Model parameter sampler
prefix = ""
params = []
initial_values = {}
prior = {}
for p in to_sample:
if p in config['prior'].keys():
if p == 'sig_v':
myprint(f'Adding {p} sampler')
params.append(p)
initial_values[p] = float(config['model'][p])
if 'inf' in config['prior'][p]:
x = ast.literal_eval(config['prior'][p].replace('inf', '"inf"'))
prior[p] = np.array([xx if xx != 'inf' else np.inf for xx in x])
else:
prior[p] = np.array(ast.literal_eval(config['prior'][p]))
elif p == 'bulk_flow':
for i, d in enumerate(['_x', '_y', '_z']):
myprint(f'Adding {p}{d} sampler')
params.append(f'{p}{d}')
initial_values[f'{p}{d}'] = np.array(ast.literal_eval(config['model']['bulk_flow']))[i]
if 'inf' in config['prior'][p]:
x = ast.literal_eval(config['prior'][p].replace('inf', '"inf"'))
prior[f'{p}{d}'] = np.array([xx if xx != 'inf' else np.inf for xx in x])
else:
prior[f'{p}{d}'] = np.array(ast.literal_eval(config['prior'][p]))
else:
for i in range(nsamp):
myprint(f'Adding {p}{i} sampler')
params.append(f'{p}{i}')
initial_values[f'{p}{i}'] = float(config[f'sample_{i}'][p])
if 'inf' in config['prior'][p]:
x = ast.literal_eval(config['prior'][p].replace('inf', '"inf"'))
prior[f'{p}{i}'] = np.array([xx if xx != 'inf' else np.inf for xx in x])
else:
prior[f'{p}{i}'] = np.array(ast.literal_eval(config['prior'][p]))
else:
s = f'Could not find {p} prior, so will not sample'
warnings.warn(s, stacklevel=2)
if len(params) > 0:
myprint('Adding model parameter sampler')
all_sampler.append(borg.samplers.ModelParamsSampler(prefix, params, likelihood, fwd_param, initial_values, prior))
return all_sampler
@borg.registerLikelihoodBuilder
def build_likelihood(state: borg.likelihood.MarkovState, info: borg.likelihood.LikelihoodInfo) -> borg.likelihood.BaseLikelihood:
"""
Builds the likelihood object and returns it.
Args:
- state (borg.likelihood.MarkovState): The Markov state object to be used in the likelihood.
- info (borg.likelihood.LikelihoodInfo): The likelihood information.
Returns:
borg.likelihood.BaseLikelihood: The likelihood object.
"""
global likelihood, fwd_param
myprint("Building likelihood")
myprint(chain.getCosmoParams())
boxm = chain.getBoxModel()
likelihood = VelocityBORGLikelihood(chain, fwd_param, borg.getIniConfigurationFilename())
return likelihood

141
borg_velocity/mock_maker.py Normal file
View file

@ -0,0 +1,141 @@
import aquila_borg as borg
import numpy as np
from astropy.coordinates import SkyCoord
import astropy.units as apu
import configparser
import ast
import borg_velocity.utils as utils
from borg_velocity.utils import myprint
import borg_velocity.forwards as forwards
import borg_velocity.poisson_process as poisson_process
import borg_velocity.projection as projection
def radially_scatter(xtrue, frac_sig_x):
"""
Radially perturb true positions by a fractional amount
frac_sig_x, where the observer sits at x = (0, 0, 0).
Args:
:xtrue (np.ndarray): The true coordinates (Mpc/h), shape=(3,Nt)
:frac_sig_x (float): The fractional uncertainty in the radial direction
Returns:
:xmeas (np.ndarray): The observed coordiantes (Mpc/h) of the tracers.
:sigma_mu (float): The uncertainty in the distance moduli of the tracers.
"""
# Convert to RA, Dec, Distance
rtrue = np.sqrt(np.sum(xtrue** 2, axis=0)) # Mpc/h
c = SkyCoord(x=xtrue[0], y=xtrue[1], z=xtrue[2], representation_type='cartesian')
RA = c.spherical.lon.degree
Dec = c.spherical.lat.degree
r_hat = np.array(SkyCoord(ra=RA*apu.deg, dec=Dec*apu.deg).cartesian.xyz)
# Add noise to radial direction
sigma_mu = 5. / np.log(10) * frac_sig_x
mutrue = 5 * np.log10(rtrue * 1.e6 / 10)
mumeas = mutrue + np.random.normal(size=len(mutrue)) * sigma_mu
rmeas = 10 ** (mumeas / 5.) * 10 / 1.e6
xmeas, ymeas, zmeas = rmeas[None,:] * r_hat
xmeas = np.array([xmeas, ymeas, zmeas])
return xmeas, sigma_mu
def borg_mock(s_hat, state, fwd_model, ini_file):
config = configparser.ConfigParser()
config.read(ini_file)
nsamp = int(config['run']['nsamp'])
# Run BORG density field
output_density = np.zeros(fwd_model.getOutputBoxModel().N)
fwd_model.forwardModel_v2(s_hat)
fwd_model.getDensityFinal(output_density)
state["BORG_final_density"][:] = output_density
# Get growth rate
cosmo = utils.get_cosmopar(ini_file)
cosmology = borg.cosmo.Cosmology(cosmo)
af = float(config['model']['af'])
f = cosmology.gplus(af) # dD / da
f *= af / cosmology.d_plus(af) # f = dlnD / dlna
# Get velocity
smooth_R = float(config['model']['smooth_R'])
output_vel = forwards.dens2vel_linear(output_density, f,
fwd_model.getOutputBoxModel().L[0], smooth_R)
# Add bulk flow
bulk_flow = np.array(ast.literal_eval(config['mock']['bulk_flow']))
output_vel = output_vel + bulk_flow.reshape((3, 1, 1, 1))
# Sample positions according to bias model
bias_epsilon = float(config['model']['bias_epsilon'])
R_max = float(config['mock']['R_max'])
coord_true = [None] * nsamp
coord_meas = [None] * nsamp
sig_mu = [None] * nsamp
for i in range(nsamp):
frac_sig_x = float(config[f'sample_{i}']['frac_sig_rhMpc'])
alpha = float(config[f'sample_{i}']['alpha'])
Nt = int(config[f'sample_{i}']['Nt'])
phi = (1. + output_density + bias_epsilon) ** alpha
coord_true[i] = np.zeros((3,Nt))
coord_meas[i] = np.zeros((3,Nt))
nmade = 0
while (nmade < Nt):
ctrue = poisson_process.sample_3d(phi, Nt,
fwd_model.getOutputBoxModel().L[0], fwd_model.getOutputBoxModel().xmin)
cmeas, sig_mu[i] = radially_scatter(ctrue, frac_sig_x)
# Only use tracers which lie within R_max
r = np.sqrt(np.sum(ctrue**2, axis=0))
m = r < R_max
nnew = m.sum()
if nmade + nnew > Nt:
coord_true[i][:,nmade:] = ctrue[:,m][:,:Nt-nmade]
coord_meas[i][:,nmade:] = cmeas[:,m][:,:Nt-nmade]
else:
coord_true[i][:,nmade:nmade+nnew] = ctrue[:,m]
coord_meas[i][:,nmade:nmade+nnew] = cmeas[:,m]
nmade = min(nmade + m.sum(), Nt)
# Interpolate velocities to tracers
interp_order = int(config['model']['interp_order'])
vr_true = [None] * nsamp
for i in range(nsamp):
tracer_vel = projection.interp_field(
output_vel,
np.expand_dims(coord_true[i], axis=2),
fwd_model.getOutputBoxModel().L[0],
np.array(fwd_model.getOutputBoxModel().xmin),
interp_order
)
vr_true[i] = projection.project_radial(
tracer_vel,
np.expand_dims(coord_true[i], axis=2),
np.zeros(3,)
)
# Compute observed redshifts (including noise)
sig_v = float(config['model']['sig_v'])
cz_obs = [None] * nsamp
for i in range(nsamp):
rtrue = np.sqrt(np.sum(coord_true[i] ** 2, axis=1))
zco = utils.z_cos(rtrue, cosmo.omega_m)
cz_obs = utils.speed_of_light * zco + (1 + zco) * vr_true[i]
cz_obs += np.random.normal(size=cz_obs.shape) * sig_v # CHECK THIS LINE!!!!!
# Add observational systematic due to incorrect distance estimate
# \mu -> \mu + 5 log10(A), or equivalently d -> A d
for i in range(nsamp):
muA = float(config[f'sample_{i}']['muA'])
coord_meas[i] = coord_meas[i] * muA
return coord_true, coord_meas, sig_mu, vr_true, cz_obs

View file

@ -0,0 +1,208 @@
import numpy as np
import scipy.stats
def sample_uniform(N: int, Nt: int, L: float, origin: np.ndarray):
"""
Generate Nt points uniformly sampled from a box of side length L.
Args:
- N (int): The number of grid points per side.
- Nt (int): The number of tracers to generate.
- L (float): The side-length of the box (Mpc/h).
- origin (np.ndarray): The coordinates of the origin of the box (Mpc/h).
Returns:
- xtrue (np.ndarray): The true coordinates (Mpc/h) of the tracers.
"""
h = 1
xtrue = np.random.uniform(low=0.0, high=N+1, size=Nt)
ytrue = np.random.uniform(low=0.0, high=N+1, size=Nt)
ztrue = np.random.uniform(low=0.0, high=N+1, size=Nt)
# Convert to coordinates, and move relative to origin
xtrue *= L / N # Mpc/h
ytrue *= L / N # Mpc/h
ztrue *= L / N # Mpc/h
xtrue += origin[0]
ytrue += origin[1]
ztrue += origin[2]
xtrue = np.array([xtrue, ytrue, ztrue])
return xtrue
def draw_linear(nsamp: int, alpha: float, beta: float, u0: float, u1: float) -> np.ndarray:
"""
Draw a sample from the probability distribution:
p(u) \propto alpha (u1 - u) + beta (u - u0)
for u0 <= u <= u1 and p(u) = 0 otherwise.
Args:
- nsamp (int): Number of samples to draw.
- alpha (float): The coefficient of (u1 - u) in p(u).
- beta (float): The coefficient of (u - u0) in p(u).
- u0 (float): The minimum allowed value of u.
- u1 (float): The maximum allowed value of u.
Return:
- np.ndarray: The samples from p(u).
"""
n = scipy.stats.uniform(0, 1).rvs(nsamp)
if isinstance(alpha, np.ndarray):
res = np.zeros(alpha.shape)
m = alpha != beta
res[m] = ((u1 - u0) * np.sqrt(n * (beta ** 2 - alpha ** 2) + alpha ** 2) - u1 * alpha + u0 * beta)[m] / (beta - alpha)[m]
res[~m] = (u0 + (u1 - u0) * n)[~m]
return res
else:
if alpha != beta:
return ((u1 - u0) * np.sqrt(n * (beta ** 2 - alpha ** 2) + alpha ** 2) - u1 * alpha + u0 * beta) / (beta - alpha)
else:
return u0 + (u1 - u0) * n
def periodic_index(index: np.ndarray, shape: tuple) -> np.ndarray:
"""
Apply periodic boundary conditions to an array of indices.
Args:
- index (np.ndarray): The indices to transform. Shape = (ndim, nvals).
- shape (tuple): The shape of the box used for periodic boundary conditions (N0, N1, ...)
Returns:
- new_index (np.ndarray): The values in index after applying periodic boundary conditions, such that for dimension i, the values are in the range [0, Ni)
"""
assert index.shape[0] == len(shape)
new_index = index.copy()
for i in range(len(shape)):
new_index[i,:] = np.mod(new_index[i,:], shape[i])
return new_index
def get_new_index(index: np.ndarray, shape: tuple, subscript: tuple) -> np.ndarray:
"""
If each entry of index corresponds to (0,0,0), find the index corresponding to the point given by subscript.
Args:
- index (np.ndarray): The indices to transform. Shape = (ndim, nvals).
- shape (tuple): The shape of the box used (N0, N1, ...).
- subscript (tuple): The coordinate to find, relative to the values given in index.
Returns:
- new_index (np.ndarray): The new index values.
"""
new_index = index.copy()
for i in range(len(subscript)):
new_index[i,:] += subscript[i]
new_index = periodic_index(new_index, shape)
return new_index
def sample_3d(phi: np.ndarray, Nt: int, L: float, origin: np.ndarray) -> np.ndarray:
"""
Sample Nt points, assuming that the points are drawn from a Poisson process given by the field phi.
phi gives the value of the field at the grid points, and we assume linear interpolation between points.
Args:
- phi (np.ndarray): The field which defines the mean of the Poisson process.
- Nt (int): The number of tracers to generate.
- L (float): The side-length of the box (Mpc/h).
- origin (np.ndarray): The coordinates of the origin of the box (Mpc/h).
Returns:
- xtrue (np.ndarray): The true coordinates (Mpc/h) of the tracers.
"""
N = phi.shape[0]
h = 1
# (1) Find which cell each point lives in
mean = phi + \
np.roll(phi, -1, axis=0) + \
np.roll(phi, -1, axis=1) + \
np.roll(phi, -1, axis=2) + \
np.roll(phi, -1, axis=(0,1)) + \
np.roll(phi, -1, axis=(0,2)) + \
np.roll(phi, -1, axis=(1,2)) + \
np.roll(phi, -1, axis=(0,1,2))
prob = mean.flatten() / mean.sum()
i = np.arange(prob.shape[0])
a1d = np.random.choice(i, Nt, p=prob)
a3d = np.array(np.unravel_index(a1d, (N,N,N)))
# (2) Find the x values
shape = (N, N, N)
alpha = np.zeros(Nt)
for subscript in [(0,0,0), (0,0,1), (0,1,0), (0,1,1)]:
idx = get_new_index(a3d, shape, subscript)
alpha += phi[idx[0,:], idx[1,:], idx[2,:]]
beta = np.zeros(Nt)
for subscript in [(1,0,0), (1,0,1), (1,1,0), (1,1,1)]:
idx = get_new_index(a3d, shape, subscript)
beta += phi[idx[0,:], idx[1,:], idx[2,:]]
u0 = a3d[0,:]
u1 = a3d[0,:] + 1
xtrue = draw_linear(Nt, alpha, beta, u0, u1)
# (3) Find the y values
shape = (N, N, N)
alpha = np.zeros(Nt)
for subscript in [(0,0,0), (0,0,1)]:
idx = get_new_index(a3d, shape, subscript)
alpha += phi[idx[0,:], idx[1,:], idx[2,:]] * (a3d[0,:] + 1 - xtrue)
for subscript in [(1,0,0), (1,0,1)]:
idx = get_new_index(a3d, shape, subscript)
alpha += phi[idx[0,:], idx[1,:], idx[2,:]] * (xtrue - a3d[0,:])
beta = np.zeros(Nt)
for subscript in [(0,1,0), (0,1,1)]:
idx = get_new_index(a3d, shape, subscript)
beta += phi[idx[0,:], idx[1,:], idx[2,:]] * (a3d[0,:] + 1 - xtrue)
for subscript in [(1,1,0), (1,1,1)]:
idx = get_new_index(a3d, shape, subscript)
beta += phi[idx[0,:], idx[1,:], idx[2,:]] * (xtrue - a3d[0,:])
u0 = a3d[1,:]
u1 = a3d[1,:] + 1
ytrue = draw_linear(Nt, alpha, beta, u0, u1)
# (4) Find the z values
xd = (xtrue - a3d[0,:]) # x1-x0=1 so xd = x - x0
yd = (ytrue - a3d[1,:]) # y1-y0=1 so yd = y - y0
ia = get_new_index(a3d, shape, (0,0,0))
ib = get_new_index(a3d, shape, (1,0,0))
phi00 = phi[ia[0,:], ia[1,:], ia[2,:]] * (1 - xd) + \
phi[ib[0,:], ib[1,:], ib[2,:]] * xd
ia = get_new_index(a3d, shape, (0,0,1))
ib = get_new_index(a3d, shape, (1,0,1))
phi01 = phi[ia[0,:], ia[1,:], ia[2,:]] * (1 - xd) + \
phi[ib[0,:], ib[1,:], ib[2,:]] * xd
ia = get_new_index(a3d, shape, (0,1,0))
ib = get_new_index(a3d, shape, (1,1,0))
phi10 = phi[ia[0,:], ia[1,:], ia[2,:]] * (1 - xd) + \
phi[ib[0,:], ib[1,:], ib[2,:]] * xd
ia = get_new_index(a3d, shape, (0,1,1))
ib = get_new_index(a3d, shape, (1,1,1))
phi11 = phi[ia[0,:], ia[1,:], ia[2,:]] * (1 - xd) + \
phi[ib[0,:], ib[1,:], ib[2,:]] * xd
alpha = phi00 * (1 - yd) + phi10 * yd # alpha = phi0
beta = phi01 * (1 - yd) + phi11 * yd # beta = phi1
u0 = a3d[2,:]
u1 = a3d[2,:] + 1
ztrue = draw_linear(Nt, alpha, beta, u0, u1)
# Convert to coordinates, and move relative to origin
xtrue *= L / N # Mpc/h
ytrue *= L / N # Mpc/h
ztrue *= L / N # Mpc/h
xtrue += origin[0]
ytrue += origin[1]
ztrue += origin[2]
xtrue = np.array([xtrue, ytrue, ztrue])
return xtrue

View file

@ -2,6 +2,8 @@ import jax.numpy as jnp
import jax.scipy.ndimage
import jax
from functools import partial
from astropy.coordinates import SkyCoord
import astropy.units as apu
@partial(jax.jit, static_argnames=['order'])
def jit_map_coordinates(image: jnp.ndarray, coords:jnp.ndarray, order: int) -> jnp.ndarray:
@ -48,7 +50,7 @@ def interp_field(input_array: jnp.ndarray, coords: jnp.ndarray, L: float, origin
else:
def fun_to_vmap(arr):
return jax.scipy.ndimage.map_coordinates(arr, pos, order=order, mode='wrap')
if len(input_array.shape) == coords.shape[0]:
out_array = jit_map_coordinates(input_array, pos, order)
elif len(input_array.shape) == coords.shape[0]+1:
@ -77,4 +79,18 @@ def project_radial(vec: jnp.ndarray, coords: jnp.ndarray, origin: jnp.ndarray) -
x = x / jnp.expand_dims(r, axis=0)
vr = jnp.sum(x * vec, axis=0)
return vr
return vr
def get_radial_vectors(coord_meas):
c = SkyCoord(x=coord_meas[0], y=coord_meas[1], z=coord_meas[2],
representation_type='cartesian')
RA = c.spherical.lon.degree
DEC = c.spherical.lat.degree
# Get unit vectors along line of sight
r_hat = jnp.array(SkyCoord(ra=RA*apu.deg, dec=DEC*apu.deg).cartesian.xyz)
r_hat = jnp.expand_dims(r_hat, axis=0)
return r_hat

View file

@ -1,6 +1,10 @@
import aquila_borg as borg
import configparser
import os
import symbolic_pofk.linear
from functools import partial
import jax
import jax.numpy as jnp
# Output stream management
cons = borg.console()
@ -33,6 +37,8 @@ def get_cosmopar(ini_file):
cpar.n_s = float(config['cosmology']['n_s'])
cpar.w = float(config['cosmology']['w'])
cpar.wprime = float(config['cosmology']['wprime'])
cpar.A_s = 1.e-9 * symbolic_pofk.linear.sigma8_to_As(
cpar.sigma8, cpar.omega_m, cpar.omega_b, cpar.h, cpar.n_s)
return cpar
@ -65,4 +71,23 @@ def get_action():
last_line = last_line[:idx].upper()
myprint(f'Running BORG mode: {last_line}')
return last_line
return last_line
speed_of_light = 299792 # km/s
@partial(jax.jit,)
def z_cos(r_hMpc: float, Omega_m: float) -> float:
"""
Convert distance in Mpc/h to cosmological redshift assuming LCDM
Args:
- r_hMpc (jnp.ndarray): Distances to tracers in Mpc/h
- Omega_m (float): The value of Omega_m at z=0
Returns:
- jnp.ndarray: The redshifts.
"""
Omega_L = 1. - Omega_m
q0 = Omega_m/2.0 - Omega_L
return (1.0 - jnp.sqrt(1 - 2*r_hMpc*100*(1 + q0)/speed_of_light))/(1.0 + q0)

91
conf/basic_ini.ini Normal file
View file

@ -0,0 +1,91 @@
[system]
VERBOSE_LEVEL = 2
N0 = 32
N1 = 32
N2 = 32
L0 = 500.0
L1 = 500.0
L2 = 500.0
corner0 = -250.0
corner1 = -250.0
corner2 = -250.0
NUM_MODES = 100
test_mode = true
seed_cpower = true
[block_loop]
hades_sampler_blocked = false
bias_sampler_blocked= true
nmean_sampler_blocked= true
sigma8_sampler_blocked = true
muA_sampler_blocked = true
omega_m_sampler_blocked = true
alpha_sampler_blocked = true
sig_v_sampler_blocked = true
bulk_flow_sampler_blocked = true
ares_heat = 1.0
[mcmc]
number_to_generate = 10
random_ic = false
init_random_scaling = 0.1
[model]
gravity = lpt
af = 1.0
ai = 0.05
nsteps = 20
smooth_R = 4
bias_epsilon = 1e-7
interp_order = 1
sig_v = 150.
R_lim = none
Nint_points = 201
Nsig = 10
[prior]
omega_m = [0.1, 0.8]
sigma8 = [0.1, 1.5]
muA = [0.5, 1.5]
alpha = [0.0, 10.0]
sig_v = [50.0, 200.0]
bulk_flow = [-200.0, 200.0]
[cosmology]
omega_r = 0
fnl = 0
omega_k = 0
omega_m = 0.315
omega_b = 0.049
omega_q = 0.685
h100 = 0.68
sigma8 = 0.81
n_s = 0.97
w = -1
wprime = 0
beta = 1.5
z0 = 0
[run]
run_type = mock
NCAT = 0
NSAMP = 2
[mock]
R_max = 100
bulk_flow = [0.0, 0.0, 0.0]
[python]
likelihood_path = /home/bartlett/fsigma8/borg_velocity/borg_velocity/likelihood.py
[sample_0]
Nt = 345
muA = 1.0
alpha = 1.4
frac_sig_rhMpc = 0.07
[sample_1]
Nt = 1682
muA = 1.0
alpha = 1.4
frac_sig_rhMpc = 0.07

BIN
figs/test_dens_ic.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 71 KiB

BIN
figs/test_vel.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 96 KiB

34
scripts/run_borg.sh Executable file
View file

@ -0,0 +1,34 @@
#!/bin/sh
# Modules
module purge
module restore myborg
module load cuda/12.3
# Environment
source /home/bartlett/.bashrc
source /home/bartlett/anaconda3/etc/profile.d/conda.sh
conda deactivate
conda activate borg_env
# Kill job if there are any errors
set -e
# Path variables
BORG=/home/bartlett/anaconda3/envs/borg_env/bin/hades_python
RUN_DIR=/data101/bartlett/fsigma8/borg_velocity/basic_run
mkdir -p $RUN_DIR
cd $RUN_DIR
# Create a custom file descriptor (3) for tracing
exec 3>$RUN_DIR/trace_file.txt
# Redirect trace output to the custom file descriptor
BASH_XTRACEFD="3"
set -x
# Just ICs
INI_FILE=/home/bartlett/fsigma8/borg_velocity/conf/basic_ini.ini
cp $INI_FILE basic_ini.ini
$BORG INIT basic_ini.ini

22
setup.py Normal file
View file

@ -0,0 +1,22 @@
from setuptools import setup
setup(
name='borg_velocity',
version='0.1.0',
description='BORG Inference with Velocity Tracers',
url='https://github.com/DeaglanBartlett/borg_velocity',
author='Deaglan Bartlett',
author_email='deaglan.bartlett@iap.fr',
license='MIT licence',
packages=['borg_velocity'],
install_requires=[
'numpy',
],
classifiers=[
'Development Status :: 1 - Planning',
'Intended Audience :: Science/Research',
'License :: MIT License',
'Operating System :: POSIX :: Linux',
'Programming Language :: Python :: 3',
],
)

View file

@ -0,0 +1,5 @@
Memory still allocated at the end: 0 MB
Statistics per context (name, allocated, freed, peak)
======================

28
tests/fft_wisdom Normal file
View file

@ -0,0 +1,28 @@
(fftw-3.3.10 fftw_wisdom #x3c273403 #x192df114 #x4d08727c #xe98e9b9d
(fftw_dft_thr_vrank_geq1_register 0 #x10bdd #x10bdd #x0 #xe7f77f6a #xaf2de8b8 #xad19bc70 #x80305f29)
(fftw_codelet_n1bv_32_avx 0 #x10bdd #x10bdd #x0 #x6d197f20 #xfc9cbd23 #x91ddb367 #x208619cb)
(fftw_dft_thr_vrank_geq1_register 0 #x10bdd #x10bdd #x0 #x1f2e97fe #x61895cd8 #x6073a2f5 #x6ada2663)
(fftw_dft_thr_vrank_geq1_register 0 #x10bdd #x10bdd #x0 #x84033142 #x81339a41 #xb78a491e #x66362e05)
(fftw_rdft2_thr_vrank_geq1_register 0 #x10bdd #x10bdd #x0 #x40ffeb6f #x4d232a35 #x49c61e65 #x4d75fa83)
(fftw_codelet_n1bv_32_avx 0 #x10bdd #x10bdd #x0 #x35d0d312 #x6b498ae1 #x1ddcffdc #x4a1a1998)
(fftw_dft_thr_vrank_geq1_register 0 #x10bdd #x10bdd #x0 #xbffceb36 #x5b340e87 #xc2433c88 #x10e155b2)
(fftw_rdft2_thr_vrank_geq1_register 0 #x10bdd #x10bdd #x0 #x7ec9785e #x02957b55 #xab1017dc #xdcd04ed7)
(fftw_codelet_r2cf_32 2 #x10bdd #x10bdd #x0 #xe5219ff5 #x7cc0cc2f #x9ce07377 #x12d27b02)
(fftw_dft_thr_vrank_geq1_register 0 #x10bdd #x10bdd #x0 #xd78cc60c #x6e1210c6 #x5868829d #x70ada990)
(fftw_codelet_r2cf_32 2 #x10bdd #x10bdd #x0 #x68269cfc #xb89b69b3 #x4eaad8fa #x9807c679)
(fftw_rdft2_rank_geq2_register 0 #x10bdd #x10bdd #x0 #x7446ec55 #x3f800a5f #xba25afcf #xc0e9d5c1)
(fftw_dft_thr_vrank_geq1_register 0 #x10bdd #x10bdd #x0 #x0ac209ed #x737616a2 #xc31f0ad8 #x13c3716f)
(fftw_dft_vrank_geq1_register 0 #x10bdd #x10bdd #x0 #x4b54e3ca #x4f94ebf3 #x244f4da3 #x2412ca79)
(fftw_rdft2_rank_geq2_register 0 #x10bdd #x10bdd #x0 #x68900aea #xb640ce9e #xcd3b0e06 #x8170fa63)
(fftw_dft_vrank_geq1_register 0 #x10bdd #x10bdd #x0 #x404fdd72 #x2323d034 #xc860c577 #x4779492a)
(fftw_rdft2_rank_geq2_register 0 #x10bdd #x10bdd #x0 #x3c2e2a1a #x07c08954 #x35c337d9 #x80864862)
(fftw_codelet_n1fv_32_sse2 0 #x10bdd #x10bdd #x0 #xe61c7c8d #x2cea019e #x8489a633 #x8d6543c6)
(fftw_rdft2_rank_geq2_register 0 #x10bdd #x10bdd #x0 #x962543ac #xb000f636 #xb27fc586 #xd4a83bb7)
(fftw_codelet_n1fv_32_avx 0 #x10bdd #x10bdd #x0 #x94cb38f8 #xed5987e0 #xa3d4151a #xeb412d04)
(fftw_codelet_r2cb_32 2 #x10bdd #x10bdd #x0 #x92bf92d5 #xdc456f1e #x5a32a424 #xe1f76e14)
(fftw_codelet_n1fv_32_avx 0 #x10bdd #x10bdd #x0 #xb5d7d23e #x26089494 #x55133ef3 #x8ac38174)
(fftw_dft_thr_vrank_geq1_register 0 #x10bdd #x10bdd #x0 #xe0a3b250 #xab7e7c07 #xf0935dde #x1568a95f)
(fftw_codelet_r2cb_32 2 #x10bdd #x10bdd #x0 #x4e6e3714 #xebce55aa #x0ede5253 #x4faf4524)
(fftw_codelet_n1bv_32_sse2 0 #x10bdd #x10bdd #x0 #x902cd310 #xa659999d #x6fde2637 #xb23e4fd2)
(fftw_dft_thr_vrank_geq1_register 0 #x10bdd #x10bdd #x0 #x21347a5d #x286e0d10 #xabf9ff02 #xccdf80a5)
)

84
tests/test_forward.py Normal file
View file

@ -0,0 +1,84 @@
import numpy as np
import aquila_borg as borg
import configparser
import matplotlib.pyplot as plt
import borg_velocity.likelihood as likelihood
import borg_velocity.utils as utils
import borg_velocity.forwards as forwards
ini_file = '../conf/basic_ini.ini'
# Setup config
config = configparser.ConfigParser()
config.read(ini_file)
# Cosmology
cosmo = utils.get_cosmopar(ini_file)
# Input box
box_in = borg.forward.BoxModel()
box_in.L = (float(config['system']['L0']), float(config['system']['L1']), float(config['system']['L2']))
box_in.N = (int(config['system']['N0']), int(config['system']['N1']), int(config['system']['N2']))
box_in.xmin = (float(config['system']['corner0']), float(config['system']['corner1']), float(config['system']['corner2']))
# Load forward model
fwd_model = likelihood.build_gravity_model(None, box_in, ini_file=ini_file)
fwd_model.setCosmoParams(cosmo)
# Make some initial conditions
s_hat = np.fft.rfftn(np.random.randn(*box_in.N)) / box_in.Ntot ** (0.5)
# Get the real version of this
s_real = np.fft.irfftn(s_hat, norm="ortho")
# Run BORG density field
output_density = np.zeros(box_in.N)
fwd_model.forwardModel_v2(s_hat)
fwd_model.getDensityFinal(output_density)
# Get growth rate
cosmology = borg.cosmo.Cosmology(cosmo)
af = float(config['model']['af'])
f = cosmology.gplus(af) # dD / da
f *= af / cosmology.d_plus(af) # f = dlnD / dlna
# Get velocity
smooth_R = float(config['model']['smooth_R'])
output_vel = forwards.dens2vel_linear(output_density, f, box_in.L[0], smooth_R)
print(output_vel.shape)
# Plot the IC and dens fields
fig, axs = plt.subplots(2, 3, figsize=(15,10))
for i, field in enumerate([s_real, np.log10(2 + output_density)]):
vmin, vmax = field.min(), field.max()
pc = axs[i,0].pcolor(field[box_in.N[0]//2], vmin=vmin, vmax=vmax)
fig.colorbar(pc, ax=axs[i,0])
pc = axs[i,1].pcolor(field[:,box_in.N[1]//2,:], vmin=vmin, vmax=vmax)
fig.colorbar(pc, ax=axs[i,1])
pc = axs[i,2].pcolor(field[:,:,box_in.N[2]//2], vmin=vmin, vmax=vmax)
fig.colorbar(pc, ax=axs[i,2])
axs[0,1].set_title('Initial Conditions')
axs[1,1].set_title(r'$\log_{10} (2 + \delta)$')
for ax in axs.flatten():
ax.set_aspect('equal')
fig.savefig('../figs/test_dens_ic.png')
# Plot the velocity fields
fig, axs = plt.subplots(3, 3, figsize=(15,15))
vmin, vmax = output_vel.min(), output_vel.max()
for i in range(3):
pc = axs[i,0].pcolor(output_vel[i,box_in.N[0]//2], vmin=vmin, vmax=vmax)
fig.colorbar(pc, ax=axs[i,0])
pc = axs[i,1].pcolor(output_vel[i,:,box_in.N[1]//2,:], vmin=vmin, vmax=vmax)
fig.colorbar(pc, ax=axs[i,1])
pc = axs[i,2].pcolor(output_vel[i,:,:,box_in.N[2]//2], vmin=vmin, vmax=vmax)
fig.colorbar(pc, ax=axs[i,2])
axs[0,1].set_title(r'$v_x$')
axs[1,1].set_title(r'$v_y$')
axs[2,1].set_title(r'$v_z$')
for ax in axs.flatten():
ax.set_aspect('equal')
fig.savefig('../figs/test_vel.png')

7
tests/timing_stats_0.txt Normal file
View file

@ -0,0 +1,7 @@
ARES version c6de4f62faad20ede0bb40aa3678551dceee637b modules
Cumulative timing spent in different context
--------------------------------------------
Context, Total time (seconds)
[/build/jenkins/miniconda3/envs/builder/conda-bld/aquila_borg_1686906696789/work/libLSS/physics/forwards/adapt_generic_bias.cpp]void {anonymous}::bias_registrator() 1 0.000723925