{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Using a calibrated flow model to predict $z_{\\rm cosmo}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "# Copyright (C) 2024 Richard Stiskalek\n",
    "# This program is free software; you can redistribute it and/or modify it\n",
    "# under the terms of the GNU General Public License as published by the\n",
    "# Free Software Foundation; either version 3 of the License, or (at your\n",
    "# option) any later version.\n",
    "#\n",
    "# This program is distributed in the hope that it will be useful, but\n",
    "# WITHOUT ANY WARRANTY; without even the implied warranty of\n",
    "# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General\n",
    "# Public License for more details.\n",
    "#\n",
    "# You should have received a copy of the GNU General Public License along\n",
    "# with this program; if not, write to the Free Software Foundation, Inc.,\n",
    "# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from h5py import File\n",
    "from tqdm import tqdm\n",
    "\n",
    "import csiborgtools\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "\n",
    "paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_calibration(catalogue, simname, nsim, ksmooth):\n",
    "    fname = f\"/mnt/extraspace/rstiskalek/csiborg_postprocessing/peculiar_velocity/flow_samples_{catalogue}_{simname}_smooth_{ksmooth}.hdf5\"  # noqa\n",
    "    keys = [\"Vext_x\", \"Vext_y\", \"Vext_z\", \"alpha\", \"beta\", \"sigma_v\"]\n",
    "\n",
    "    # SN_keys = ['mag_cal', 'alpha_cal', 'beta_cal']\n",
    "    # SN_keys = []\n",
    "    calibration_samples = {}\n",
    "    with File(fname, 'r') as f:\n",
    "        for key in keys:\n",
    "            calibration_samples[key] = f[f\"sim_{nsim}/{key}\"][:][::10]\n",
    "\n",
    "        # for key in SN_keys:\n",
    "        #     calibration_samples[key] = f[f\"sim_{nsim}/{key}\"][:]\n",
    "\n",
    "    return calibration_samples"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test running a model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 135,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/19 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10:32:19: reading the catalogue.\n",
      "10:32:19: reading the interpolated field.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/mnt/users/rstiskalek/csiborgtools/csiborgtools/flow/flow_model.py:113: UserWarning: The number of radial steps is even. Skipping the first step at 0.0 because Simpson's rule requires an odd number of steps.\n",
      "  warn(f\"The number of radial steps is even. Skipping the first \"\n",
      "  5%|▌         | 1/19 [00:00<00:05,  3.38it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10:32:19: calculating the radial velocity.\n",
      "10:32:20: reading the catalogue.\n",
      "10:32:20: reading the interpolated field.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 11%|█         | 2/19 [00:00<00:04,  3.43it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10:32:20: calculating the radial velocity.\n",
      "10:32:20: reading the catalogue.\n",
      "10:32:20: reading the interpolated field.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 16%|█▌        | 3/19 [00:00<00:04,  3.47it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10:32:20: calculating the radial velocity.\n",
      "10:32:20: reading the catalogue.\n",
      "10:32:20: reading the interpolated field.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 21%|██        | 4/19 [00:01<00:04,  3.44it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10:32:20: calculating the radial velocity.\n",
      "10:32:20: reading the catalogue.\n",
      "10:32:20: reading the interpolated field.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 26%|██▋       | 5/19 [00:01<00:04,  3.41it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10:32:21: calculating the radial velocity.\n",
      "10:32:21: reading the catalogue.\n",
      "10:32:21: reading the interpolated field.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 32%|███▏      | 6/19 [00:01<00:03,  3.44it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10:32:21: calculating the radial velocity.\n",
      "10:32:21: reading the catalogue.\n",
      "10:32:21: reading the interpolated field.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 37%|███▋      | 7/19 [00:02<00:03,  3.33it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10:32:21: calculating the radial velocity.\n",
      "10:32:21: reading the catalogue.\n",
      "10:32:21: reading the interpolated field.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 42%|████▏     | 8/19 [00:02<00:03,  3.31it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10:32:22: calculating the radial velocity.\n",
      "10:32:22: reading the catalogue.\n",
      "10:32:22: reading the interpolated field.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 47%|████▋     | 9/19 [00:02<00:03,  3.14it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10:32:22: calculating the radial velocity.\n",
      "10:32:22: reading the catalogue.\n",
      "10:32:22: reading the interpolated field.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 53%|█████▎    | 10/19 [00:03<00:02,  3.18it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10:32:22: calculating the radial velocity.\n",
      "10:32:22: reading the catalogue.\n",
      "10:32:22: reading the interpolated field.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 58%|█████▊    | 11/19 [00:03<00:02,  3.19it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10:32:22: calculating the radial velocity.\n",
      "10:32:23: reading the catalogue.\n",
      "10:32:23: reading the interpolated field.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 63%|██████▎   | 12/19 [00:03<00:02,  3.25it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10:32:23: calculating the radial velocity.\n",
      "10:32:23: reading the catalogue.\n",
      "10:32:23: reading the interpolated field.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 68%|██████▊   | 13/19 [00:03<00:01,  3.23it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10:32:23: calculating the radial velocity.\n",
      "10:32:23: reading the catalogue.\n",
      "10:32:23: reading the interpolated field.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 74%|███████▎  | 14/19 [00:04<00:01,  3.25it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10:32:23: calculating the radial velocity.\n",
      "10:32:23: reading the catalogue.\n",
      "10:32:23: reading the interpolated field.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 79%|███████▉  | 15/19 [00:04<00:01,  3.21it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10:32:24: calculating the radial velocity.\n",
      "10:32:24: reading the catalogue.\n",
      "10:32:24: reading the interpolated field.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 84%|████████▍ | 16/19 [00:04<00:00,  3.26it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10:32:24: calculating the radial velocity.\n",
      "10:32:24: reading the catalogue.\n",
      "10:32:24: reading the interpolated field.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 89%|████████▉ | 17/19 [00:05<00:00,  3.31it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10:32:24: calculating the radial velocity.\n",
      "10:32:24: reading the catalogue.\n",
      "10:32:24: reading the interpolated field.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 95%|█████████▍| 18/19 [00:05<00:00,  3.33it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10:32:25: calculating the radial velocity.\n",
      "10:32:25: reading the catalogue.\n",
      "10:32:25: reading the interpolated field.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 19/19 [00:05<00:00,  3.30it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10:32:25: calculating the radial velocity.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# fpath_data = \"/mnt/extraspace/rstiskalek/catalogs/PV_compilation_Supranta2019.hdf5\"\n",
    "fpath_data = \"/mnt/extraspace/rstiskalek/catalogs/PV_mock_CB2_17417_large.hdf5\"\n",
    "\n",
    "simname = \"csiborg2_main\"\n",
    "catalogue = \"CB2_large\"\n",
    "\n",
    "nsims = paths.get_ics(simname)[:-1]\n",
    "ksmooth = 1\n",
    "\n",
    "loaders = []\n",
    "models = []\n",
    "zcosmo_mean = None\n",
    "zobs = None\n",
    "\n",
    "for i, nsim in enumerate(tqdm(nsims)):\n",
    "    loader = csiborgtools.flow.DataLoader(simname, i, catalogue, fpath_data, paths, ksmooth=ksmooth)\n",
    "    calibration_samples = load_calibration(catalogue, simname, nsim, ksmooth)\n",
    "    model = csiborgtools.flow.Observed2CosmologicalRedshift(calibration_samples, loader.rdist, loader._Omega_m)\n",
    "\n",
    "    if i == 0:\n",
    "        zcosmo_mean = loader.cat[\"zcosmo\"]\n",
    "        zobs = loader.cat[\"zobs\"]\n",
    "        vrad = loader.cat[\"vrad\"]\n",
    "\n",
    "    loaders.append(loader)\n",
    "    models.append(model)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 143,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Stacking:   0%|          | 0/19 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Stacking: 100%|██████████| 19/19 [00:06<00:00,  3.06it/s]\n"
     ]
    }
   ],
   "source": [
    "n = 400\n",
    "zcosmo, pzcosmo = csiborgtools.flow.stack_pzosmo_over_realizations(\n",
    "    n, models, loaders, \"zobs\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 144,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure()\n",
    "\n",
    "# for i in range(len(nsims)):\n",
    "    # mask = pzcosmo[i] > 1e-5\n",
    "    # plt.plot(zcosmo[mask], pzcosmo[i][mask], color=\"black\", alpha=0.1)\n",
    "\n",
    "# mu = np.nanmean(pzcosmo, axis=0)\n",
    "mask = pzcosmo > 1e-5\n",
    "plt.plot(zcosmo[mask], pzcosmo[mask], color=\"black\", label=r\"$p(z_{\\rm cosmo})$\")\n",
    "\n",
    "plt.ylim(0)\n",
    "plt.axvline(zcosmo_mean[n], color=\"green\", label=r\"$z_{\\rm cosmo}$\")\n",
    "plt.axvline(zobs[n], color=\"red\", label=r\"$z_{\\rm CMB}$\")\n",
    "\n",
    "plt.xlabel(r\"$z$\")\n",
    "plt.ylabel(r\"$p(z)$\")\n",
    "plt.legend()\n",
    "plt.tight_layout()\n",
    "# plt.savefig(\"../plots/zcosmo_posterior_mock_example_B.png\", dpi=450)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv_csiborg",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}