diff --git a/notebooks/cornerplot.ipynb b/notebooks/cornerplot.ipynb new file mode 100644 index 0000000..73c5773 --- /dev/null +++ b/notebooks/cornerplot.ipynb @@ -0,0 +1,263 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "cb44b7e3-e1d9-40c1-92c3-8e312ffd6ecc", + "metadata": {}, + "source": [ + "# Pretty cornerplots\n", + "Uses multiple plotting utilities to demo all of the options. Here I will display two posteriors, both trained using the same priors; one trained using the generative option for SBI, one trained using the pre-generated training set." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "f89eab45-407d-42c0-812d-3a0800370ab3", + "metadata": {}, + "outputs": [], + "source": [ + "from scripts import evaluate, io, plot\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib\n", + "# remove top and right axis from plots\n", + "matplotlib.rcParams[\"axes.spines.right\"] = False\n", + "matplotlib.rcParams[\"axes.spines.top\"] = False" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "9f8923fc-3abd-464e-9b19-fc26ed90437a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "../savedmodels/sbi/\n" + ] + } + ], + "source": [ + "# load up the generative model\n", + "modelloader = io.ModelLoader()\n", + "path = \"../savedmodels/sbi/\"\n", + "model_name = \"sbi_linear_generative\"\n", + "posterior_generative = modelloader.load_model_pkl(path, model_name)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "ebbf1d55-24d4-487b-af49-da544cb3f668", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "../savedmodels/sbi/\n" + ] + } + ], + "source": [ + "# load up the generative model\n", + "modelloader = io.ModelLoader()\n", + "path = \"../savedmodels/sbi/\"\n", + "model_name = \"sbi_linear_from_data\"\n", + "posterior_static = modelloader.load_model_pkl(path, model_name)" + ] + }, + { + "cell_type": "markdown", + "id": "4f5f4249-7a82-408f-a2bd-165aeeb8d8bc", + "metadata": {}, + "source": [ + "In order to evaluate these, we need a validation set, which we'll load below." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "abaae20b-8028-43c1-aafb-8f4f2dfb4443", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "../saveddata/\n" + ] + } + ], + "source": [ + "dataloader = io.DataLoader()\n", + "path = \"../saveddata/\"\n", + "data_name = \"data_validation\"\n", + "validation = dataloader.load_data_pkl(data_name, path)\n", + "theta_true = validation['thetas'][0]\n", + "y_true = validation['xs'][0]" + ] + }, + { + "cell_type": "markdown", + "id": "f7ed9d9d-2b36-47b2-a098-84cf4a796f99", + "metadata": {}, + "source": [ + "Visualize the validation data." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "025b6d3a-0405-45f7-a885-a65c5e9942a7", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjsAAAGwCAYAAABPSaTdAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA1s0lEQVR4nO3de3CV9Z3H8U8SIKCYpIkhCTkHD9tmxQj1hmK0qWGSKVbaDRrsQNGhlI0jJZgE61bXgLCt4tqtCWCFNbMrjIoul1iqo3UwEBqXyM3LoljMrlmJeBKQLDlYlcuTZ/+w55QTEjj3y3Per5nMyDnPefI7z4zy8ff7/r6/JNM0TQEAAFhUcrQHAAAAEE6EHQAAYGmEHQAAYGmEHQAAYGmEHQAAYGmEHQAAYGmEHQAAYGmEHUmmacrlcomWQwAAWA9hR9Lx48eVnp6u48ePR3soAAAgxAg7AADA0gg7AADA0gg7AADA0gg7AADA0gg7AADA0gg7AADA0gg7AADA0gg7AADA0gg7AADA0gg7AADA0gg7AADA0oZEewAAAMCaDMNQa2urnE6n8vLyVFxcrJSUlIiPg7ADAABCrqmpSdXV1frkk088r9lsNi1fvly33XZbRMeSZJqmGdHfGINcLpfS09PV29urtLS0aA8HAIC45J7J2bx5sxoaGs56PykpSZK0cePGiAYewo4IOwAABGugmZyBJCUlyWazqaOjI2JLWhQoAwCAoDQ1NWn69OnnDTqSZJqmOjs71draGoGRfY2aHQAA4Df3ktWhQ4dUW1srfxeKnE5nmEZ2NsIOAADwi69LVueSl5cXwhGdG2EHAAD4zL1kFWjJr7tmp7i4OMQjGxxhBwAAnJdhGGppaVFlZWVQQUeSGhoaItpvhwJlAABwTk1NTXI4HCorK1NPT0/A97HZbBHfdi4xswMAAM4h2GUrSaqpqVF5eTkdlAEAQGwxDEPV1dUBBx273a6GhoaIz+T0F9VlLMMwtGjRIo0dO1YjRozQN7/5Tf3yl7/0eqimaWrx4sXKy8vTiBEjVFZWpvb2dq/79PT0aNasWUpLS1NGRobmzp2rzz//PNJfBwAAS2ltbfV7x1V2draeffZZbdu2TR0dHVEPOlKUZ3b++Z//WatWrdLatWt1+eWXa8+ePZozZ47S09N1zz33SJIee+wxrVixQmvXrtXYsWO1aNEiTZkyRfv379fw4cMlSbNmzZLT6dSWLVt06tQpzZkzR3fddZfWrVsXza8HAEBccvfQ2bRpk8+fcRcfr169OiYCzpmielzED37wA+Xk5Ojf/u3fPK9VVFRoxIgRevbZZ2WapkaPHq17771XP//5zyVJvb29ysnJ0Zo1azRjxgx98MEHKiws1O7duzVx4kRJ0h/+8Afdcsst+uSTTzR69OjzjoPjIgAA+FqgPXRiZclqIFFdxrrhhhvU3NysDz/8UJL07rvv6o033tD3v/99SVJHR4e6urpUVlbm+Ux6eromTZqktrY2SVJbW5syMjI8QUeSysrKlJycrJ07dw74e0+cOCGXy+X1AwBAonJvK6+trVVFRYVfQSczM1Ovv/56zCxZDSSqy1j333+/XC6Xxo0bp5SUFBmGoYcfflizZs2SJHV1dUmScnJyvD6Xk5Pjea+rq0ujRo3yen/IkCHKzMz0XNPfsmXLtHTp0lB/HQAA4k6gMznuZavGxkaVlpaGY2ghE9WZnfXr1+u5557TunXr9NZbb2nt2rX6l3/5F61duzasv/eBBx5Qb2+v56ezszOsvw8AgFjkzwGe/UWrZ04gojqzc9999+n+++/XjBkzJEkTJkzQxx9/rGXLlmn27NnKzc2VJHV3d3udodHd3a0rr7xSkpSbm6vDhw973ff06dPq6enxfL6/1NRUpaamhuEbAQAQ24I9wLOqqkoVFRVR65kTiKjO7HzxxRdKTvYeQkpKivr6+iRJY8eOVW5urpqbmz3vu1wu7dy5U0VFRZKkoqIiHTt2THv37vVcs3XrVvX19WnSpEkR+BYAAMQHdyfkyZMn64477tCRI0f8vkdFRYVKSkriJuhIUZ7Z+eEPf6iHH35YY8aM0eWXX663335bjz/+uH76059K+no9sKamRr/61a9UUFDg2Xo+evRoTZs2TZJ02WWX6eabb1ZlZaVWr16tU6dOqaqqSjNmzPBpJxYAAIkgHg/wDJWobj0/fvy4Fi1apBdffFGHDx/W6NGjNXPmTC1evFjDhg2T9HVTwYceekhPPfWUjh07pu985zt68skn9bd/+7ee+/T09KiqqkovvfSSkpOTVVFRoRUrVmjkyJE+jYOt5wAAKzMMQw6HI6DaHOmvxcjxUqPTX1TDTqwg7AAArKylpUWTJ08O+POx3EPHF5yNBQCAxTmdzoA+F+0DPEOFsAMAgMWduaPZF/E+k9MfYQcAAAszDEOGYSgzM1M9PT2DXpedna36+nrl5+fH/UxOf4QdAAAswt1Dx+l0Ki8vT5999plqa2vPWZgcywd4hgphBwAACwj02AebzWapJauBEHYAAIhzgfTQyczM1Pr16+OuQWAgotpBGQAABMcwDFVXV/vdLLCnp0cpKSmWDzoSYQcAgLjW2toacLPAQLekxxuWsQAAiAP9i4/dO6aCCSz+bkmPV4QdAABi3EDFxzabTY8//ri6u7v9vl88n3MVCI6LEMdFAABiV7AHePYX7+dcBYKaHQAAYpBhGGpublZlZWXIgo709YxQIgUdiWUsAABiTqA9c/qz2+36zW9+o+zs7LNqfRIJYQcAgBgSqmWr+vp6LViwIOGCzUBYxgIAIEYE2jNnIDk5OQSdv2BmBwCAKHNvK29ubg566cotUbaV+4KwAwBAFIWqPsct0baV+4KwAwBAlIRrW3lDQwNLWGegZgcAgCgItD4nMzNTr7/+ujZs2CCbzeb1XiJuK/cFMzsAAESBv2dauWdtGhsbVVpaKkm69dZbBzxCAt4IOwAARIG/Z1rZbDY1NDR4zdqkpKSopKQkxCOzHsIOAABR4Otuqbq6OpWWljJrEwTOxhJnYwEAIs8wDDkcDh06dGjAuh33rqqOjg5CTpAoUAYAIApSUlK0fPlySX+tx3FjV1VoEXYAAIiS2267TRs3blR+fr7X6+yqCi2WscQyFgAgctzdks/cQSWJXVVhRIEyAABhdGa4aW9vV2Njo9eWc5vNpuXLlzOLE0bM7IiZHQBAePhyFIS7Podlq/Ah7IiwAwAIjf6zOEuWLPGpQzI7r8KLZSwAAEIgmAM9TdNUZ2enWltbaRIYBoQdAAAC5J7J2bx5sxoaGoK+n79dleEbwg4AAAEIZiZnML52VYZ/CDsAAPipqalJ06dP9/vE8sG4a3bc29ARWjQVBADAD4ZhqLq6OqRBR6JbcjgRdgAA8ENra2tIl67olhx+LGMBAOCHYIuIbTabKisrVVBQQLfkCCHsAADgB3+KiJOSkmSappYuXUq4iSLCDgAAA+h/htUNN9ygHTt26NChQ8rOztZnn3123rodm82mhoYGlqiijLADAEA/A20rT0lJkWEYPn2+pqZG5eXlzOLECMIOAABnGGxbuS9Bx263M5MTgwg7AAD8RSDbyrOzs1VfX6/8/HxmcmIUYQcAkPDc9TnNzc1+bys/cuSI8vPzOdMqhhF2AAAJLRTHPnCmVWwj7AAAElaojn3gTKvYRtgBACQcwzDU0tKiysrKoIIOZ1rFB46LAAAklKamJjkcDpWVlamnpyfg+3CmVfwg7AAAEoZ72SqQ+pz+gYYzreIHy1gAgIQQ6GnldXV1Ki0t9XRQdndUZpt5/EgyQ3VGfRxzuVxKT09Xb2+v0tLSoj0cAEAYtLS0aPLkyT5f767H6ejoINTEOWZ2AACW5u6hs2nTJp8/Qz2OtRB2AACWFWgPHQ7wtBbCDgDAkgLpoZOZman169erpKSEGR0LIewAACzH32Jk97JVY2OjSktLwzk0RAFbzwEAlmIYhlauXOnX0hXbyK2NmR0AgGX4W6NTVVWliooKtpFbHGEHAGAJgdToVFRUcFp5AqDPjuizAwDxzjAMORwOn2d06KGTWJjZAQDELXcPnebmZr+CjkQPnURC2AEAxCV66MBXhB0AQNxwz+Rs3rxZDQ0Nfn++vr5eCxYsYEYnwVCzI2p2ACAeBDqTI1Gjk+iY2QEAxCT3LI7T6VR7e7uWLFni94nlEjU6IOwAAGJE/3DT2NgY0CxOf9TogLADAIi6YJaoBlNXV6fS0lIaBoKaHYmaHQCIpkCaAZ4L9Tnoj7OxAABR4++BnedDfQ4GQtgBAERNa2trSJeuONATA6FmBwAQce5i5E2bNoXkfjU1NSovL6c+BwMi7AAAIiqUxch2u52dVjgvwg4AIGKCKUZOSkqSaZpaunSpCgoKlJeXx0wOfBL1mp1Dhw7pjjvuUFZWlkaMGKEJEyZoz549nvdN09TixYuVl5enESNGqKysTO3t7V736Onp0axZs5SWlqaMjAzNnTtXn3/+eaS/CgBgEIZhqLm5WZWVlQEXI9tsNm3atEmLFy/WzJkzVVJSQtCBT6I6s/N///d/uvHGGzV58mS9+uqrys7OVnt7u77xjW94rnnssce0YsUKrV27VmPHjtWiRYs0ZcoU7d+/X8OHD5ckzZo1S06nU1u2bNGpU6c0Z84c3XXXXVq3bl20vhoA4C+CObCzsrKSWRwELap9du6//37953/+p1pbWwd83zRNjR49Wvfee69+/vOfS5J6e3uVk5OjNWvWaMaMGfrggw9UWFio3bt3a+LEiZKkP/zhD7rlllv0ySefaPTo0ecdB312ACA8Alm2qqqqUkVFBeEGIRPVZazf//73mjhxom6//XaNGjVKV111lRobGz3vd3R0qKurS2VlZZ7X0tPTNWnSJLW1tUmS2tralJGR4Qk6klRWVqbk5GTt3LlzwN974sQJuVwurx8AQGgF2kOnoqKCJSqEVFTDzkcffaRVq1apoKBAr732mubNm6d77rlHa9eulSR1dXVJknJycrw+l5OT43mvq6tLo0aN8np/yJAhyszM9FzT37Jly5Senu75sdvtof5qAJDw/O2hk5SUJLvdruLi4jCOCokoqmGnr69PV199tR555BFdddVVuuuuu1RZWanVq1eH9fc+8MAD6u3t9fx0dnaG9fcBQCIxDEMtLS1+9dCh8zHCKaphJy8vT4WFhV6vXXbZZTp48KAkKTc3V5LU3d3tdU13d7fnvdzcXB0+fNjr/dOnT6unp8dzTX+pqalKS0vz+gEABK+pqUkOh0OTJ0/WE0884fPn6HyMcIpq2Lnxxht14MABr9c+/PBDXXLJJZKksWPHKjc3V83NzZ73XS6Xdu7cqaKiIklSUVGRjh07pr1793qu2bp1q/r6+jRp0qQIfAsAgPTXYmR/lq4yMzP1+uuvq6Ojg6CDsInq1vPa2lrdcMMNeuSRR/SjH/1Iu3bt0lNPPaWnnnpK0tfTmjU1NfrVr36lgoICz9bz0aNHa9q0aZK+ngm6+eabPctfp06dUlVVlWbMmOHTTiwAQHDcy1b+9NBxL1s1NjaqtLQ0nMMDJDPKXnrpJXP8+PFmamqqOW7cOPOpp57yer+vr89ctGiRmZOTY6amppqlpaXmgQMHvK45evSoOXPmTHPkyJFmWlqaOWfOHPP48eM+j6G3t9eUZPb29obkOwFAoti0aZNps9lMSX792O12c9OmTdEePhJEVPvsxAr67ACA/+ihg3hB2BFhBwD8ZRiGHA6H312Rt23bppKSkvAMChgEB4ECAPwWSA8dm81GDx1EBWEHADAowzDU2toqp9PpdT6V0+n0+R700EG0EXYAAAMa6ADP/Px83XXXXTp16pTP97HZbGpoaGBrOaKGmh1RswMA/QVSfNxfZmam1q9fzzlXiDpmdgAAXgI9wNONHjqINVHtoAwAiD3+Fh/3x9EPiDXM7AAAvPhTfHymuro6lZaW0kMHMYewAwCQ9NedV/v37w/o84WFhfTQQUwi7ABAgjpzW3l7e7saGxuDWr7Ky8sL4eiA0CHsAEACGmhbeaBoGIhYR9gBgAQTim3lbjQMRDxgNxYAJJBAt5XX1dVp6dKlstlsXq+z8wrxgJkdAEgggW4rLyws1MyZM/Xggw8OeHwEEMsIOwCQANzFyJs2bQro8+7i45SUFHZcIe4QdgDAgkK104riY1gBYQcALCZUO60oPoZVUKAMABbi3mkVii3lFB/DKpjZAQCLCPYAT5vNpsrKShUUFFB8DEsh7ACARQS606qqqkoVFRWEG1gWYQcA4lywO60qKirYYQVLI+wAQBwLphiZnVZIFIQdAIhTwRz7wE4rJBJ2YwFAHApFMTI7rZAomNkBgDhjGIZWrlzp19IVO62QyAg7ABBH/K3RYacVQNgBgLgRSI0OO60AKckMdMHXQlwul9LT09Xb26u0tLRoDwcAzmIYhhwOh88zOu6dVh0dHczoIOFRoAwAccCfhoHstAK8sYwFAFFy5snkgxUNB9Iw0GazqaGhgZ1WwF8QdgAgCgYqNLbZbHr88ceVnZ0tp9Op9vZ2NTY2+rXrqr6+XgsWLGBGBzgDNTuiZgdAZAXTDHAw1OgAg6NmBwAiKNhmgAOhRgc4N8IOAERQoCeTnwvdkIFzo2YHACLI6XSG7F40DAR8Q9gBgAgxDEPd3d0hux8NAwHfEHYAIAL8PebhXNzFyMXFxSEYGWB9hB0ACLNQ7r6iGBnwHwXKABBGod59RTEy4D9mdgAgjHzdfVVfXy+bzaba2tqzGg1WVlaqoKBg0C7LAM6NsAMAYeDvMQ85OTmaPn26br311vMeIQHAP4QdAAixQIqR8/LyJEkpKSnssAJCjLADACHkbzEyO6uA8CPsAEAIGIahlpYWVVZW+hV0JHZWAeHGbiwACFJTU5McDofKysrU09Pj8+fYWQVEBjM7ABCEQHrocMwDEFlJZiiP3o1TLpdL6enp6u3tVVpaWrSHAyBOGIYhh8Phd1fkbdu2UYQMRBAzOwAQIH9PMKcYGYgOwg4A+MnfHjoSxchANBF2AMAPgR7oabPZ1NDQQDEyEAWEHQDwUSDFyJmZmVq/fr1KSkqY0QGihLADAOcRTA+dxsZGlZaWhnN4AM6DPjsAcA700AHiHzM7AHAGd/Gx0+lUe3u7lixZQg8dIM4RdgDgLwItPj5TRUUFPXSAGEPYAQAFVnx8JnroALGLmh0ACc8wDFVXVwcVdCR66ACxirADIOH52wm5P4qRgdjGMhaAhHFm8XFeXp5nyam5uTmg+9FDB4gPhB0ACWGg4uOsrCxJ0tGjR/26Fz10gPhC2AFgeYMVH/sbctw4+gGIL4QdAJYWiuJj0zS1dOlSFRQUeJa/WLYC4gdhB4ClhaL4mFkcIL6xGwuApTmdzoA/W1dXp46ODoIOEOcIOwAsLS8vL+DPlpaWslwFWABhB4BlGYYhwzCUmZnp1+eSkpJkt9vphgxYBDU7ACyhfw+dzz77TLW1tX7X69ANGbAev8PO7NmzNXfuXH33u98Nx3gAwG+BHuA5UJ8dCpIB6/E77PT29qqsrEyXXHKJ5syZo9mzZys/Pz8cYwOA8wrkAM8zOx9LOqurMjM6gLUkmQE0nzhy5IieeeYZrV27Vvv371dZWZnmzp2r8vJyDR06NBzjDCuXy6X09HT19vYqLS0t2sMB4CPDMORwOALaWr5t2zZP2AFgbQEVKGdnZ2vhwoV69913tXPnTn3rW9/SnXfeqdGjR6u2tlbt7e1+3/PRRx9VUlKSampqPK999dVXmj9/vrKysjRy5EhVVFSou7vb63MHDx7U1KlTdcEFF2jUqFG67777dPr06UC+FoA4E0wPnWC2pAOIL0HtxnI6ndqyZYu2bNmilJQU3XLLLdq3b58KCwtVX1/v8312796tf/3Xf9W3v/1tr9dra2v10ksvacOGDdq+fbs+/fRTr3V0wzA0depUnTx5Ujt27NDatWu1Zs0aLV68OJivBSBOBBNYgtmSDiDOmH46efKkuXHjRnPq1Knm0KFDzWuuucZctWqV2dvb67mmqanJzMjI8Ol+x48fNwsKCswtW7aYN910k1ldXW2apmkeO3bMHDp0qLlhwwbPtR988IEpyWxrazNN0zRfeeUVMzk52ezq6vJcs2rVKjMtLc08ceKEz9+pt7fXlOT1HQDEttOnT5v19fWmJL9+kpKSTLvdbp4+fTraXwFAhPg9s5OXl6fKykpdcskl2rVrl/bs2aO7777bq9Zl8uTJysjI8Ol+8+fP19SpU1VWVub1+t69e3Xq1Cmv18eNG6cxY8aora1NktTW1qYJEyYoJyfHc82UKVPkcrn0/vvvD/o7T5w4IZfL5fUDIH40NTXJ4XCotrbWr8+xrRxITH7vxqqvr9ftt9+u4cOHD3pNRkaGOjo6znuvF154QW+99ZZ279591ntdXV0aNmzYWaEpJydHXV1dnmvODDru993vDWbZsmVaunTpeccHIPYEsvvKjW3lQGLyO+zceeedIfnFnZ2dqq6u1pYtW84ZnMLhgQce0MKFCz1/drlcstvtER0DAP/5c4K53W7Xb37zG2VnZ7OtHEhwUeugvHfvXh0+fFhXX3215zXDMPTHP/5RTzzxhF577TWdPHlSx44d85rd6e7uVm5uriQpNzdXu3bt8rqve7eW+5qBpKamKjU1NYTfBkAk+Lr7qr6+XgsWLCDYAJAUxbOxSktLtW/fPr3zzjuen4kTJ2rWrFmefx46dKiam5s9nzlw4IAOHjyooqIiSVJRUZH27dunw4cPe67ZsmWL0tLSVFhYGPHvBCA8DMNQS0uLNm3a5NP1OTk5BB0AHlGb2bnooos0fvx4r9cuvPBCZWVleV6fO3euFi5cqMzMTKWlpWnBggUqKirS9ddfL0n63ve+p8LCQt1555167LHH1NXVpbq6Os2fP5+ZG8AiAjkKgm3lAM4U0weB1tfXKzk5WRUVFTpx4oSmTJmiJ5980vN+SkqKXn75Zc2bN09FRUW68MILNXv2bP3TP/1TFEcNIFjuQz03b96shoYGnz+XlJQkm83GaeUAvAR0XITVcFwEEDsCPdTTva1848aN7LYC4CVqNTsA0J97W3kgR0DYbDaCDoABxfQyFgDrcy9ZHTp0SLW1tX73z6mqqlJFRQXbygEMirADIGoCXbI6U0VFBaeXAzgnwg6AqAimE7JEMTIA31GzAyDi/OmEPBDOuALgD8IOgIgyDEMrV64MaumKYmQA/mAZC0DEBFujU1NTo/LycoqRAfiFsAMgIoKp0bHb7ZxWDiBghB0AYRdIjU52drbq6+uVn5/PTA6AoBB2AISNu4dOc3Ozz0tX7uLj1atXM5MDICQIOwBCxh1unE6n2tvb1djY6Hd9js1mY8kKQEgRdgCERCgaBNbX12vBggUsWQEIKQ4CFQeBAsEKVYPAjo4Ogg6AkGNmB0BAgj3Tyo0GgQDCjbADwG+hWLJyo0YHQLgRdgD4JdglK7e6ujqVlpayrRxA2FGzI2p2AF8ZhiGHwxHUjA71OQAijbOxAPistbU16KAjUZ8DILIIOwB8YhiGmpubg7oHB3gCiAZqdgCcV6AFyTabTZWVlSooKFBeXh71OQCigrAD4Jz8LUjmTCsAsYawA+AsgfTQ4UwrALGKsAPASzBLVvTLARCLCDsAPALtoVNXV6clS5awZAUgJrEbC4Ckr5euqqurA2oWWFpaStABELOY2QESjLsex+l0enZISdLKlSv9XrpyNwh03wMAYhFhB0ggA9XjZGVlSZKOHj3q171oEAggXhB2gAQxWD2OvyHHjYJkAPGCsAMkgGDqcc5EDx0A8YiwA1icYRgB1eOciR46AOIZu7EAC2tqapLD4VBtbW1Q9+FMKwDxjJkdwKIC7ZnTX319vRYsWMCSFYC4RdgBLCSQYx4G495WTtABEO8IO4BFBHrMw0DYVg7ASqjZASzAvWQVSNDJysry9Npxo0YHgJUwswPEiYE6H6ekpAS1rdxdjyNpwHsDgBUkmcFWL1qAy+VSenq6ent7lZaWFu3hAGcZaInKZrNp+fLlyszM1OTJk/26n7sep6Ojg1ADwPKY2QFi3GC7qg4dOqSKigpNnz7dr/tRjwMg0TCzI2Z2ELsMw5DD4QhJ0bGb3W7nmAcACYWZHSCGtba2hiTocMwDgERG2AFimNPpDOrzHPMAAGw9B2JaXl5eUJ9nCzkAEHaAmFZcXCybzeaZofFHXV2dOjo6CDoAEh5hB4hhKSkpWr58uST5HXhKS0upzQEAEXaAmHfbbbdp48aNys/P9+n6pKQk2e12FRcXh3lkABAfKFAGYlD/bsnl5eUqLy/3vNbe3q4lS5ZIklf/HXroAMDZCDtAjDlXt+Qz62/Gjx8/4HX00AEAbzQVFE0FEX3umZzNmzeroaHhrPfdMzb9d1YNdl4WAOCvCDsi7CC6BprJGQjnWQFAYFjGAiLszNkYd+2NL//PYZqmOjs71draqpKSkvAPFAAsgrADRJCvszjnEmxXZQBINIQdIEIGO73cX8F2VQaAREPYAcLMMAy1tLSosrIyqKDjrtmhfw4A+IemgkAYNTU1yeFwqKysTD09PQHfh/45ABA4ZnaAEAq0+Ph86J8DAIEj7AAhEori4/5qampUXl5O/xwACAJhBwiBUBUfu9ntdmZyACBECDtAkAzDUHV1dcBBJykpSaZpaunSpSooKKATMgCEGGEHCFJra2tQS1fU4wBAeBF2gCAF2uQvMzNT69evV0lJCbM4ABBGhB0gSP42+XNvI29sbFRpaWk4hgQAOAN9doAgFRcXy2azeULM+dhstrNOLwcAhA8zO0AQ3H11pk+froaGBk+xsRvFxwAQfYQdIEAD9dVJTk6WYRieP1N8DADRR9gBAjBYXx130KEZIADEjiQzVF3Q4pjL5VJ6erp6e3uVlpYW7eEgxhmGIYfDMeh2c/eBnR0dHQQdAIgBFCgDPnKfXr5kyZJz9tUxTVOdnZ1qbW2N4OgAAINhGQvwQSDnXgXafwcAEFqEHeA8Aj33yt/+OwCA8KBmR9TsYHDnq88ZCDU7ABBbqNkBBmEYhlauXOl30JGkhoYGgg4AxIiohp1ly5bp2muv1UUXXaRRo0Zp2rRpOnDggNc1X331lebPn6+srCyNHDlSFRUV6u7u9rrm4MGDmjp1qi644AKNGjVK9913n06fPh3Jr4I44y42fv7559XS0uLVG0f6eunK4XCotrbWr/vSHRkAYk9Uw8727ds1f/58vfnmm9qyZYtOnTql733ve/rzn//suaa2tlYvvfSSNmzYoO3bt+vTTz/1+ovEMAxNnTpVJ0+e1I4dO7R27VqtWbNGixcvjsZXQhxwB5nJkyfrxz/+sSZPniyHw6GmpibP+9OnT/drRqeurk7btm1TR0cHQQcAYkxM1ewcOXJEo0aN0vbt2/Xd735Xvb29ys7O1rp16zR9+nRJ0p/+9Cdddtllamtr0/XXX69XX31VP/jBD/Tpp58qJydHkrR69Wr94he/0JEjRzRs2LCzfs+JEyd04sQJz59dLpfsdjs1OwngfMXG99xzj55//nkdOXLEp/tRnwMAsS+manZ6e3slSZmZmZKkvXv36tSpUyorK/NcM27cOI0ZM0ZtbW2SpLa2Nk2YMMETdCRpypQpcrlcev/99wf8PcuWLVN6errnx263h+srIYYYhqHq6upz7qpasWKFX0FHoj4HAGJdzISdvr4+1dTU6MYbb9T48eMlSV1dXRo2bJgyMjK8rs3JyVFXV5fnmjODjvt993sDeeCBB9Tb2+v56ezsDPG3QSxqbW31a2nqfKjPAYD4EDN9dubPn6/33ntPb7zxRth/V2pqqlJTU8P+exBbQtnkr76+XgsWLGBGBwDiQEzM7FRVVenll1/Wtm3bZLPZPK/n5ubq5MmTOnbsmNf13d3dys3N9VzTf3eW+8/uawApNE3+kpKSZLfbCToAEEeiGnZM01RVVZVefPFFbd26VWPHjvV6/5prrtHQoUPV3Nzsee3AgQM6ePCgioqKJElFRUXat2+fDh8+7Llmy5YtSktLU2FhYWS+CGKae5v5oUOHlJ2d7am18Rc1OgAQn6K6G+tnP/uZ1q1bp82bN+vSSy/1vJ6enq4RI0ZIkubNm6dXXnlFa9asUVpamhYsWCBJ2rFjh6Sv/yK78sorNXr0aD322GPq6urSnXfeqb//+7/XI4884tM46KBsXYGcaTUYu92uhoYGanQAIM5ENewM9n/YTz/9tH7yk59I+rqp4L333qvnn39eJ06c0JQpU/Tkk096LVF9/PHHmjdvnlpaWnThhRdq9uzZevTRRzVkiG8lSYQdawr0TKszZWdnq76+Xvn5+SouLmZGBwDiUEz12YkWwo61uJetfvSjH6mnp2fQ67Kzs/XjH/9Yy5cvV1JSklcocgdxdlsBQPyLiQJlIFTc3ZHLysrOGXSkr5tYTps2TZs2bVJ+fr7Xe2wrBwDriJmt50AgDMNQa2urnE6n2tvbtWTJEr+WrZxOp2bOnKny8nLPffLy8liyAgALIewgboWi+Ni9HT0lJUUlJSUhGhkAIJYQdhCXgi0+dp9pVVxcHOKRAQBiDTU7iDu+nHF1LvTLAYDEQthBXDEMQytXrgxq6YriYwBILCxjIW4EW6OTmZmp9evXq6SkhBkdAEgghB3ErGB3Wrm5l60aGxtVWloa6mECAGIcYQcxKZTHPNhsNo55AIAERthBzAlmp5W7E/LSpUtVUFBAzxwAAGEHsSXYnVbM4gAA+iPsIKa0trYGvHRVX1+vBQsWMIsDAPBC2EFMcTqdfn/G3SCQoAMAGAh9dhBT3Mc3+IoGgQCA8yHsIKYUFxfLZrN5Qsz50CAQAHA+LGMhpqSkpGj58uWaPn26Z2eVGzutAACBSDID3fZiIS6XS+np6ert7VVaWlq0h5Ow+jcRbGxs9CpWttvt7LQCAPiNmR3EhIGaCObn5zOLAwAIGjM7YmYnGnw5CsJdt0NNDgAgGIQdEXYizZ+jINzbyjs6OpjVAQAEhN1YiAjDMNTS0qLa2lpVVFT43DjQNE11dnaqtbU1zCMEAFgVNTsIu1Ac6hlIs0EAACTCDsIsmEM9z+Rvs0EAANyo2RE1O+FiGIYcDkdQMzrU7AAAgsXMDkLOvdOqubk56KAjcRQEACA4hB2EVCjqc9xsNhtNBAEAQSPsIGjumZzNmzeroaEhoHtwFAQAIFwIOwhKqGZymMUBAIQLYQcBC8VOq5qaGpWXlzOLAwAIG8IO/OJesjp06JBqa2sDDjoc6gkAiBTCDs7pfCeR+6uurk6lpaXM5AAAIoY+O6LPzmBCubOKfjkAgGhhZgcDClXnY4l+OQCA6OIgUJzFMAxVV1eHJOhIX++02rhxI/U5AICoYGYHZ2ltbQ3J0hU7rQAAsYCwAw93MfKmTZuCug87rQAAsYSwk8BCudMqOztb9fX1ys/PZyYHABBTCDsJKlQ7rdzFx6tXr2YmBwAQkyhQTkDunVahOqyT4mMAQCxjZifBBLvTymazqbKyksM6AQBxg7CTYALdaVVVVaWKigrCDQAg7hB2EkSwO60qKipUUlIS2kEBABABhJ0EEEwxsvuYh+Li4jCMDACA8CPsWFwwxz5wzAMAwArYjWVRhmGoublZlZWVQRUjs9MKABDvmNmxoECXrdhpBQCwIsKOxQSybMVOKwCAlSWZoTraOo65XC6lp6ert7dXaWlp0R5OwAzDkMPh8HtGZ9u2bey0AgBYFjM7FuJvDx12WgEAEgEFyhbidDp9vpadVgCAREHYsQDDMNTS0qL9+/f7/Bl2WgEAEgXLWHHO351XmZmZWr9+vUpKSpjRAQAkBMJOHPNn55V72aqxsVGlpaXhHhoAADGDZaw4FEjDQJatAACJipmdOOPvslVdXZ1KS0vpoQMASFiEnTgSSMPAwsJCeugAABIay1hxwjAMVVdX+33OVV5eXphGBABAfGBmJ07QMBAAgMAQdmKcYRhqbW3Vpk2bfP4MDQMBAPgrwk4MC+b08oaGBnZeAQAgwk7Mcc/kbN68WQ0NDX59loaBAACcjbATQwKdyaFhIAAAg2M3Voxwbyv3N+hINAwEAOBcmNmJAYFuK6+qqlJFRQUNAwEAOAfCTgzwd1u5W0VFBQ0DAQA4D8JODHA6nX5dTw8dAAB8R9iJIvfOq/379/v8GXroAADgH8JOlNBDBwCAyCDsREEgB3rW1NSovLycYmQAAPxE2AkT9xKV0+lUXl6ep76mpaVFlZWVPgcdu93OTA4AAEGwTNj57W9/q1//+tfq6urSFVdcoZUrV+q6666LylgGWqLKysqSJB09etSne9TV1am0tJSZHAAAgmSJpoL/8R//oYULF+qhhx7SW2+9pSuuuEJTpkzR4cOHIz6WwZoDHj161OegI0mFhYUc+wAAQAgkmf52sotBkyZN0rXXXqsnnnhCktTX1ye73a4FCxbo/vvvP+v6EydO6MSJE54/u1wu2e129fb2Ki0tLeBxGIYhh8MRUM+c/rZt20YPHQAAQiDuZ3ZOnjypvXv3qqyszPNacnKyysrK1NbWNuBnli1bpvT0dM+P3W4PyVgCbQ54pqSkJNntdnroAAAQInEfdj777DMZhqGcnByv13NyctTV1TXgZx544AH19vZ6fjo7O0MyFn+bA/ZHDx0AAELPMgXK/khNTVVqamrI75uXlxfU5+mhAwBA6MV92Ln44ouVkpKi7u5ur9e7u7uVm5sb0bEUFxfLZrPp0KFDfvXQyczM1Pr16ylIBgAgDOJ+GWvYsGG65ppr1Nzc7Hmtr69Pzc3NKioqiuhYUlJStHz5ckl/XZI6l6SkJCUlJamxsVGlpaUEHQAAwiDuw44kLVy4UI2NjVq7dq0++OADzZs3T3/+8581Z86ciI/ltttu08aNG5Wfn+/1elZWlqfXjpvNZtPGjRtZtgIAIIwssfVckp544glPU8Err7xSK1as0KRJk3z6rMvlUnp6etBbz880WAfl/q8xmwMAQHhZJuwEIxxhBwAAxAZLLGMBAAAMhrADAAAsjbADAAAsjbADAAAsjbADAAAsjbADAAAsjbADAAAsjbADAAAsjbADAAAsLe5PPQ8FdxNpl8sV5ZEAAAB/XXTRRec8gJuwI+n48eOSJLvdHuWRAAAAf53vuCfOxpLU19enTz/99LzJ0F8ul0t2u12dnZ2cuRVGPOfI4VlHBs85MnjOkRGJ58zMjg+Sk5Nls9nCdv+0tDT+RYoAnnPk8Kwjg+ccGTznyIjmc6ZAGQAAWBphBwAAWBphJ4xSU1P10EMPKTU1NdpDsTSec+TwrCOD5xwZPOfIiIXnTIEyAACwNGZ2AACApRF2AACApRF2AACApRF2AACApRF2wui3v/2tHA6Hhg8frkmTJmnXrl3RHlJcW7Zsma699lpddNFFGjVqlKZNm6YDBw54XfPVV19p/vz5ysrK0siRI1VRUaHu7u4ojTj+Pfroo0pKSlJNTY3nNZ5x6Bw6dEh33HGHsrKyNGLECE2YMEF79uzxvG+aphYvXqy8vDyNGDFCZWVlam9vj+KI449hGFq0aJHGjh2rESNG6Jvf/KZ++ctf6sy9OTxn//3xj3/UD3/4Q40ePVpJSUn63e9+5/W+L8+0p6dHs2bNUlpamjIyMjR37lx9/vnn4RmwibB44YUXzGHDhpn//u//br7//vtmZWWlmZGRYXZ3d0d7aHFrypQp5tNPP22+99575jvvvGPecsst5pgxY8zPP//cc83dd99t2u12s7m52dyzZ495/fXXmzfccEMURx2/du3aZTocDvPb3/62WV1d7XmdZxwaPT095iWXXGL+5Cc/MXfu3Gl+9NFH5muvvWb+93//t+eaRx991ExPTzd/97vfme+++675d3/3d+bYsWPNL7/8Moojjy8PP/ywmZWVZb788stmR0eHuWHDBnPkyJHm8uXLPdfwnP33yiuvmA8++KDZ1NRkSjJffPFFr/d9eaY333yzecUVV5hvvvmm2draan7rW98yZ86cGZbxEnbC5LrrrjPnz5/v+bNhGObo0aPNZcuWRXFU1nL48GFTkrl9+3bTNE3z2LFj5tChQ80NGzZ4rvnggw9MSWZbW1u0hhmXjh8/bhYUFJhbtmwxb7rpJk/Y4RmHzi9+8QvzO9/5zqDv9/X1mbm5ueavf/1rz2vHjh0zU1NTzeeffz4SQ7SEqVOnmj/96U+9XrvtttvMWbNmmabJcw6F/mHHl2e6f/9+U5K5e/duzzWvvvqqmZSUZB46dCjkY2QZKwxOnjypvXv3qqyszPNacnKyysrK1NbWFsWRWUtvb68kKTMzU5K0d+9enTp1yuu5jxs3TmPGjOG5+2n+/PmaOnWq17OUeMah9Pvf/14TJ07U7bffrlGjRumqq65SY2Oj5/2Ojg51dXV5Pev09HRNmjSJZ+2HG264Qc3Nzfrwww8lSe+++67eeOMNff/735fEcw4HX55pW1ubMjIyNHHiRM81ZWVlSk5O1s6dO0M+Jg4CDYPPPvtMhmEoJyfH6/WcnBz96U9/itKorKWvr081NTW68cYbNX78eElSV1eXhg0bpoyMDK9rc3Jy1NXVFYVRxqcXXnhBb731lnbv3n3Wezzj0Pnoo4+0atUqLVy4UP/4j/+o3bt365577tGwYcM0e/Zsz/Mc6L8jPGvf3X///XK5XBo3bpxSUlJkGIYefvhhzZo1S5J4zmHgyzPt6urSqFGjvN4fMmSIMjMzw/LcCTuIS/Pnz9d7772nN954I9pDsZTOzk5VV1dry5YtGj58eLSHY2l9fX2aOHGiHnnkEUnSVVddpffee0+rV6/W7Nmzozw661i/fr2ee+45rVu3Tpdffrneeecd1dTUaPTo0TznBMIyVhhcfPHFSklJOWuHSnd3t3Jzc6M0KuuoqqrSyy+/rG3btslms3lez83N1cmTJ3Xs2DGv63nuvtu7d68OHz6sq6++WkOGDNGQIUO0fft2rVixQkOGDFFOTg7POETy8vJUWFjo9dpll12mgwcPSpLnefLfkeDcd999uv/++zVjxgxNmDBBd955p2pra7Vs2TJJPOdw8OWZ5ubm6vDhw17vnz59Wj09PWF57oSdMBg2bJiuueYaNTc3e17r6+tTc3OzioqKojiy+GaapqqqqvTiiy9q69atGjt2rNf711xzjYYOHer13A8cOKCDBw/y3H1UWlqqffv26Z133vH8TJw4UbNmzfL8M884NG688cazWid8+OGHuuSSSyRJY8eOVW5urtezdrlc2rlzJ8/aD1988YWSk73/qktJSVFfX58knnM4+PJMi4qKdOzYMe3du9dzzdatW9XX16dJkyaFflAhL3mGaZpfbz1PTU0116xZY+7fv9+86667zIyMDLOrqyvaQ4tb8+bNM9PT082WlhbT6XR6fr744gvPNXfffbc5ZswYc+vWreaePXvMoqIis6ioKIqjjn9n7sYyTZ5xqOzatcscMmSI+fDDD5vt7e3mc889Z15wwQXms88+67nm0UcfNTMyMszNmzeb//Vf/2WWl5ezJdpPs2fPNvPz8z1bz5uamsyLL77Y/Id/+AfPNTxn/x0/ftx8++23zbffftuUZD7++OPm22+/bX788cemafr2TG+++WbzqquuMnfu3Gm+8cYbZkFBAVvP49HKlSvNMWPGmMOGDTOvu+46880334z2kOKapAF/nn76ac81X375pfmzn/3M/MY3vmFecMEF5q233mo6nc7oDdoC+ocdnnHovPTSS+b48ePN1NRUc9y4ceZTTz3l9X5fX5+5aNEiMycnx0xNTTVLS0vNAwcORGm08cnlcpnV1dXmmDFjzOHDh5t/8zd/Yz744IPmiRMnPNfwnP23bdu2Af97PHv2bNM0fXumR48eNWfOnGmOHDnSTEtLM+fMmWMeP348LONNMs0z2kgCAABYDDU7AADA0gg7AADA0gg7AADA0gg7AADA0gg7AADA0gg7AADA0gg7AADA0gg7AADA0gg7AADA0gg7AADA0gg7AADA0gg7ACznyJEjys3N1SOPPOJ5bceOHRo2bJiam5ujODIA0cBBoAAs6ZVXXtG0adO0Y8cOXXrppbryyitVXl6uxx9/PNpDAxBhhB0AljV//ny9/vrrmjhxovbt26fdu3crNTU12sMCEGGEHQCW9eWXX2r8+PHq7OzU3r17NWHChGgPCUAUULMDwLL+53/+R59++qn6+vr0v//7v9EeDoAoYWYHgCWdPHlS1113na688kpdeumlamho0L59+zRq1KhoDw1AhBF2AFjSfffdp40bN+rdd9/VyJEjddNNNyk9PV0vv/xytIcGIMJYxgJgOS0tLWpoaNAzzzyjtLQ0JScn65lnnlFra6tWrVoV7eEBiDBmdgAAgKUxswMAACyNsAMAACyNsAMAACyNsAMAACyNsAMAACyNsAMAACyNsAMAACyNsAMAACyNsAMAACyNsAMAACyNsAMAACzt/wGgmYM4MbtW9gAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "x = np.linspace(0, 100, 101)\n", + "plt.clf()\n", + "plt.scatter(x, y_true, color = 'black')\n", + "plt.xlabel('x')\n", + "plt.ylabel('y')\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "2dc3ee21-3d8a-4c21-8cc2-488f1149ca13", + "metadata": {}, + "source": [ + "Let's draw from the posterior and display the results in a pairplot from mackelab. First for the static results." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "63b30f5a-c2e0-4804-85a4-ee244899da11", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c107f18e51e74383919cef4843d84f71", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Drawing 10000 posterior samples: 0%| | 0/10000 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display = plot.Display()\n", + "posterior_samples = posterior_static.sample((10000,), x = y_true)\n", + "display.mackelab_corner_plot(posterior_samples,\n", + " labels_list = ['$m$','$b$'],\n", + " truth_list = theta_true,\n", + " truth_color = 'orange')" + ] + }, + { + "cell_type": "markdown", + "id": "470ec484-1c12-4049-836b-392c5af8bae9", + "metadata": {}, + "source": [ + "Now for the generative model." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "0ce170e0-2290-4029-a66e-4706793e0ca9", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "4ec59b22718843659b7dd3ae2489de2a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Drawing 10000 posterior samples: 0%| | 0/10000 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display = plot.Display()\n", + "posterior_samples = posterior_generative.sample((10000,), x = y_true)\n", + "display.mackelab_corner_plot(posterior_samples,\n", + " labels_list = ['$m$','$b$'],\n", + " truth_list = theta_true,\n", + " truth_color = 'orange')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7bdce131-6cb5-4be9-aad5-e02ed35efbcd", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}