Control over Malmquist

This commit is contained in:
rstiskalek 2024-10-08 00:14:20 +01:00
parent 42f4044796
commit 5c9a8fc5a6

View file

@ -356,7 +356,8 @@ def e2_distmod_TFR(e2_mag, e2_eta, eta, b, c, e_mu_intrinsic):
def sample_TFR(e_mu_min, e_mu_max, a_mean, a_std, b_mean, b_std, def sample_TFR(e_mu_min, e_mu_max, a_mean, a_std, b_mean, b_std,
c_mean, c_std, alpha_min, alpha_max, sample_alpha, c_mean, c_std, alpha_min, alpha_max, sample_alpha,
a_dipole_mean, a_dipole_std, sample_a_dipole, name): a_dipole_mean, a_dipole_std, sample_a_dipole,
sample_curvature, name):
"""Sample Tully-Fisher calibration parameters.""" """Sample Tully-Fisher calibration parameters."""
e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max)) e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max))
a = sample(f"a_{name}", Normal(a_mean, a_std)) a = sample(f"a_{name}", Normal(a_mean, a_std))
@ -367,7 +368,11 @@ def sample_TFR(e_mu_min, e_mu_max, a_mean, a_std, b_mean, b_std,
ax, ay, az = 0.0, 0.0, 0.0 ax, ay, az = 0.0, 0.0, 0.0
b = sample(f"b_{name}", Normal(b_mean, b_std)) b = sample(f"b_{name}", Normal(b_mean, b_std))
c = sample(f"c_{name}", Normal(c_mean, c_std))
if sample_curvature:
c = sample(f"c_{name}", Normal(c_mean, c_std))
else:
c = 0.
alpha = sample_alpha_bias(name, alpha_min, alpha_max, sample_alpha) alpha = sample_alpha_bias(name, alpha_min, alpha_max, sample_alpha)
@ -513,12 +518,18 @@ class PV_LogLikelihood(BaseFlowValidationModel):
Whether to directly sample the distance without numerical Whether to directly sample the distance without numerical
marginalisation. in which case the tracers can be coupled by a marginalisation. in which case the tracers can be coupled by a
covariance matrix. By default `False`. covariance matrix. By default `False`.
with_homogeneous_malmquist : bool, optional
Whether to include the homogeneous Malmquist bias. By default `True`.
with_inhomogeneous_malmquist : bool, optional
Whether to include the inhomogeneous Malmquist bias. By default `True`.
""" """
def __init__(self, los_density, los_velocity, RA, dec, z_obs, e_zobs, def __init__(self, los_density, los_velocity, RA, dec, z_obs, e_zobs,
calibration_params, abs_calibration_params, mag_selection, calibration_params, abs_calibration_params, mag_selection,
r_xrange, Omega_m, kind, name, void_kwargs=None, r_xrange, Omega_m, kind, name, void_kwargs=None,
wo_num_dist_marginalisation=False): wo_num_dist_marginalisation=False,
with_homogeneous_malmquist=True,
with_inhomogeneous_malmquist=True):
if e_zobs is not None: if e_zobs is not None:
e2_cz_obs = jnp.asarray((SPEED_OF_LIGHT * e_zobs)**2) e2_cz_obs = jnp.asarray((SPEED_OF_LIGHT * e_zobs)**2)
else: else:
@ -557,6 +568,8 @@ class PV_LogLikelihood(BaseFlowValidationModel):
self.name = name self.name = name
self.Omega_m = Omega_m self.Omega_m = Omega_m
self.wo_num_dist_marginalisation = wo_num_dist_marginalisation self.wo_num_dist_marginalisation = wo_num_dist_marginalisation
self.with_homogeneous_malmquist = with_homogeneous_malmquist
self.with_inhomogeneous_malmquist = with_inhomogeneous_malmquist
self.norm = - self.ndata * jnp.log(self.num_sims) self.norm = - self.ndata * jnp.log(self.num_sims)
if mag_selection is not None: if mag_selection is not None:
@ -771,9 +784,14 @@ class PV_LogLikelihood(BaseFlowValidationModel):
"marginalising the true distance.") "marginalising the true distance.")
# Calculate p(r) (Malmquist bias). Shape is (ndata, nxrange) # Calculate p(r) (Malmquist bias). Shape is (ndata, nxrange)
log_ptilde = log_ptilde_wo_bias( if self.with_homogeneous_malmquist:
self.mu_xrange[None, :], mu[:, None], e2_mu[:, None], log_ptilde = log_ptilde_wo_bias(
self.log_r2_xrange[None, :]) self.mu_xrange[None, :], mu[:, None], e2_mu[:, None],
self.log_r2_xrange[None, :])
else:
log_ptilde = log_ptilde_wo_bias(
self.mu_xrange[None, :], mu[:, None], e2_mu[:, None],
0.)
if self.is_void_data: if self.is_void_data:
rLG = field_calibration_params["rLG"] rLG = field_calibration_params["rLG"]
@ -785,7 +803,9 @@ class PV_LogLikelihood(BaseFlowValidationModel):
# Inhomogeneous Malmquist bias. Shape: (nsims, ndata, nxrange) # Inhomogeneous Malmquist bias. Shape: (nsims, ndata, nxrange)
alpha = distmod_params["alpha"] alpha = distmod_params["alpha"]
log_ptilde = log_ptilde[None, ...] + alpha * log_los_density log_ptilde = log_ptilde[None, ...]
if self.with_inhomogeneous_malmquist:
log_ptilde += alpha * log_los_density
ptilde = jnp.exp(log_ptilde) ptilde = jnp.exp(log_ptilde)
# Normalization of p(r). Shape: (nsims, ndata) # Normalization of p(r). Shape: (nsims, ndata)
@ -840,11 +860,13 @@ class PV_LogLikelihood(BaseFlowValidationModel):
alpha = distmod_params["alpha"] alpha = distmod_params["alpha"]
# Normalisation of p(mu), shape is `(n_sims, n_data, n_rad)` # Normalisation of p(mu), shape is `(n_sims, n_data, n_rad)`
pnorm = ( pnorm = normal_logpdf(
+ self.log_r2_xrange[None, None, :] self.mu_xrange[None, :], mu[:, None], e_mu[:, None])[None, ...]
+ alpha * log_los_density_grid if self.with_homogeneous_malmquist:
+ normal_logpdf( pnorm += self.log_r2_xrange[None, None, :]
self.mu_xrange[None, :], mu[:, None], e_mu[:, None])[None, ...]) # noqa if self.with_inhomogeneous_malmquist:
pnorm += alpha * log_los_density_grid
pnorm = jnp.exp(pnorm) pnorm = jnp.exp(pnorm)
# Now integrate over the radial steps. Shape is `(nsims, ndata)`. # Now integrate over the radial steps. Shape is `(nsims, ndata)`.
# No Jacobian here because I integrate over distance, not the # No Jacobian here because I integrate over distance, not the
@ -855,11 +877,12 @@ class PV_LogLikelihood(BaseFlowValidationModel):
jac = jnp.abs(distmod2dist_gradient(mu_true, self.Omega_m)) jac = jnp.abs(distmod2dist_gradient(mu_true, self.Omega_m))
# Calculate unnormalized log p(mu). Shape is (nsims, ndata) # Calculate unnormalized log p(mu). Shape is (nsims, ndata)
ll = ( ll = 0.
+ jnp.log(jac)[None, :] if self.with_homogeneous_malmquist:
+ (2 * jnp.log(r_true) - self.log_r2_xrange_mean)[None, :] ll = (+ jnp.log(jac)
+ alpha * log_density + (2 * jnp.log(r_true) - self.log_r2_xrange_mean))
) if self.with_inhomogeneous_malmquist:
ll += alpha * log_density
# Subtract the normalization. Shape remains (nsims, ndata) # Subtract the normalization. Shape remains (nsims, ndata)
ll -= jnp.log(pnorm) ll -= jnp.log(pnorm)