diff --git a/tests/TFR_tests.ipynb b/tests/TFR_tests.ipynb index 8476acd..ef71488 100644 --- a/tests/TFR_tests.ipynb +++ b/tests/TFR_tests.ipynb @@ -518,10 +518,57 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "67759f3d-0cb3-4771-bb07-f3d5bfc6848c", "metadata": {}, "outputs": [], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "9c0b73c6-70d7-4ba7-9c6f-e98a53a1358a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([107., 113., 94., 108., 100., 93., 110., 95., 86., 94.]),\n", + " array([-1.99562025, -1.59607887, -1.19653749, -0.79699612, -0.39745474,\n", + " 0.00208664, 0.40162802, 0.8011694 , 1.20071077, 1.60025215,\n", + " 1.99979353]),\n", + " )" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAGdCAYAAAA44ojeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAfM0lEQVR4nO3df3ST5f3/8VekEFrXRoGRtLNCdZ24VR1WBSpKFakiMh1n/oJxcFMPyg9XccMi28fiObbQae0ZnTA8HmRjqOdMcO6gjm5C0VOcBetUmLDNAlXoOl2XVmCt0Ov7h9/mGFoKwaR5pz4f5+Qcc+fKzXVxqX2eO0njcc45AQAAGHJKvCcAAABwNAIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5iTFewIno6OjQ/v27VNqaqo8Hk+8pwMAAE6Ac06tra3KyMjQKaf0fI0kIQNl3759yszMjPc0AADASWhoaNAZZ5zR45iEDJTU1FRJny0wLS0tzrMBAAAnoqWlRZmZmaGf4z1JyEDpfFknLS2NQAEAIMGcyNszeJMsAAAwh0ABAADmECgAAMAcAgUAAJhDoAAAAHMIFAAAYA6BAgAAzCFQAACAOQQKAAAwh0ABAADmECgAAMAcAgUAAJhDoAAAAHMIFAAAYE5SvCeA6BhetD7eU4jY7sWT4j0FAIBRXEEBAADmECgAAMAcAgUAAJhDoAAAAHMIFAAAYA6BAgAAzCFQAACAOQQKAAAwh0ABAADm8JtkASAKEvG3OUv8RmfYxRUUAABgDoECAADMIVAAAIA5BAoAADCHQAEAAObwKR4gAon4SQ0+pQEgEXEFBQAAmEOgAAAAc3iJpxuJeBkfAIC+hCsoAADAHAIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAHAIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAnIgDZfPmzZo8ebIyMjLk8Xj0/PPPhz3unFNxcbEyMjKUnJys/Px8bd++PWxMW1ub5s6dqyFDhujUU0/Vd77zHX3wwQdfaCEAAKDviDhQDhw4oAsuuECVlZXdPl5WVqby8nJVVlaqtrZWgUBAEyZMUGtra2hMYWGh1q1bp2eeeUavvfaaPvnkE1133XU6cuTIya8EAAD0GUmRPmHixImaOHFit48551RRUaGFCxdqypQpkqRVq1bJ7/drzZo1mjlzpoLBoJ588kn95je/0VVXXSVJWr16tTIzM/WnP/1JV1999RdYDgAA6Aui+h6U+vp6NTY2qqCgIHTM6/Vq3LhxqqmpkSRt27ZNn376adiYjIwM5eTkhMYcra2tTS0tLWE3AADQd0V8BaUnjY2NkiS/3x923O/3a8+ePaExAwYM0Omnn95lTOfzj1ZaWqpFixZFc6rAl8bwovXxnkLEdi+eFO8pAIizmHyKx+PxhN13znU5drSexixYsEDBYDB0a2hoiNpcAQCAPVENlEAgIEldroQ0NTWFrqoEAgG1t7erubn5mGOO5vV6lZaWFnYDAAB9V1QDJSsrS4FAQFVVVaFj7e3tqq6uVl5eniQpNzdX/fv3Dxuzf/9+vfvuu6ExAADgyy3i96B88skn+sc//hG6X19fr7feekuDBg3SmWeeqcLCQpWUlCg7O1vZ2dkqKSlRSkqKpk6dKkny+Xy6/fbbdd9992nw4MEaNGiQfvzjH+u8884LfaoHAAB8uUUcKFu3btUVV1wRuj9v3jxJ0owZM/TUU09p/vz5OnTokGbNmqXm5maNGjVKGzZsUGpqaug5jz32mJKSknTTTTfp0KFDGj9+vJ566in169cvCksCAACJLuJAyc/Pl3PumI97PB4VFxeruLj4mGMGDhyopUuXaunSpZH+8QAA4EuA7+IBAADmECgAAMAcAgUAAJhDoAAAAHMIFAAAYA6BAgAAzCFQAACAOQQKAAAwJ+Jf1AYA6DuGF62P9xQitnvxpHhPAb2AKygAAMAcAgUAAJjDSzyIm0S8tAwAJyMR/38X75fSuIICAADMIVAAAIA5BAoAADCH96AAMCcRX68HEF1cQQEAAOYQKAAAwBwCBQAAmEOgAAAAcwgUAABgDoECAADMIVAAAIA5BAoAADCHQAEAAOYQKAAAwBwCBQAAmEOgAAAAcwgUAABgDoECAADMIVAAAIA5BAoAADCHQAEAAOYQKAAAwBwCBQAAmEOgAAAAc5LiPQEAACIxvGh9vKeAXsAVFAAAYA6BAgAAzCFQAACAOQQKAAAwh0ABAADmECgAAMAcAgUAAJhDoAAAAHMIFAAAYA6BAgAAzCFQAACAOQQKAAAwh0ABAADmECgAAMCcqAfK4cOH9dOf/lRZWVlKTk7WWWedpYceekgdHR2hMc45FRcXKyMjQ8nJycrPz9f27dujPRUAAJCgoh4oS5Ys0fLly1VZWam//e1vKisr089//nMtXbo0NKasrEzl5eWqrKxUbW2tAoGAJkyYoNbW1mhPBwAAJKCoB8qWLVt0/fXXa9KkSRo+fLi+973vqaCgQFu3bpX02dWTiooKLVy4UFOmTFFOTo5WrVqlgwcPas2aNdGeDgAASEBRD5SxY8fqz3/+s3bt2iVJ+utf/6rXXntN1157rSSpvr5ejY2NKigoCD3H6/Vq3Lhxqqmp6facbW1tamlpCbsBAIC+KynaJ7z//vsVDAY1YsQI9evXT0eOHNHDDz+sW2+9VZLU2NgoSfL7/WHP8/v92rNnT7fnLC0t1aJFi6I9VQAAYFTUr6A8++yzWr16tdasWaM333xTq1at0iOPPKJVq1aFjfN4PGH3nXNdjnVasGCBgsFg6NbQ0BDtaQMAAEOifgXlJz/5iYqKinTLLbdIks477zzt2bNHpaWlmjFjhgKBgKTPrqSkp6eHntfU1NTlqkonr9crr9cb7akCAACjon4F5eDBgzrllPDT9uvXL/Qx46ysLAUCAVVVVYUeb29vV3V1tfLy8qI9HQAAkICifgVl8uTJevjhh3XmmWfqW9/6lurq6lReXq4f/vCHkj57aaewsFAlJSXKzs5Wdna2SkpKlJKSoqlTp0Z7OgAAIAFFPVCWLl2qn/3sZ5o1a5aampqUkZGhmTNn6v/+7/9CY+bPn69Dhw5p1qxZam5u1qhRo7RhwwalpqZGezoAACABeZxzLt6TiFRLS4t8Pp+CwaDS0tKifv7hReujfk4AABLJ7sWTon7OSH5+8108AADAHAIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAHAIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAHAIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAHAIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAHAIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAHAIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAHAIFAACYQ6AAAABzYhIoH374ob7//e9r8ODBSklJ0be//W1t27Yt9LhzTsXFxcrIyFBycrLy8/O1ffv2WEwFAAAkoKgHSnNzsy699FL1799fL730knbs2KFHH31Up512WmhMWVmZysvLVVlZqdraWgUCAU2YMEGtra3Rng4AAEhASdE+4ZIlS5SZmamVK1eGjg0fPjz0z845VVRUaOHChZoyZYokadWqVfL7/VqzZo1mzpwZ7SkBAIAEE/UrKC+88IIuuugi3XjjjRo6dKhGjhypJ554IvR4fX29GhsbVVBQEDrm9Xo1btw41dTUdHvOtrY2tbS0hN0AAEDfFfVAef/997Vs2TJlZ2frj3/8o+666y7dc889+vWvfy1JamxslCT5/f6w5/n9/tBjRystLZXP5wvdMjMzoz1tAABgSNQDpaOjQxdeeKFKSko0cuRIzZw5U3feeaeWLVsWNs7j8YTdd851OdZpwYIFCgaDoVtDQ0O0pw0AAAyJeqCkp6frm9/8Ztixc889V3v37pUkBQIBSepytaSpqanLVZVOXq9XaWlpYTcAANB3RT1QLr30Uu3cuTPs2K5duzRs2DBJUlZWlgKBgKqqqkKPt7e3q7q6Wnl5edGeDgAASEBR/xTPvffeq7y8PJWUlOimm27SG2+8oRUrVmjFihWSPntpp7CwUCUlJcrOzlZ2drZKSkqUkpKiqVOnRns6AAAgAUU9UC6++GKtW7dOCxYs0EMPPaSsrCxVVFRo2rRpoTHz58/XoUOHNGvWLDU3N2vUqFHasGGDUlNToz0dAACQgDzOORfvSUSqpaVFPp9PwWAwJu9HGV60PurnBAAgkexePCnq54zk5zffxQMAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAHAIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAHAIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAHAIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAHAIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAHAIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAHAIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDkECgAAMCfmgVJaWiqPx6PCwsLQMeeciouLlZGRoeTkZOXn52v79u2xngoAAEgQMQ2U2tparVixQueff37Y8bKyMpWXl6uyslK1tbUKBAKaMGGCWltbYzkdAACQIGIWKJ988ommTZumJ554QqeffnrouHNOFRUVWrhwoaZMmaKcnBytWrVKBw8e1Jo1a2I1HQAAkEBiFiizZ8/WpEmTdNVVV4Udr6+vV2NjowoKCkLHvF6vxo0bp5qamlhNBwAAJJCkWJz0mWee0Ztvvqna2toujzU2NkqS/H5/2HG/3689e/Z0e762tja1tbWF7re0tERxtgAAwJqoX0FpaGjQj370I61evVoDBw485jiPxxN23znX5Vin0tJS+Xy+0C0zMzOqcwYAALZEPVC2bdumpqYm5ebmKikpSUlJSaqurtYvfvELJSUlha6cdF5J6dTU1NTlqkqnBQsWKBgMhm4NDQ3RnjYAADAk6i/xjB8/Xu+8807YsR/84AcaMWKE7r//fp111lkKBAKqqqrSyJEjJUnt7e2qrq7WkiVLuj2n1+uV1+uN9lQBAIBRUQ+U1NRU5eTkhB079dRTNXjw4NDxwsJClZSUKDs7W9nZ2SopKVFKSoqmTp0a7ekAAIAEFJM3yR7P/PnzdejQIc2aNUvNzc0aNWqUNmzYoNTU1HhMBwAAGONxzrl4TyJSLS0t8vl8CgaDSktLi/r5hxetj/o5AQBIJLsXT4r6OSP5+c138QAAAHMIFAAAYA6BAgAAzCFQAACAOQQKAAAwh0ABAADmECgAAMAcAgUAAJhDoAAAAHMIFAAAYA6BAgAAzCFQAACAOQQKAAAwh0ABAADmECgAAMAcAgUAAJhDoAAAAHMIFAAAYA6BAgAAzCFQAACAOQQKAAAwh0ABAADmECgAAMAcAgUAAJhDoAAAAHMIFAAAYA6BAgAAzCFQAACAOQQKAAAwh0ABAADmECgAAMAcAgUAAJhDoAAAAHMIFAAAYA6BAgAAzCFQAACAOQQKAAAwh0ABAADmECgAAMAcAgUAAJhDoAAAAHMIFAAAYA6BAgAAzCFQAACAOQQKAAAwh0ABAADmECgAAMAcAgUAAJhDoAAAAHMIFAAAYA6BAgAAzCFQAACAOVEPlNLSUl188cVKTU3V0KFDdcMNN2jnzp1hY5xzKi4uVkZGhpKTk5Wfn6/t27dHeyoAACBBRT1QqqurNXv2bL3++uuqqqrS4cOHVVBQoAMHDoTGlJWVqby8XJWVlaqtrVUgENCECRPU2toa7ekAAIAElBTtE7788sth91euXKmhQ4dq27Ztuvzyy+WcU0VFhRYuXKgpU6ZIklatWiW/3681a9Zo5syZ0Z4SAABIMDF/D0owGJQkDRo0SJJUX1+vxsZGFRQUhMZ4vV6NGzdONTU13Z6jra1NLS0tYTcAANB3xTRQnHOaN2+exo4dq5ycHElSY2OjJMnv94eN9fv9oceOVlpaKp/PF7plZmbGctoAACDOYhooc+bM0dtvv62nn366y2MejyfsvnOuy7FOCxYsUDAYDN0aGhpiMl8AAGBD1N+D0mnu3Ll64YUXtHnzZp1xxhmh44FAQNJnV1LS09NDx5uamrpcVenk9Xrl9XpjNVUAAGBM1K+gOOc0Z84crV27Vq+88oqysrLCHs/KylIgEFBVVVXoWHt7u6qrq5WXlxft6QAAgAQU9Ssos2fP1po1a/T73/9eqampofeV+Hw+JScny+PxqLCwUCUlJcrOzlZ2drZKSkqUkpKiqVOnRns6AAAgAUU9UJYtWyZJys/PDzu+cuVK3XbbbZKk+fPn69ChQ5o1a5aam5s1atQobdiwQampqdGeDgAASEBRDxTn3HHHeDweFRcXq7i4ONp/PAAA6AP4Lh4AAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAHAIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAHAIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAHAIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAHAIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAHAIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAHAIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDlxDZTHH39cWVlZGjhwoHJzc/Xqq6/GczoAAMCIuAXKs88+q8LCQi1cuFB1dXW67LLLNHHiRO3duzdeUwIAAEbELVDKy8t1++2364477tC5556riooKZWZmatmyZfGaEgAAMCIpHn9oe3u7tm3bpqKiorDjBQUFqqmp6TK+ra1NbW1tofvBYFCS1NLSEpP5dbQdjMl5AQBIFLH4Gdt5TufcccfGJVA++ugjHTlyRH6/P+y43+9XY2Njl/GlpaVatGhRl+OZmZkxmyMAAF9mvorYnbu1tVU+n6/HMXEJlE4ejyfsvnOuyzFJWrBggebNmxe639HRof/85z8aPHhwt+NPVktLizIzM9XQ0KC0tLSondeSvr7Gvr4+qe+vsa+vT+r7a2R9iS9Wa3TOqbW1VRkZGccdG5dAGTJkiPr169flaklTU1OXqyqS5PV65fV6w46ddtppMZtfWlpan/2XrlNfX2NfX5/U99fY19cn9f01sr7EF4s1Hu/KSae4vEl2wIABys3NVVVVVdjxqqoq5eXlxWNKAADAkLi9xDNv3jxNnz5dF110kcaMGaMVK1Zo7969uuuuu+I1JQAAYETcAuXmm2/Wxx9/rIceekj79+9XTk6OXnzxRQ0bNixeU5LX69WDDz7Y5eWkvqSvr7Gvr0/q+2vs6+uT+v4aWV/is7BGjzuRz/oAAAD0Ir6LBwAAmEOgAAAAcwgUAABgDoECAADM+VIHyu7du3X77bcrKytLycnJOvvss/Xggw+qvb29x+c551RcXKyMjAwlJycrPz9f27dv76VZR+bhhx9WXl6eUlJSTviX2912223yeDxht9GjR8d2ol/AyawxkfawublZ06dPl8/nk8/n0/Tp0/Xf//63x+dY38PHH39cWVlZGjhwoHJzc/Xqq6/2OL66ulq5ubkaOHCgzjrrLC1fvryXZnpyIlnfpk2buuyVx+PRe++914szPnGbN2/W5MmTlZGRIY/Ho+eff/64z0m0/Yt0jYm2h6Wlpbr44ouVmpqqoUOH6oYbbtDOnTuP+7ze3scvdaC899576ujo0K9+9Stt375djz32mJYvX64HHnigx+eVlZWpvLxclZWVqq2tVSAQ0IQJE9Ta2tpLMz9x7e3tuvHGG3X33XdH9LxrrrlG+/fvD91efPHFGM3wizuZNSbSHk6dOlVvvfWWXn75Zb388st66623NH369OM+z+oePvvssyosLNTChQtVV1enyy67TBMnTtTevXu7HV9fX69rr71Wl112merq6vTAAw/onnvu0XPPPdfLMz8xka6v086dO8P2Kzs7u5dmHJkDBw7oggsuUGVl5QmNT7T9kyJfY6dE2cPq6mrNnj1br7/+uqqqqnT48GEVFBTowIEDx3xOXPbRIUxZWZnLyso65uMdHR0uEAi4xYsXh47973//cz6fzy1fvrw3pnhSVq5c6Xw+3wmNnTFjhrv++utjOp9YONE1JtIe7tixw0lyr7/+eujYli1bnCT33nvvHfN5lvfwkksucXfddVfYsREjRriioqJux8+fP9+NGDEi7NjMmTPd6NGjYzbHLyLS9W3cuNFJcs3Nzb0wu+iS5NatW9fjmETbv6OdyBoTeQ+dc66pqclJctXV1cccE499/FJfQelOMBjUoEGDjvl4fX29GhsbVVBQEDrm9Xo1btw41dTU9MYUe8WmTZs0dOhQfeMb39Cdd96ppqameE8pahJpD7ds2SKfz6dRo0aFjo0ePVo+n++4c7W4h+3t7dq2bVvY370kFRQUHHM9W7Zs6TL+6quv1tatW/Xpp5/GbK4n42TW12nkyJFKT0/X+PHjtXHjxlhOs1cl0v59UYm6h8FgUJJ6/NkXj30kUD7nn//8p5YuXdrjr9vv/ILDo7/U0O/3d/nyw0Q1ceJE/fa3v9Urr7yiRx99VLW1tbryyivV1tYW76lFRSLtYWNjo4YOHdrl+NChQ3ucq9U9/Oijj3TkyJGI/u4bGxu7HX/48GF99NFHMZvryTiZ9aWnp2vFihV67rnntHbtWp1zzjkaP368Nm/e3BtTjrlE2r+Tlch76JzTvHnzNHbsWOXk5BxzXDz2sU8GSnFxcbdvWPr8bevWrWHP2bdvn6655hrdeOONuuOOO477Z3g8nrD7zrkux2LlZNYXiZtvvlmTJk1STk6OJk+erJdeekm7du3S+vXro7iKnsV6jVLi7GF3czreXC3sYU8i/bvvbnx3x62IZH3nnHOO7rzzTl144YUaM2aMHn/8cU2aNEmPPPJIb0y1VyTa/kUqkfdwzpw5evvtt/X0008fd2xv72PcvosnlubMmaNbbrmlxzHDhw8P/fO+fft0xRVXhL60sCeBQEDSZzWZnp4eOt7U1NSlLmMl0vV9Uenp6Ro2bJj+/ve/R+2cxxPLNSbSHr799tv617/+1eWxf//73xHNNR572J0hQ4aoX79+Xa4m9PR3HwgEuh2flJSkwYMHx2yuJ+Nk1ted0aNHa/Xq1dGeXlwk0v5FUyLs4dy5c/XCCy9o8+bNOuOMM3ocG4997JOBMmTIEA0ZMuSExn744Ye64oorlJubq5UrV+qUU3q+qJSVlaVAIKCqqiqNHDlS0mevO1dXV2vJkiVfeO4nIpL1RcPHH3+shoaGsB/msRbLNSbSHo4ZM0bBYFBvvPGGLrnkEknSX/7yFwWDQeXl5Z3wnxePPezOgAEDlJubq6qqKn33u98NHa+qqtL111/f7XPGjBmjP/zhD2HHNmzYoIsuukj9+/eP6XwjdTLr605dXV3c9ypaEmn/osnyHjrnNHfuXK1bt06bNm1SVlbWcZ8Tl32M2dtvE8CHH37ovv71r7srr7zSffDBB27//v2h2+edc845bu3ataH7ixcvdj6fz61du9a988477tZbb3Xp6emupaWlt5dwXHv27HF1dXVu0aJF7itf+Yqrq6tzdXV1rrW1NTTm8+trbW119913n6upqXH19fVu48aNbsyYMe5rX/uayfU5F/kanUusPbzmmmvc+eef77Zs2eK2bNnizjvvPHfdddeFjUmkPXzmmWdc//793ZNPPul27NjhCgsL3amnnup2797tnHOuqKjITZ8+PTT+/fffdykpKe7ee+91O3bscE8++aTr37+/+93vfhevJfQo0vU99thjbt26dW7Xrl3u3XffdUVFRU6Se+655+K1hB61traG/huT5MrLy11dXZ3bs2ePcy7x98+5yNeYaHt49913O5/P5zZt2hT2c+/gwYOhMRb28UsdKCtXrnSSur19niS3cuXK0P2Ojg734IMPukAg4Lxer7v88svdO++808uzPzEzZszodn0bN24Mjfn8+g4ePOgKCgrcV7/6Vde/f3935plnuhkzZri9e/fGZwEnINI1OpdYe/jxxx+7adOmudTUVJeamuqmTZvW5eOMibaHv/zlL92wYcPcgAED3IUXXhj28cYZM2a4cePGhY3ftGmTGzlypBswYIAbPny4W7ZsWS/PODKRrG/JkiXu7LPPdgMHDnSnn366Gzt2rFu/fn0cZn1iOj9Se/RtxowZzrm+sX+RrjHR9vBYP/c+//9IC/vo+f+TBQAAMKNPfooHAAAkNgIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAHAIFAACYQ6AAAABzCBQAAGDO/wO5Pal7plnwCQAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "a = np.loadtxt('a.txt')\n", + "plt.hist(a)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aeb7135d-0fc7-43c1-81c9-1fd32e5f05d6", + "metadata": {}, + "outputs": [], "source": [] } ], diff --git a/tests/corner.png b/tests/corner.png index 0bb7ee4..06eb473 100644 Binary files a/tests/corner.png and b/tests/corner.png differ diff --git a/tests/likelihood_scan_a_TFR.png b/tests/likelihood_scan_a_TFR.png index 47ef585..7374fdf 100644 Binary files a/tests/likelihood_scan_a_TFR.png and b/tests/likelihood_scan_a_TFR.png differ diff --git a/tests/likelihood_scan_alpha.png b/tests/likelihood_scan_alpha.png index cf9f26d..78b99f8 100644 Binary files a/tests/likelihood_scan_alpha.png and b/tests/likelihood_scan_alpha.png differ diff --git a/tests/likelihood_scan_b_TFR.png b/tests/likelihood_scan_b_TFR.png index 75cdd4b..96190d7 100644 Binary files a/tests/likelihood_scan_b_TFR.png and b/tests/likelihood_scan_b_TFR.png differ diff --git a/tests/likelihood_scan_sigma_TFR.png b/tests/likelihood_scan_sigma_TFR.png index 8eda8fd..1ea69ef 100644 Binary files a/tests/likelihood_scan_sigma_TFR.png and b/tests/likelihood_scan_sigma_TFR.png differ diff --git a/tests/likelihood_scan_sigma_v.png b/tests/likelihood_scan_sigma_v.png index d73ddba..17d983d 100644 Binary files a/tests/likelihood_scan_sigma_v.png and b/tests/likelihood_scan_sigma_v.png differ diff --git a/tests/tfr_inference.py b/tests/tfr_inference.py index aaa16ff..49d1158 100644 --- a/tests/tfr_inference.py +++ b/tests/tfr_inference.py @@ -165,6 +165,42 @@ def get_fields(L, N, xmin, gravity='lpt', velmodel_name='CICModel'): def create_mock(Nt, L, xmin, cpar, dens, vel, Rmax, alpha, mthresh, a_TFR, b_TFR, sigma_TFR, sigma_m, sigma_eta, hyper_eta_mu, hyper_eta_sigma, sigma_v, interp_order=1, bias_epsilon=1e-7): + """ + Create mock TFR catalogue from a density and velocity field + + Args: + - Nt (int): Number of tracers to produce + - L (float): Box length (Mpc/h) + - xmin (float): Coordinate of corner of the box (Mpc/h) + - cpar (borg.cosmo.CosmologicalParameters): Cosmological parameters to use + - dens (np.ndarray): Over-density field (shape = (N, N, N)) + - vel (np.ndarray): Velocity field (km/s) (shape = (3, N, N, N)) + - Rmax (float): Maximum allowed comoving radius of a tracer (Mpc/h) + - alpha (float): Exponent for bias model + - mthresh (float): Threshold absolute magnitude in selection + - a_TFR (float): TFR relation intercept + - b_TFR (float): TFR relation slope + - sigma_TFR (float): Intrinsic scatter in the TFR + - sigma_v (float): Uncertainty on the velocity field (km/s) + - sigma_m (float): Uncertainty on the apparent magnitude measurements + - sigma_eta (float): Uncertainty on the linewidth measurements + - hyper_eta_mu (float): Mean of the Gaussian hyper-prior for the true eta values + - hyper_eta_sigma (float): Std deviation of the Gaussian hyper-prior for the true eta values + - sigma_v (float): Uncertainty on the velocity field (km/s) + - interp_order (int, default=1): Order of interpolation from grid points to the line of sight + - bias_epsilon (float, default=1e-7): Small number to add to 1 + delta to prevent 0^# + + Returns: + - all_RA (np.ndarrary): Right Ascension (degrees) of the tracers (shape = (Nt,)) + - all_Dec (np.ndarrary): Dec (np.ndarray): Delination (degrees) of the tracers (shape = (Nt,)) + - czCMB (np.ndarrary): Observed redshifts (km/s) of the tracers (shape = (Nt,)) + - all_mtrue (np.ndarrary): True apparent magnitudes of the tracers (shape = (Nt,)) + - all_etatrue (np.ndarrary): True linewidths of the tracers (shape = (Nt,)) + - all_mobs (np.ndarrary): Observed apparent magnitudes of the tracers (shape = (Nt,)) + - all_etaobs (np.ndarrary): Observed linewidths of the tracers (shape = (Nt,)) + - all_xtrue (np.ndarrary): True comoving coordinates of the tracers (Mpc/h) (shape = (3, Nt)) + + """ # Initialize lists to store valid positions and corresponding sig_mu values all_xtrue = np.empty((3, Nt)) @@ -226,7 +262,8 @@ def create_mock(Nt, L, xmin, cpar, dens, vel, Rmax, alpha, mthresh, etaobs = etatrue + sigma_eta * np.random.randn(Nt) # Apply apparement magnitude cut - m = mobs <= mthresh + # m = mobs <= mthresh + m = np.ones(mobs.shape, dtype=bool) mtrue = mtrue[m] etatrue = etatrue[m] mobs = mobs[m] @@ -283,8 +320,10 @@ def create_mock(Nt, L, xmin, cpar, dens, vel, Rmax, alpha, mthresh, def estimate_data_parameters(): - """ + Using the 2MASS catalogue, estimate some parameters to use in mock generation. + The file contains the following columns: + ID 2MASS XSC ID name (HHMMSSss+DDMMSSs) RAdeg Right ascension (J2000) DEdeg Declination (J2000) @@ -297,6 +336,13 @@ def estimate_data_parameters(): e_Jmag Error of the NIR magnitudes in J band from the (mag) WHIc Corrected HI width (km/s) e_WHIc Error of corrected HI width (km/s) + + Returns: + - sigma_m (float): Estimate of the uncertainty on the apparent magnitude measurements + - sigma_eta (float): Estimate of the uncertainty on the linewidth measurements + - hyper_eta_mu (float): Estimate of the mean of the Gaussian hyper-prior for the true eta values + - hyper_eta_sigma (float): Estimate of the std deviation of the Gaussian hyper-prior for the true eta values + - hyper_m_sigma (float): Estimate of the std deviation of the Gaussian hyper-prior for the true m values """ columns = ['ID', 'RAdeg', 'DEdeg', 'cz2mrs', 'Kmag', 'Hmag', 'Jmag', 'e_Kmag', 'e_Hmah', 'e_Jmag', 'WHIc', 'e_WHIc'] @@ -311,7 +357,9 @@ def estimate_data_parameters(): hyper_eta_mu = np.median(eta) hyper_eta_sigma = (np.percentile(eta, 84) - np.percentile(eta, 16)) / 2 - return sigma_m, sigma_eta, hyper_eta_mu, hyper_eta_sigma + hyper_m_sigma = np.amax(df['Kmag']) - np.percentile(df['Kmag'], 16) + + return sigma_m, sigma_eta, hyper_eta_mu, hyper_eta_sigma, hyper_m_sigma def generateMBData(RA, Dec, cz_obs, L, N, R_lim, Nsig, Nint_points, sigma_v, frac_sigma_r): @@ -363,11 +411,11 @@ def generateMBData(RA, Dec, cz_obs, L, N, R_lim, Nsig, Nint_points, sigma_v, fra return MB_pos -def likelihood(alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true, eta_true, +def likelihood_vel(alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true, eta_true, dens, vel, omega_m, h, L, xmin, interp_order, bias_epsilon, - cz_obs, m_obs, eta_obs, sigma_m, sigma_eta, MB_pos, mthresh): + cz_obs, MB_pos, mthresh): """ - Evaluate the likelihood for TFR sample + Evaluate the terms in the likelihood from the velocity and malmquist bias Args: - alpha (float): Exponent for bias model @@ -386,10 +434,6 @@ def likelihood(alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true, eta_true, - interp_order (int): Order of interpolation from grid points to the line of sight - bias_epsilon (float): Small number to add to 1 + delta to prevent 0^# - cz_obs (np.ndarray): Observed redshifts (km/s) of the tracers (shape = (Nt,)) - - m_obs (np.ndarray): Observed apparent magnitudes of the tracers (shape = (Nt,)) - - eta_obs (np.ndarray): Observed linewidths of the tracers (shape = (Nt,)) - - sigma_m (float): Uncertainty on the apparent magnitude measurements - - sigma_eta (float): Uncertainty on the apparent linewidth measurements - MB_pos (np.ndarray): Comoving coordinates of integration points to use in likelihood (Mpc/h). The shape is (3, Nt, Nsig) - mthresh (float): Threshold absolute magnitude in selection @@ -397,7 +441,7 @@ def likelihood(alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true, eta_true, Returns: - loglike (float): The log-likelihood of the data """ - + # Comoving radii of integration points (Mpc/h) r = jnp.sqrt(jnp.sum(MB_pos ** 2, axis=0)) @@ -444,26 +488,102 @@ def likelihood(alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true, eta_true, # Integrate to get likelihood p_cz = jnp.trapezoid(jnp.exp(-0.5 * d2) * p_r / p_r_norm, r, axis=1) lkl_ind = jnp.log(p_cz) - scale / 2 - 0.5 * jnp.log(2 * np.pi * sigma_v**2) - loglike_vel = - lkl_ind.sum() + loglike = lkl_ind.sum() + + return loglike + + +def likelihood_m(m_true, m_obs, sigma_m, mthresh): + """ + Evaluate the terms in the likelihood from apparent magnitude + + Args: + - m_true (np.ndarray): True apparent magnitudes of the tracers (shape = (Nt,)) + - m_obs (np.ndarray): Observed apparent magnitudes of the tracers (shape = (Nt,)) + - sigma_m (float): Uncertainty on the apparent magnitude measurements + - mthresh (float): Threshold absolute magnitude in selection + + Returns: + - loglike (float): The log-likelihood of the data + """ Nt = m_obs.shape[0] - - # Apparent magnitude terms - norm = 0.5 * (1 + jax.scipy.special.erf((mthresh - m_true) / (jnp.sqrt(2) * sigma_m))) - loglike_m = ( + # norm = 2 / (1 + jax.scipy.special.erf((mthresh - m_true) / (jnp.sqrt(2) * sigma_m))) / jnp.sqrt(2 * jnp.pi * sigma_m ** 2) + norm = jnp.sqrt(2 * jnp.pi * sigma_m ** 2) * jnp.ones(Nt) + loglike = - ( 0.5 * jnp.sum((m_obs - m_true) ** 2 / sigma_m ** 2) + jnp.sum(jnp.log(norm)) + Nt * 0.5 * jnp.log(2 * jnp.pi * sigma_m ** 2) ) + + return loglike + + +def likelihood_eta(eta_true, eta_obs, sigma_eta): + """ + Evaluate the terms in the likelihood from linewidth + + Args: + - eta_true (np.ndarray): True linewidths of the tracers (shape = (Nt,)) + - eta_obs (np.ndarray): Observed linewidths of the tracers (shape = (Nt,)) + - sigma_eta (float): Uncertainty on the linewidth measurements - # Linewidth terms - loglike_eta = ( + Returns: + - loglike (float): The log-likelihood of the data + """ + + Nt = eta_obs.shape[0] + loglike = - ( 0.5 * jnp.sum((eta_obs - eta_true) ** 2 / sigma_eta ** 2) + Nt * 0.5 * jnp.log(2 * jnp.pi * sigma_eta ** 2) ) - # loglike = - (loglike_vel + loglike_m + loglike_eta) - loglike = - (loglike_eta + loglike_m) + return loglike + + +def likelihood(alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true, eta_true, + dens, vel, omega_m, h, L, xmin, interp_order, bias_epsilon, + cz_obs, m_obs, eta_obs, sigma_m, sigma_eta, MB_pos, mthresh): + """ + Evaluate the likelihood for TFR sample + + Args: + - alpha (float): Exponent for bias model + - a_TFR (float): TFR relation intercept + - b_TFR (float): TFR relation slope + - sigma_TFR (float): Intrinsic scatter in the TFR + - sigma_v (float): Uncertainty on the velocity field (km/s) + - m_true (np.ndarray): True apparent magnitudes of the tracers (shape = (Nt,)) + - eta_true (np.ndarray): True linewidths of the tracers (shape = (Nt,)) + - dens (np.ndarray): Over-density field (shape = (N, N, N)) + - vel (np.ndarray): Velocity field (km/s) (shape = (3, N, N, N)) + - omega_m (float): Matter density parameter Om + - h (float): Hubble constant H0 = 100 h km/s/Mpc + - L (float): Comoving box size (Mpc/h) + - xmin (float): Coordinate of corner of the box (Mpc/h) + - interp_order (int): Order of interpolation from grid points to the line of sight + - bias_epsilon (float): Small number to add to 1 + delta to prevent 0^# + - cz_obs (np.ndarray): Observed redshifts (km/s) of the tracers (shape = (Nt,)) + - m_obs (np.ndarray): Observed apparent magnitudes of the tracers (shape = (Nt,)) + - eta_obs (np.ndarray): Observed linewidths of the tracers (shape = (Nt,)) + - sigma_m (float): Uncertainty on the apparent magnitude measurements + - sigma_eta (float): Uncertainty on the linewidth measurements + - MB_pos (np.ndarray): Comoving coordinates of integration points to use in likelihood (Mpc/h). + The shape is (3, Nt, Nsig) + - mthresh (float): Threshold absolute magnitude in selection + + Returns: + - loglike (float): The log-likelihood of the data + """ + + + loglike_vel = likelihood_vel(alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true, eta_true, + dens, vel, omega_m, h, L, xmin, interp_order, bias_epsilon, + cz_obs, MB_pos, mthresh) + loglike_m = likelihood_m(m_true, m_obs, sigma_m, mthresh) + loglike_eta = likelihood_eta(eta_true, eta_obs, sigma_eta) + + loglike = (loglike_vel + loglike_m + loglike_eta) return loglike @@ -507,7 +627,7 @@ def test_likelihood_scan(prior, alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true, pars = [alpha, a_TFR, b_TFR, sigma_TFR, sigma_v] par_names = ['alpha', 'a_TFR', 'b_TFR', 'sigma_TFR', 'sigma_v'] - orig_ll = likelihood(*pars, m_true, eta_true, + orig_ll = - likelihood(*pars, m_true, eta_true, dens, vel, omega_m, h, L, xmin, interp_order, bias_epsilon, czCMB, m_obs, eta_obs, sigma_m, sigma_eta, MB_pos, mthresh) @@ -526,7 +646,7 @@ def test_likelihood_scan(prior, alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true, orig_x = pars[i] for j, xx in enumerate(x): pars[i] = xx - all_ll[j] = likelihood(*pars, m_true, eta_true, + all_ll[j] = - likelihood(*pars, m_true, eta_true, dens, vel, omega_m, h, L, xmin, interp_order, bias_epsilon, czCMB, m_obs, eta_obs, sigma_m, sigma_eta, MB_pos, mthresh) pars[i] = orig_x @@ -546,8 +666,7 @@ def test_likelihood_scan(prior, alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true, def run_mcmc(num_warmup, num_samples, prior, initial, dens, vel, omega_m, h, L, xmin, interp_order, bias_epsilon, - czCMB, m_obs, eta_obs, sigma_m, sigma_eta, MB_pos, mthresh, - m_true): + czCMB, m_obs, eta_obs, sigma_m, sigma_eta, MB_pos, mthresh,): """ Run MCMC over the model parameters @@ -573,6 +692,9 @@ def run_mcmc(num_warmup, num_samples, prior, initial, dens, vel, omega_m, h, L, The shape is (3, Nt, Nsig) - mthresh (float): Threshold absolute magnitude in selection + Returns: + - mcmc + """ Nt = eta_obs.shape[0] @@ -583,50 +705,46 @@ def run_mcmc(num_warmup, num_samples, prior, initial, dens, vel, omega_m, h, L, a_TFR = numpyro.sample("a_TFR", dist.Uniform(*prior['a_TFR'])) b_TFR = numpyro.sample("b_TFR", dist.Uniform(*prior['b_TFR'])) sigma_TFR = numpyro.sample("sigma_TFR", dist.HalfCauchy(1.0)) - sigma_v = numpyro.sample("sigma_v", dist.HalfCauchy(1.0)) - -# # Sample the means with a uniform prior -# hyper_mean_m = numpyro.sample("hyper_mean_m", dist.Uniform(*prior['hyper_mean_m'])) -# hyper_mean_eta = numpyro.sample("hyper_mean_eta", dist.Uniform(*prior['hyper_mean_eta'])) -# hyper_mean = jnp.array([hyper_mean_m, hyper_mean_eta]) - -# # Sample standard deviations with a 1/sigma prior (Jeffreys prior approximation) -# hyper_sigma_m = numpyro.sample("hyper_sigma_m", dist.HalfCauchy(1.0)) # Equivalent to 1/sigma prior -# hyper_sigma_eta = numpyro.sample("hyper_sigma_eta", dist.HalfCauchy(1.0)) -# hyper_sigma = jnp.array([hyper_sigma_m, hyper_sigma_eta]) - -# # Sample correlation matrix using LKJ prior -# L_corr = numpyro.sample("L_corr", dist.LKJCholesky(2, concentration=1.0)) # Cholesky factor of correlation matrix -# corr_matrix = L_corr @ L_corr.T # Convert to full correlation matrix - -# # Construct full covariance matrix: Σ = D * Corr * D -# hyper_cov = jnp.diag(hyper_sigma) @ corr_matrix @ jnp.diag(hyper_sigma) - -# # Sample the true eta and m -# x = numpyro.sample("x", dist.MultivariateNormal(hyper_mean, hyper_cov), sample_shape=(Nt,)) -# m_true = numpyro.deterministic("m_true", x[:, 0]) -# eta_true = numpyro.deterministic("eta_true", x[:, 1]) + # sigma_v = numpyro.sample("sigma_v", dist.HalfCauchy(1.0)) + sigma_v = numpyro.sample("sigma_v", dist.Uniform(*prior['sigma_v'])) + hyper_mean_m = numpyro.sample("hyper_mean_m", dist.Uniform(*prior['hyper_mean_m'])) + hyper_sigma_m = numpyro.sample("hyper_sigma_m", dist.HalfCauchy(1.0)) # Equivalent to 1/sigma prior hyper_mean_eta = numpyro.sample("hyper_mean_eta", dist.Uniform(*prior['hyper_mean_eta'])) hyper_sigma_eta = numpyro.sample("hyper_sigma_eta", dist.HalfCauchy(1.0)) # Equivalent to 1/sigma prior - eta_true = numpyro.sample("eta_true", dist.Normal(hyper_mean_eta, hyper_sigma_eta), sample_shape=(Nt,)) - + + # Sample correlation matrix using LKJ prior + L_corr = numpyro.sample("L_corr", dist.LKJCholesky(2, concentration=1.0)) # Cholesky factor of correlation matrix + corr_matrix = L_corr @ L_corr.T # Convert to full correlation matrix + + # Construct full covariance matrix: Σ = D * Corr * D + hyper_mean = jnp.array([hyper_mean_m, hyper_mean_eta]) + hyper_sigma = jnp.array([hyper_sigma_m, hyper_sigma_eta]) + hyper_cov = jnp.diag(hyper_sigma) @ corr_matrix @ jnp.diag(hyper_sigma) + + # Sample m_true and eta_true + x = numpyro.sample("true_vars", dist.MultivariateNormal(hyper_mean, hyper_cov), sample_shape=(Nt,)) + m_true = numpyro.deterministic("m_true", x[:, 0]) + eta_true = numpyro.deterministic("eta_true", x[:, 1]) + # Evaluate the likelihood - numpyro.sample("obs", TFRLikelihood(alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, eta_true), obs=jnp.array([m_obs, eta_obs])) + numpyro.sample("obs", TFRLikelihood(alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, hyper_mean_eta, hyper_sigma_eta, m_true, eta_true), obs=jnp.array([m_obs, eta_obs])) class TFRLikelihood(dist.Distribution): support = dist.constraints.real - def __init__(self, alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, eta_true): - self.alpha, self.a_TFR, self.b_TFR, self.sigma_TFR, self.sigma_v, self.eta_true = dist.util.promote_shapes(alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, eta_true) + def __init__(self, alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, hyper_mean_eta, hyper_sigma_eta, m_true, eta_true): + self.alpha, self.a_TFR, self.b_TFR, self.sigma_TFR, self.sigma_v, self.hyper_mean_eta, self.hyper_sigma_eta, self.m_true, self.eta_true = dist.util.promote_shapes(alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, hyper_mean_eta, hyper_sigma_eta, m_true, eta_true) batch_shape = lax.broadcast_shapes( jnp.shape(alpha), jnp.shape(a_TFR), jnp.shape(b_TFR), jnp.shape(sigma_TFR), jnp.shape(sigma_v), - # jnp.shape(m_true), + jnp.shape(hyper_mean_eta), + jnp.shape(hyper_sigma_eta), + jnp.shape(m_true), jnp.shape(eta_true), ) super(TFRLikelihood, self).__init__(batch_shape = batch_shape) @@ -636,7 +754,7 @@ def run_mcmc(num_warmup, num_samples, prior, initial, dens, vel, omega_m, h, L, def log_prob(self, value): loglike = likelihood(self.alpha, self.a_TFR, self.b_TFR, self.sigma_TFR, self.sigma_v, - m_true, self.eta_true, + self.m_true, self.eta_true, dens, vel, omega_m, h, L, xmin, interp_order, bias_epsilon, czCMB, m_obs, eta_obs, sigma_m, sigma_eta, MB_pos, mthresh) return loglike @@ -644,6 +762,8 @@ def run_mcmc(num_warmup, num_samples, prior, initial, dens, vel, omega_m, h, L, rng_key = random.PRNGKey(6) rng_key, rng_key_ = random.split(rng_key) values = initial + values['true_vars'] = jnp.array([m_obs, eta_obs]).T + values['L_corr'] = jnp.identity(2) myprint('Preparing MCMC kernel') kernel = numpyro.infer.NUTS(tfr_model, init_strategy=numpyro.infer.initialization.init_to_value(values=initial) @@ -656,14 +776,28 @@ def run_mcmc(num_warmup, num_samples, prior, initial, dens, vel, omega_m, h, L, return mcmc -def process_mcmc_run(mcmc, param_labels, truths, obs): +def process_mcmc_run(mcmc, param_labels, truths, true_vars): + """ + Make summary plots from the MCMC and save these to file + + Args: + - mcmc + - param_labels (list[str]): Names of the parameters to plot + - truths (list[float]): True values of the parameters to plot. If unknown, then entry is None + - true_vars (dict): True values of the observables to compare against inferred ones + """ # Convert samples into a single array samples = mcmc.get_samples() samps = jnp.empty((len(samples[param_labels[0]]), len(param_labels))) for i, p in enumerate(param_labels): - samps = samps.at[:,i].set(samples[p]) + if p == 'hyper_corr': + L_corr = samples['L_corr'] + corr_matrix = jnp.matmul(L_corr, jnp.transpose(L_corr, (0, 2, 1))) + samps = samps.at[:,i].set(corr_matrix[:,0,1]) + else: + samps = samps.at[:,i].set(samples[p]) # Trace plot of non-redshift quantities fig1, axs1 = plt.subplots(samps.shape[1], 1, figsize=(6,3*samps.shape[1]), sharex=True) @@ -671,13 +805,14 @@ def process_mcmc_run(mcmc, param_labels, truths, obs): for i in range(samps.shape[1]): axs1[i].plot(samps[:,i]) axs1[i].set_ylabel(param_labels[i]) - axs1[i].axhline(truths[i], color='k') + if truths[i] is not None: + axs1[i].axhline(truths[i], color='k') axs1[-1].set_xlabel('Step Number') fig1.tight_layout() fig1.savefig('trace.png') # Corner plot - fig2, axs2 = plt.subplots(samps.shape[1], samps.shape[1], figsize=(10,10)) + fig2, axs2 = plt.subplots(samps.shape[1], samps.shape[1], figsize=(15,15)) corner.corner( np.array(samps), labels=param_labels, @@ -690,7 +825,7 @@ def process_mcmc_run(mcmc, param_labels, truths, obs): for var in ['eta', 'm']: vname = var + '_true' if vname in samples.keys(): - xtrue = obs[var] + xtrue = true_vars[var] xpred_median = np.median(samples[vname], axis=0) xpred_plus = np.percentile(samples[vname], 84, axis=0) - xpred_median xpred_minus = xpred_median - np.percentile(samples[vname], 16, axis=0) @@ -723,7 +858,7 @@ def main(): myprint('Beginning') # Get some parameters from the data - sigma_m, sigma_eta, hyper_eta_mu, hyper_eta_sigma = estimate_data_parameters() + sigma_m, sigma_eta, hyper_eta_mu, hyper_eta_sigma, hyper_m_sigma = estimate_data_parameters() # Other parameters to use L = 500.0 @@ -751,7 +886,8 @@ def main(): 'a_TFR': [-25, -20], 'b_TFR': [-10, -5], 'hyper_mean_eta': [hyper_eta_mu - 0.5, hyper_eta_mu + 0.5], - # 'hyper_mean_m':[mthresh - 5, mthresh + 5] + 'hyper_mean_m':[mthresh - 5, mthresh + 5], + 'sigma_v': [50, 300], } initial = { 'alpha': alpha, @@ -759,7 +895,8 @@ def main(): 'b_TFR': b_TFR, 'hyper_mean_eta': hyper_eta_mu, 'hyper_sigma_eta': hyper_eta_sigma, - # 'hyper_mean_m': mthresh, + 'hyper_mean_m': mthresh, + 'hyper_sigma_m': hyper_m_sigma, 'sigma_TFR': sigma_TFR, 'sigma_v': sigma_v, } @@ -787,15 +924,12 @@ def main(): # Run a MCMC mcmc = run_mcmc(num_warmup, num_samples, prior, initial, dens, vel, cpar.omega_m, cpar.h, L, xmin, interp_order, bias_epsilon, - czCMB, m_obs, eta_obs, sigma_m, sigma_eta, MB_pos, mthresh, - m_true) + czCMB, m_obs, eta_obs, sigma_m, sigma_eta, MB_pos, mthresh,) - param_labels = ['alpha', 'a_TFR', 'b_TFR', 'sigma_TFR', 'sigma_v', 'hyper_mean_eta', 'hyper_sigma_eta'] - truths = [alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, hyper_eta_mu, hyper_eta_sigma] - param_labels = ['hyper_mean_eta', 'hyper_sigma_eta'] - truths = [hyper_eta_mu, hyper_eta_sigma] - obs = {'m':m_obs, 'eta':eta_obs} - process_mcmc_run(mcmc, param_labels, truths, obs) + param_labels = ['alpha', 'a_TFR', 'b_TFR', 'sigma_TFR', 'sigma_v', 'hyper_mean_m', 'hyper_sigma_m', 'hyper_mean_eta', 'hyper_sigma_eta', 'hyper_corr'] + truths = [alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, None, None, hyper_eta_mu, hyper_eta_sigma, None] + true_vars = {'m':m_true, 'eta':eta_true} + process_mcmc_run(mcmc, param_labels, truths, true_vars) if __name__ == "__main__": main() @@ -803,8 +937,7 @@ if __name__ == "__main__": """ TO DO -- Fails to initialise currently when loglike includes the BORG term -- Runs MCMC with this likelihood +- Reinsert magnitude cut - Add bulk velocity - Deal with case where sigma_eta and sigma_m could be floats vs arrays diff --git a/tests/trace.png b/tests/trace.png index 4a29ba6..0b1b7d3 100644 Binary files a/tests/trace.png and b/tests/trace.png differ diff --git a/tests/true_predicted_eta.png b/tests/true_predicted_eta.png index 2bcd5da..a3b7bcf 100644 Binary files a/tests/true_predicted_eta.png and b/tests/true_predicted_eta.png differ diff --git a/tests/true_predicted_m.png b/tests/true_predicted_m.png new file mode 100644 index 0000000..58a1642 Binary files /dev/null and b/tests/true_predicted_m.png differ