Only update bulk flow prior when cosmology changed

This commit is contained in:
Deaglan Bartlett 2024-11-08 10:38:36 +01:00
parent cbcd1fce0f
commit 9c07a87b8b
13 changed files with 60 additions and 60 deletions

View file

@ -189,6 +189,8 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
self.fwd.setCosmoParams(cpar)
self.fwd_param.setCosmoParams(cpar)
self.sigma_bulk = get_sigma_bulk(self.L[0], self.fwd.getCosmoParams())
def generateMBData(self) -> None:
"""
@ -355,8 +357,7 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
# self.lkl_ind[i] = temp_lkl_ind.copy()
# Add in bulk flow prior
# sigma_bulk = get_sigma_bulk(self.L[0], self.fwd.getCosmoParams())
# lkl += jnp.sum(0.5 * jnp.log(2 * np.pi) + jnp.log(sigma_bulk) + self.bulk_flow ** 2 / 2. / sigma_bulk ** 2)
lkl += jnp.sum(0.5 * jnp.log(2 * np.pi) + jnp.log(self.sigma_bulk) + self.bulk_flow ** 2 / 2. / self.sigma_bulk ** 2)
# lkl = jnp.clip(lkl, -self.bignum, self.bignum)
lkl = lax.cond(

View file

@ -523,13 +523,19 @@ class TransformedBlackJaxBiasSampler(borg.samplers.PyBaseSampler):
adaptation_info_fn=get_info
)
rng_key, warmup_key, sample_key = jax.random.split(self.rng_key, 3)
(state, self.parameters), info = warmup.run(warmup_key, self.y, num_steps=self.warmup_nsteps)
(bj_state, self.parameters), info = warmup.run(warmup_key, self.y, num_steps=self.warmup_nsteps)
# x = info.state.position
self.y[:] = state.position
self.y[:] = bj_state.position
# Save results to the borg state object
state["inverse_mass_matrix"][:] = self.parameters['inverse_mass_matrix']
state["step_size"] = float(self.parameters['step_size'])
# Make and save summary figures
dirname = os.getcwd()
trace_plot(self.attributes, info.state.position, f'{dirname}/tuning_params.png', ncol=4)
trace_plot(self.attributes, info.adaptation_state.inverse_mass_matrix, f'{dirname}/tuning_inverse_mass_matrix.png', ncol=4)
trace_plot(['stepsize'], info.adaptation_state.step_size, f'{dirname}/tuning_step_size.png', ncol=4)
nuts = blackjax.nuts(logdensity_fn, **self.parameters)

View file

@ -29,7 +29,7 @@ ares_heat = 1.0
[mcmc]
number_to_generate = 15000
warmup_model = 1000
warmup_model = 500
warmup_cosmo = 0
random_ic = false
init_random_scaling = 0.1
@ -42,12 +42,12 @@ max_timesteps = 50
mixing = 1
[sampling]
algorithm = BlackJax
algorithm = TransformedBlackJax
epsilon = 0.005
Nsteps = 20
refresh = 0.1
rng_seed = 1
warmup_nsteps = 200
warmup_nsteps = 400
warmup_target_acceptance_rate = 0.7
mvlam_mua = 0.01
mvlam_alpha = 0.5

Binary file not shown.

Before

Width:  |  Height:  |  Size: 24 KiB

After

Width:  |  Height:  |  Size: 24 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 170 KiB

After

Width:  |  Height:  |  Size: 172 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 100 KiB

After

Width:  |  Height:  |  Size: 100 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 53 KiB

After

Width:  |  Height:  |  Size: 49 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 94 KiB

After

Width:  |  Height:  |  Size: 84 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 82 KiB

After

Width:  |  Height:  |  Size: 80 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 359 KiB

After

Width:  |  Height:  |  Size: 334 KiB

File diff suppressed because one or more lines are too long

View file

@ -16,7 +16,7 @@ set -e
# Path variables
BORG=/data101/bartlett/build_borg/tools/hades_python/hades_python
RUN_DIR=/data101/bartlett/fsigma8/borg_velocity/blackjax_model_ic_v2
RUN_DIR=/data101/bartlett/fsigma8/borg_velocity/blackjax_model_ic_bf_prior
mkdir -p $RUN_DIR
cd $RUN_DIR
@ -30,6 +30,6 @@ set -x
# Just ICs
INI_FILE=/home/bartlett/fsigma8/borg_velocity/conf/supranta_ini.ini
cp $INI_FILE ini_file.ini
$BORG INIT ini_file.ini
# $BORG RESUME ini_file.ini
# cp $INI_FILE ini_file.ini
# $BORG INIT ini_file.ini
$BORG RESUME ini_file.ini

View file

@ -1,5 +1,5 @@
#!/bin/bash
#SBATCH --job-name=blackjax_model_ic_v2
#SBATCH --job-name=blackjax_model_ic_bf_prior
#SBATCH --nodes=1
#SBATCH --exclusive
#SBATCH --ntasks=40
@ -28,7 +28,7 @@ set -e
# Path variables
BORG=/data101/bartlett/build_borg/tools/hades_python/hades_python
RUN_DIR=/data101/bartlett/fsigma8/borg_velocity/blackjax_model_ic_v2
RUN_DIR=/data101/bartlett/fsigma8/borg_velocity/blackjax_model_ic_bf_prior
mkdir -p $RUN_DIR
cd $RUN_DIR
@ -42,9 +42,9 @@ set -x
# Run BORG
INI_FILE=/home/bartlett/fsigma8/borg_velocity/conf/supranta_ini.ini
cp $INI_FILE ini_file.ini
$BORG INIT ini_file.ini
# $BORG RESUME ini_file.ini
# cp $INI_FILE ini_file.ini
# $BORG INIT ini_file.ini
$BORG RESUME ini_file.ini
conda deactivate