mirror of
https://github.com/Richard-Sti/csiborgtools_public.git
synced 2025-05-13 14:11:11 +00:00
* Rename nb * Add Carrick 2MTF mocks * Add more 2MTF mock support * Add mocks generator * Add mock gen nb * Control over MAlmquist * Control over Malmquist * Update imports * Update script * Add H0 to TFR mocks * Clear up saving * Add h sampling * Update saving * Update mocks * Add calibration to mocks * More support to read in absmag * Add absmag * Update script
73 KiB
73 KiB
In [2]:
# Copyright (C) 2024 Richard Stiskalek
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import numpy as np
import matplotlib.pyplot as plt
from tqdm import trange
from joblib import dump
from h5py import File
import csiborgtools
%matplotlib inline
%load_ext autoreload
%autoreload 2
In [3]:
fpath = "/mnt/extraspace/rstiskalek/catalogs/PV_compilation.hdf5"
with File(fpath, 'r') as f:
RA_2MTF = f["2MTF/RA"][...]
DEC_2MTF = f["2MTF/DEC"][...]
In [123]:
!rm /mnt/extraspace/rstiskalek/csiborg_postprocessing/flow_mock/*
reader = csiborgtools.read.Carrick2015Field()
velocity = reader.velocity_field()
boxsize = csiborgtools.simname2boxsize("Carrick2015")
mean_mag = 10.31
std_mag = 0.83
h = 0.7
nrepeat_calibration = 1
kmax = 1
seed = 53
for k in trange(kmax):
mock, truths = csiborgtools.flow.mock_Carrick2MTF(
velocity, boxsize, RA_2MTF, DEC_2MTF,
mean_mag=mean_mag, std_mag=std_mag, h=h,
nrepeat_calibration=nrepeat_calibration,
calibration_max_percentile=10, calibration_rand_fraction=1.,
seed=seed, verbose=True)
seed += 1
fname = f"/mnt/extraspace/rstiskalek/csiborg_postprocessing/flow_mock/Carrick2MTFmock_seed{k}.hdf5" # noqa
print("Saving to ", fname)
with File(fname, 'w') as f:
for key in mock.keys():
f.create_dataset(key, data=mock[key])
fname_truths = fname.replace(".hdf5", "_truths.pkl")
dump(truths, fname_truths)
In [124]:
m = np.isfinite(mock["mu_calibration"][0])
In [125]:
plt.figure()
plt.hist(mock["mu_calibration"][0, m], bins="auto")
plt.show()
In [126]:
plt.figure()
# plt.hist(mock["mu_TFR"] - mock["mu_calibration"][1])
plt.scatter(mock["mu_true"], mock["mu_calibration"][0])
plt.show()
In [ ]:
In [127]:
np.sum(np.isfinite(mock["mu_calibration"]), axis=1)
Out[127]:
In [10]:
m = np.isfinite(mock["mu_calibration"])
plt.figure()
plt.hist(mock["mu_calibration"][m], bins="auto")
plt.show()
In [101]:
np.percentile(mock["mu_true"], 10)
Out[101]:
In [75]:
plt.figure()
plt.hist(mock["mu_true"], bins="auto")
plt.show()
In [57]:
from jax import numpy as jnp
from jax.scipy.special import logsumexp
def normal_logpdf(x, loc, scale):
"""Log of the normal probability density function."""
return (-0.5 * ((x - loc) / scale)**2
- jnp.log(scale) - 0.5 * jnp.log(2 * jnp.pi))
In [44]:
mu_true = np.copy(mock["mu_true"])
mu_calibration = np.copy(mock["mu_calibration"])
e_mu = np.copy(mock["e_mu_calibration"])
mu_calibration = np.stack([mu_calibration, mu_calibration])
e_mu_calibration = np.stack([e_mu, e_mu])
mu_calibration[0, 100:] = np.nan
e_mu_calibration[0, 100:] = np.nan
mu_calibration[1, 50:] = np.nan
e_mu_calibration[1, 50:] = np.nan
In [105]:
h = 0.7
# Now, the rest of the code except the calibration likelihood
# uses the distance modulus in units of h
mu_true_h = mu_true + 5 * jnp.log10(h)
# Calculate the log-likelihood of the calibration, but the
# shape is `(n_calibrators, n_data)`.
ll_calibration = normal_logpdf(
mu_calibration, mu_true[None, :],
e_mu_calibration)
# Create a mask for valid (non-NaN) log-likelihoods
calibration_mask = ~jnp.isnan(ll_calibration)
# Replace NaN values with zero (or another neutral value) for safety
ll_calibration_clean = jnp.where(calibration_mask, ll_calibration, 0.0)
# Count the number of valid calibrators for each galaxy (non-NaN entries)
counts = jnp.sum(calibration_mask, axis=0)
# Now apply logsumexp only to the valid log-likelihoods
ll_calibration_sum = jnp.where(
counts > 0,
logsumexp(ll_calibration_clean, axis=0) - jnp.log(counts),
0.0 # Return zero likelihood if no valid calibrators
)
In [120]:
from jax.lax import cond
def ll_calibration(mu_calibration, mu_true, e_mu):
# Use jnp.where to apply element-wise conditional logic
return jnp.where(
jnp.isfinite(mu_calibration), # Check for finite values
normal_logpdf(mu_calibration, mu_true, e_mu), # Use valid values
0.0 # Return 0 for invalid (non-finite) values
)
In [121]:
mu_calibration[0, 0]
Out[121]:
In [127]:
ll_calibration(mu_calibration, mu_true[None, :], e_mu)
Out[127]:
In [125]:
mu_calibration[0], mu_true
Out[125]:
In [ ]:
In [ ]:
In [69]:
ll = normal_logpdf(mu_calibration, mu_true[None, :], e_mu)
print(ll)
print()
mask = ~jnp.isnan(ll)
ll = jnp.where(mask, ll, -jnp.inf)
print(ll)
print()
counts = jnp.sum(mask, axis=0)
ll = jnp.where(counts > 0, logsumexp(ll, axis=0) - jnp.log(counts), 0.)
print(ll)
print()
In [68]:
jnp.sum(~jnp.isnan(ll))
Out[68]:
In [ ]: