Only update bulk flow prior when cosmology changed
|
@ -189,6 +189,8 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
|
|||
self.fwd.setCosmoParams(cpar)
|
||||
self.fwd_param.setCosmoParams(cpar)
|
||||
|
||||
self.sigma_bulk = get_sigma_bulk(self.L[0], self.fwd.getCosmoParams())
|
||||
|
||||
|
||||
def generateMBData(self) -> None:
|
||||
"""
|
||||
|
@ -355,8 +357,7 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
|
|||
# self.lkl_ind[i] = temp_lkl_ind.copy()
|
||||
|
||||
# Add in bulk flow prior
|
||||
# sigma_bulk = get_sigma_bulk(self.L[0], self.fwd.getCosmoParams())
|
||||
# lkl += jnp.sum(0.5 * jnp.log(2 * np.pi) + jnp.log(sigma_bulk) + self.bulk_flow ** 2 / 2. / sigma_bulk ** 2)
|
||||
lkl += jnp.sum(0.5 * jnp.log(2 * np.pi) + jnp.log(self.sigma_bulk) + self.bulk_flow ** 2 / 2. / self.sigma_bulk ** 2)
|
||||
|
||||
# lkl = jnp.clip(lkl, -self.bignum, self.bignum)
|
||||
lkl = lax.cond(
|
||||
|
|
|
@ -523,13 +523,19 @@ class TransformedBlackJaxBiasSampler(borg.samplers.PyBaseSampler):
|
|||
adaptation_info_fn=get_info
|
||||
)
|
||||
rng_key, warmup_key, sample_key = jax.random.split(self.rng_key, 3)
|
||||
(state, self.parameters), info = warmup.run(warmup_key, self.y, num_steps=self.warmup_nsteps)
|
||||
(bj_state, self.parameters), info = warmup.run(warmup_key, self.y, num_steps=self.warmup_nsteps)
|
||||
# x = info.state.position
|
||||
self.y[:] = state.position
|
||||
self.y[:] = bj_state.position
|
||||
|
||||
# Save results to the borg state object
|
||||
state["inverse_mass_matrix"][:] = self.parameters['inverse_mass_matrix']
|
||||
state["step_size"] = float(self.parameters['step_size'])
|
||||
|
||||
# Make and save summary figures
|
||||
dirname = os.getcwd()
|
||||
trace_plot(self.attributes, info.state.position, f'{dirname}/tuning_params.png', ncol=4)
|
||||
trace_plot(self.attributes, info.adaptation_state.inverse_mass_matrix, f'{dirname}/tuning_inverse_mass_matrix.png', ncol=4)
|
||||
trace_plot(['stepsize'], info.adaptation_state.step_size, f'{dirname}/tuning_step_size.png', ncol=4)
|
||||
|
||||
|
||||
nuts = blackjax.nuts(logdensity_fn, **self.parameters)
|
||||
|
|
|
@ -29,7 +29,7 @@ ares_heat = 1.0
|
|||
|
||||
[mcmc]
|
||||
number_to_generate = 15000
|
||||
warmup_model = 1000
|
||||
warmup_model = 500
|
||||
warmup_cosmo = 0
|
||||
random_ic = false
|
||||
init_random_scaling = 0.1
|
||||
|
@ -42,12 +42,12 @@ max_timesteps = 50
|
|||
mixing = 1
|
||||
|
||||
[sampling]
|
||||
algorithm = BlackJax
|
||||
algorithm = TransformedBlackJax
|
||||
epsilon = 0.005
|
||||
Nsteps = 20
|
||||
refresh = 0.1
|
||||
rng_seed = 1
|
||||
warmup_nsteps = 200
|
||||
warmup_nsteps = 400
|
||||
warmup_target_acceptance_rate = 0.7
|
||||
mvlam_mua = 0.01
|
||||
mvlam_alpha = 0.5
|
||||
|
|
Before Width: | Height: | Size: 24 KiB After Width: | Height: | Size: 24 KiB |
Before Width: | Height: | Size: 170 KiB After Width: | Height: | Size: 172 KiB |
Before Width: | Height: | Size: 100 KiB After Width: | Height: | Size: 100 KiB |
Before Width: | Height: | Size: 53 KiB After Width: | Height: | Size: 49 KiB |
Before Width: | Height: | Size: 94 KiB After Width: | Height: | Size: 84 KiB |
Before Width: | Height: | Size: 82 KiB After Width: | Height: | Size: 80 KiB |
BIN
figs/trace.png
Before Width: | Height: | Size: 359 KiB After Width: | Height: | Size: 334 KiB |
|
@ -16,7 +16,7 @@ set -e
|
|||
|
||||
# Path variables
|
||||
BORG=/data101/bartlett/build_borg/tools/hades_python/hades_python
|
||||
RUN_DIR=/data101/bartlett/fsigma8/borg_velocity/blackjax_model_ic_v2
|
||||
RUN_DIR=/data101/bartlett/fsigma8/borg_velocity/blackjax_model_ic_bf_prior
|
||||
|
||||
mkdir -p $RUN_DIR
|
||||
cd $RUN_DIR
|
||||
|
@ -30,6 +30,6 @@ set -x
|
|||
|
||||
# Just ICs
|
||||
INI_FILE=/home/bartlett/fsigma8/borg_velocity/conf/supranta_ini.ini
|
||||
cp $INI_FILE ini_file.ini
|
||||
$BORG INIT ini_file.ini
|
||||
# $BORG RESUME ini_file.ini
|
||||
# cp $INI_FILE ini_file.ini
|
||||
# $BORG INIT ini_file.ini
|
||||
$BORG RESUME ini_file.ini
|
|
@ -1,5 +1,5 @@
|
|||
#!/bin/bash
|
||||
#SBATCH --job-name=blackjax_model_ic_v2
|
||||
#SBATCH --job-name=blackjax_model_ic_bf_prior
|
||||
#SBATCH --nodes=1
|
||||
#SBATCH --exclusive
|
||||
#SBATCH --ntasks=40
|
||||
|
@ -28,7 +28,7 @@ set -e
|
|||
|
||||
# Path variables
|
||||
BORG=/data101/bartlett/build_borg/tools/hades_python/hades_python
|
||||
RUN_DIR=/data101/bartlett/fsigma8/borg_velocity/blackjax_model_ic_v2
|
||||
RUN_DIR=/data101/bartlett/fsigma8/borg_velocity/blackjax_model_ic_bf_prior
|
||||
|
||||
mkdir -p $RUN_DIR
|
||||
cd $RUN_DIR
|
||||
|
@ -42,9 +42,9 @@ set -x
|
|||
|
||||
# Run BORG
|
||||
INI_FILE=/home/bartlett/fsigma8/borg_velocity/conf/supranta_ini.ini
|
||||
cp $INI_FILE ini_file.ini
|
||||
$BORG INIT ini_file.ini
|
||||
# $BORG RESUME ini_file.ini
|
||||
# cp $INI_FILE ini_file.ini
|
||||
# $BORG INIT ini_file.ini
|
||||
$BORG RESUME ini_file.ini
|
||||
|
||||
conda deactivate
|
||||
|
||||
|
|