{ "cells": [ { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Copyright (C) 2024 Richard Stiskalek\n", "# This program is free software; you can redistribute it and/or modify it\n", "# under the terms of the GNU General Public License as published by the\n", "# Free Software Foundation; either version 3 of the License, or (at your\n", "# option) any later version.\n", "#\n", "# This program is distributed in the hope that it will be useful, but\n", "# WITHOUT ANY WARRANTY; without even the implied warranty of\n", "# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General\n", "# Public License for more details.\n", "#\n", "# You should have received a copy of the GNU General Public License along\n", "# with this program; if not, write to the Free Software Foundation, Inc.,\n", "# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from tqdm import trange\n", "from joblib import dump\n", "from h5py import File\n", "\n", "import csiborgtools\n", "\n", "%matplotlib inline\n", "%load_ext autoreload\n", "%autoreload 2\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "fpath = \"/mnt/extraspace/rstiskalek/catalogs/PV_compilation.hdf5\"\n", "\n", "with File(fpath, 'r') as f:\n", " RA_2MTF = f[\"2MTF/RA\"][...]\n", " DEC_2MTF = f[\"2MTF/DEC\"][...]" ] }, { "cell_type": "code", "execution_count": 123, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/shared/python/3.11.7/lib/python3.11/pty.py:89: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", " pid, fd = os.forkpty()\n", " 0%| | 0/1 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure()\n", "plt.hist(mock[\"mu_calibration\"][0, m], bins=\"auto\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 126, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure()\n", "# plt.hist(mock[\"mu_TFR\"] - mock[\"mu_calibration\"][1])\n", "plt.scatter(mock[\"mu_true\"], mock[\"mu_calibration\"][0])\n", "plt.show()\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 127, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([125])" ] }, "execution_count": 127, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.sum(np.isfinite(mock[\"mu_calibration\"]), axis=1)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "m = np.isfinite(mock[\"mu_calibration\"])\n", "\n", "plt.figure()\n", "plt.hist(mock[\"mu_calibration\"][m], bins=\"auto\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 101, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "31.757872992956177" ] }, "execution_count": 101, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.percentile(mock[\"mu_true\"], 10)" ] }, { "cell_type": "code", "execution_count": 75, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure()\n", "plt.hist(mock[\"mu_true\"], bins=\"auto\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [], "source": [ "from jax import numpy as jnp\n", "from jax.scipy.special import logsumexp\n", "\n", "def normal_logpdf(x, loc, scale):\n", " \"\"\"Log of the normal probability density function.\"\"\"\n", " return (-0.5 * ((x - loc) / scale)**2\n", " - jnp.log(scale) - 0.5 * jnp.log(2 * jnp.pi))" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [], "source": [ "mu_true = np.copy(mock[\"mu_true\"])\n", "mu_calibration = np.copy(mock[\"mu_calibration\"])\n", "e_mu = np.copy(mock[\"e_mu_calibration\"])\n", "\n", "mu_calibration = np.stack([mu_calibration, mu_calibration])\n", "e_mu_calibration = np.stack([e_mu, e_mu])\n", "\n", "mu_calibration[0, 100:] = np.nan\n", "e_mu_calibration[0, 100:] = np.nan\n", "\n", "mu_calibration[1, 50:] = np.nan\n", "e_mu_calibration[1, 50:] = np.nan" ] }, { "cell_type": "code", "execution_count": 105, "metadata": {}, "outputs": [], "source": [ "\n", "h = 0.7\n", "\n", "# Now, the rest of the code except the calibration likelihood\n", "# uses the distance modulus in units of h\n", "mu_true_h = mu_true + 5 * jnp.log10(h)\n", "\n", "# Calculate the log-likelihood of the calibration, but the\n", "# shape is `(n_calibrators, n_data)`.\n", "ll_calibration = normal_logpdf(\n", " mu_calibration, mu_true[None, :],\n", " e_mu_calibration)\n", "\n", "# Create a mask for valid (non-NaN) log-likelihoods\n", "calibration_mask = ~jnp.isnan(ll_calibration)\n", "\n", "# Replace NaN values with zero (or another neutral value) for safety\n", "ll_calibration_clean = jnp.where(calibration_mask, ll_calibration, 0.0)\n", "\n", "# Count the number of valid calibrators for each galaxy (non-NaN entries)\n", "counts = jnp.sum(calibration_mask, axis=0)\n", "\n", "# Now apply logsumexp only to the valid log-likelihoods\n", "ll_calibration_sum = jnp.where(\n", " counts > 0,\n", " logsumexp(ll_calibration_clean, axis=0) - jnp.log(counts),\n", " 0.0 # Return zero likelihood if no valid calibrators\n", ")" ] }, { "cell_type": "code", "execution_count": 120, "metadata": {}, "outputs": [], "source": [ "from jax.lax import cond\n", "def ll_calibration(mu_calibration, mu_true, e_mu):\n", " # Use jnp.where to apply element-wise conditional logic\n", " return jnp.where(\n", " jnp.isfinite(mu_calibration), # Check for finite values\n", " normal_logpdf(mu_calibration, mu_true, e_mu), # Use valid values\n", " 0.0 # Return 0 for invalid (non-finite) values\n", " )\n" ] }, { "cell_type": "code", "execution_count": 121, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "33.15014277245038" ] }, "execution_count": 121, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mu_calibration[0, 0]" ] }, { "cell_type": "code", "execution_count": 127, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Array([[1.5757915, 1.3846258, 2.0763617, ..., 0. , 0. ,\n", " 0. ],\n", " [1.5757915, 1.3846258, 2.0763617, ..., 0. , 0. ,\n", " 0. ]], dtype=float32)" ] }, "execution_count": 127, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ll_calibration(mu_calibration, mu_true[None, :], e_mu)" ] }, { "cell_type": "code", "execution_count": 125, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([33.15014277, 33.36291487, 31.68284461, ..., nan,\n", " nan, nan]),\n", " array([33.10009268, 33.30408597, 31.68137438, ..., 31.3564011 ,\n", " 32.63194483, 33.72852162]))" ] }, "execution_count": 125, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mu_calibration[0], mu_true" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 69, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[1.5757915 1.3846258 2.0763617 ... nan nan nan]\n", " [1.5757915 1.3846258 2.0763617 ... nan nan nan]]\n", "\n", "[[1.5757915 1.3846258 2.0763617 ... -inf -inf -inf]\n", " [1.5757915 1.3846258 2.0763617 ... -inf -inf -inf]]\n", "\n", "[1.5757914 1.3846259 2.0763617 ... 0. 0. 0. ]\n", "\n" ] } ], "source": [ "ll = normal_logpdf(mu_calibration, mu_true[None, :], e_mu)\n", "\n", "print(ll)\n", "print()\n", "\n", "\n", "mask = ~jnp.isnan(ll)\n", "ll = jnp.where(mask, ll, -jnp.inf)\n", "\n", "print(ll)\n", "print()\n", "\n", "counts = jnp.sum(mask, axis=0)\n", "ll = jnp.where(counts > 0, logsumexp(ll, axis=0) - jnp.log(counts), 0.)\n", "\n", "print(ll)\n", "print()" ] }, { "cell_type": "code", "execution_count": 68, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Array(100, dtype=int32)" ] }, "execution_count": 68, "metadata": {}, "output_type": "execute_result" } ], "source": [ "jnp.sum(~jnp.isnan(ll))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "venv_csiborg", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.4" } }, "nbformat": 4, "nbformat_minor": 2 }