diff --git a/examples/contrib/epidemiology/sir.py b/examples/contrib/epidemiology/sir.py index 25417f4c94..fdd1a30ddf 100644 --- a/examples/contrib/epidemiology/sir.py +++ b/examples/contrib/epidemiology/sir.py @@ -91,6 +91,7 @@ def hook_fn(kernel, *unused): heuristic_ess_threshold=args.ess_threshold, warmup_steps=args.warmup_steps, num_samples=args.num_samples, + num_chains=args.num_chains, max_tree_depth=args.max_tree_depth, arrowhead_mass=args.arrowhead_mass, num_quant_bins=args.num_bins, @@ -293,6 +294,7 @@ def main(args): parser.add_argument("-np", "--num-particles", default=1024, type=int) parser.add_argument("-ess", "--ess-threshold", default=0.5, type=float) parser.add_argument("-w", "--warmup-steps", type=int) + parser.add_argument("-c", "--num-chains", default=1, type=int) parser.add_argument("-t", "--max-tree-depth", default=5, type=int) parser.add_argument("-a", "--arrowhead-mass", action="store_true") parser.add_argument("-r", "--rng-seed", default=0, type=int) diff --git a/tests/test_examples.py b/tests/test_examples.py index 459266a43e..18dab90083 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -34,6 +34,7 @@ 'contrib/autoname/tree_data.py --num-epochs=1', 'contrib/cevae/synthetic.py --num-epochs=1', 'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2', + 'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -c=2', 'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -e=2', 'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -k=1', 'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -e=2 -k=1', diff --git a/tutorial/source/epi_phylogeny.ipynb b/tutorial/source/epi_phylogeny.ipynb new file mode 100644 index 0000000000..bbf9030616 --- /dev/null +++ b/tutorial/source/epi_phylogeny.ipynb @@ -0,0 +1,383 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Epidemiology: Phylogen + Aggregate observations\n", + "\n", + "\n", + "This notebook demonstrates how to use the [pyro.contrib.epidemiology](http://docs.pyro.ai/en/latest/contrib.epidemiology.html) to infer epidemiological parameters based on both aggregate infection data and phylogenetic data from sequenced viral genomes. We will generally follow the analysis of [(Li & Ayscue 2020)](https://www.medrxiv.org/content/10.1101/2020.05.05.20092098v1)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import re\n", + "import datetime\n", + "import urllib.request\n", + "\n", + "import torch\n", + "from Bio import Phylo\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "\n", + "import pyro\n", + "import pyro.distributions as dist\n", + "from pyro.contrib.epidemiology import CompartmentalModel, binomial_dist, infection_dist\n", + "from pyro.contrib.epidemiology.models import SuperspreadingSEIRModel\n", + "\n", + "%matplotlib inline\n", + "pyro.enable_validation(True)\n", + "torch.set_default_dtype(torch.double)\n", + "torch.multiprocessing.set_start_method(\"spawn\")\n", + "torch.set_printoptions(precision=2)\n", + "print(torch.__version__)\n", + "print(pyro.__version__)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "region = \"california\"\n", + "url_phylogeny = \"https://github.com/czbiohub/EpiGen-COVID19/raw/master/files/{}_timetree.nexus\"\n", + "urllib.request.urlretrieve(url_phylogeny.format(region), \"timetree.nexus\");" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open(\"timetree.nexus\") as f:\n", + " for phylogeny in Phylo.parse(f, \"nexus\"):\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Fix a parsing error for whereby internal nodes interpret .name as .confidence\n", + "for clade in phylogeny.find_clades():\n", + " if clade.confidence:\n", + " clade.name = clade.confidence\n", + " clade.confidence = None" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "Phylo.draw(phylogeny, do_show=False)\n", + "plt.gcf().set_figwidth(10)\n", + "plt.gcf().set_figheight(10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can load the timeseries file using Pandas." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "url_timeseries = (\"https://github.com/czbiohub/EpiGen-COVID19/raw/master/files/\"\n", + " \"{}_timeseries.txt\")\n", + "urllib.request.urlretrieve(url_timeseries.format(region), \"timeseries.txt\");" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df = pd.read_csv(\"timeseries.txt\", sep=\"\\t\")\n", + "print(len(df))\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.plot(df[\"new_cases\"])\n", + "plt.xlabel(\"day after {}\".format(df[\"date\"][0]))\n", + "plt.ylabel(\"# new infections\")\n", + "plt.title(\"New infections in {}\".format(region.capitalize()));" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Converting data inputs to PyTorch tensors\n", + "\n", + "We'll need to convert the timeseres from pandas dataframe to torch tensor." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "start_date = datetime.datetime.strptime(df.date[0], \"%Y-%m-%d\")\n", + "start_date" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "start_days_after_2020 = (start_date - datetime.datetime(2020, 1, 1, 0, 0)).days\n", + "start_days_after_2020" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "new_cases = list(df[\"new_cases\"])\n", + "if new_cases[-1] == 0:\n", + " new_cases.pop(-1) # ignore a final empty observation\n", + "new_cases = torch.tensor(new_cases, dtype=torch.double)\n", + "print(new_cases.sum())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "leaf_times = []\n", + "coal_times = []\n", + "for clade in phylogeny.find_clades():\n", + " date_string = re.search(r\"date=(\\d\\d\\d\\d\\.\\d\\d)\", clade.comment).group(1)\n", + " days_after_2020 = (float(date_string) - 2020) * 365.25\n", + " time = days_after_2020 - start_days_after_2020\n", + "\n", + " num_children = len(clade)\n", + " if num_children == 0:\n", + " leaf_times.append(time)\n", + " else:\n", + " # Pyro expects binary coalescent events, so we split n-ary events\n", + " # into n-1 separate binary events.\n", + " for _ in range(num_children - 1):\n", + " coal_times.append(time)\n", + "assert len(leaf_times) == 1 + len(coal_times)\n", + "\n", + "leaf_times = torch.tensor(leaf_times, dtype=torch.double)\n", + "coal_times = torch.tensor(coal_times, dtype=torch.double)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's visualize this aggregate data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "times = torch.cat([coal_times, leaf_times])\n", + "signs = torch.cat([-torch.ones_like(coal_times), torch.ones_like(leaf_times)])\n", + "times, index = times.sort(0)\n", + "signs = signs[index]\n", + "lineages = signs.flip([0]).cumsum(0).flip([0])\n", + "\n", + "plt.plot(times, lineages)\n", + "plt.xlabel(\"time\")\n", + "plt.ylabel(\"# lineages\")\n", + "plt.title(\"Phylogeny width of {}\".format(region.capitalize()));" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Fitting a model\n", + "\n", + "The [pyro.contrib.epidemiology](http://docs.pyro.ai/en/latest/contrib.epidemiology.html) module provides a number of simple example models. We can start by training one of those, say an [SuperspreadingSEIRModel](http://docs.pyro.ai/en/latest/contrib.epidemiology.html#pyro.contrib.epidemiology.seir.SuperspreadingSEIRModel)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "model = SuperspreadingSEIRModel(population=int(39e6),\n", + " incubation_time=5.5,\n", + " recovery_time=14.,\n", + " data=new_cases,\n", + " leaf_times=leaf_times,\n", + " coal_times=coal_times)\n", + "\n", + "pyro.set_rng_seed(20200615)\n", + "mcmc = model.fit(warmup_steps=200, num_samples=800, haar_full_mass=7,\n", + " num_chains=4, jit_compile=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mcmc.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def plot_globals(samples):\n", + " names = sorted(k for k, v in model.samples.items() if v.shape[1:].numel() == 1)\n", + "\n", + " # Plot individual histograms.\n", + " fig, axes = plt.subplots(len(names), 1, figsize=(5, 2.5 * len(names)))\n", + " for ax, name in zip(axes, names):\n", + " mean = samples[name].mean().item()\n", + " std = samples[name].std().item()\n", + " ax.set_title(\"{} = {:0.3g} \\u00B1 {:0.3g}\".format(name, mean, std))\n", + " sns.distplot(samples[name].reshape(-1), ax=ax)\n", + " ax.set_yticks(())\n", + " plt.tight_layout()\n", + " \n", + " # Plot pairwise joint distributions for selected variables.\n", + " covariates = [(name, samples[name]) for name in names]\n", + " N = len(covariates)\n", + " fig, axes = plt.subplots(N, N, figsize=(6, 6), sharex=\"col\", sharey=\"row\")\n", + " for i in range(N):\n", + " axes[i][0].set_ylabel(covariates[i][0])\n", + " axes[0][i].set_xlabel(covariates[i][0])\n", + " axes[0][i].xaxis.set_label_position(\"top\")\n", + " for j in range(N):\n", + " ax = axes[i][j]\n", + " ax.set_xticks(())\n", + " ax.set_yticks(())\n", + " ax.scatter(covariates[j][1], -covariates[i][1],\n", + " lw=0, color=\"darkblue\", alpha=0.3)\n", + " plt.tight_layout()\n", + " plt.subplots_adjust(wspace=0, hspace=0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "plot_globals(model.samples)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def plot_series(model, forecast=14):\n", + " samples = model.predict(forecast=forecast)\n", + "\n", + " obs = model.data\n", + " S2E = samples[\"S2E\"]\n", + " num_samples = len(S2E)\n", + " median = S2E.median(dim=0).values\n", + " pred = samples[\"obs\"].median(dim=0).values.squeeze()[model.duration:]\n", + " print(\"Median prediction of new infections (starting on day 0):\\n{}\"\n", + " .format(\" \".join(map(str, map(int, median)))))\n", + "\n", + " plt.figure()\n", + " time = torch.arange(model.duration + forecast)\n", + " p05 = S2E.kthvalue(int(round(0.5 + 0.05 * num_samples)), dim=0).values\n", + " p95 = S2E.kthvalue(int(round(0.5 + 0.95 * num_samples)), dim=0).values\n", + " plt.fill_between(time, p05, p95, color=\"red\", alpha=0.3, label=\"90% CI\")\n", + " plt.plot(time, median, \"r-\", label=\"median\")\n", + " plt.plot(time[:model.duration], obs, \"k.\", label=\"observed\")\n", + " plt.plot(time[model.duration:], pred, \"k+\", label=\"predicted\")\n", + " plt.axvline(model.duration - 0.5, color=\"gray\", lw=1)\n", + " plt.xlim(0, len(time) - 1)\n", + " plt.ylim(1, None)\n", + " plt.yscale(\"log\")\n", + " plt.xlabel(\"day after first infection\")\n", + " plt.ylabel(\"new infections per day\")\n", + " plt.title(\"Predicted new infections (absent intervention)\")\n", + " plt.legend(loc=\"upper left\")\n", + " plt.tight_layout()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_series(model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}