From e3af7976f90f2984dd42ff8ec5f72b1fd008112c Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 10 Dec 2021 12:02:50 -0500 Subject: [PATCH] Add new option for PyroModule parameters to be stored locally, rather than the global parameter store. Add backwards-compatible __call__ method to pyro.infer.ELBO that returns a Module bound to a specific model and guide, allowing direct use of the PyTorch JIT API. Fork Bayesian regression tutorial into a PyTorch API usage tutorial to illustrate a PyTorch-native programming style facilitated by these changes and PyroModule. --- pyro/__init__.py | 4 + pyro/infer/elbo.py | 16 + pyro/nn/module.py | 70 +- pyro/poutine/runtime.py | 12 + pyro/primitives.py | 15 +- .../source/bayesian_regression_module.ipynb | 854 ++++++++++++++++++ tutorial/source/index.rst | 1 + 7 files changed, 947 insertions(+), 25 deletions(-) create mode 100644 tutorial/source/bayesian_regression_module.ipynb diff --git a/pyro/__init__.py b/pyro/__init__.py index 2658f03ebf..b68601fa81 100644 --- a/pyro/__init__.py +++ b/pyro/__init__.py @@ -5,6 +5,7 @@ from pyro.infer.inspect import render_model from pyro.logger import log from pyro.poutine import condition, do, markov +from pyro.poutine.runtime import enable_module_local_param from pyro.primitives import ( barrier, clear_param_store, @@ -21,6 +22,7 @@ random_module, sample, subsample, + use_param_store, validation_enabled, ) from pyro.util import set_rng_seed @@ -42,6 +44,7 @@ "deterministic", "do", "enable_validation", + "enable_module_local_param", "factor", "get_param_store", "iarange", @@ -59,5 +62,6 @@ "sample", "set_rng_seed", "subsample", + "use_param_store", "validation_enabled", ] diff --git a/pyro/infer/elbo.py b/pyro/infer/elbo.py index 3abe07d748..991d278350 100644 --- a/pyro/infer/elbo.py +++ b/pyro/infer/elbo.py @@ -5,6 +5,8 @@ import warnings from abc import ABCMeta, abstractmethod +import torch + import pyro import pyro.poutine as poutine from pyro.infer.util import is_validation_enabled @@ -12,6 +14,17 @@ from pyro.util import check_site_shape +class _ELBOModule(torch.nn.Module): + def __init__(self, model, guide, elbo): + super().__init__() + self.model = model + self.guide = guide + self.elbo = elbo + + def forward(self, *args, **kwargs): + return self.elbo.differentiable_loss(self.model, self.guide, *args, **kwargs) + + class ELBO(object, metaclass=ABCMeta): """ :class:`ELBO` is the top-level interface for stochastic variational @@ -86,6 +99,9 @@ def __init__( self.jit_options = jit_options self.tail_adaptive_beta = tail_adaptive_beta + def __call__(self, model, guide): + return _ELBOModule(model, guide, self) + def _guess_max_plate_nesting(self, model, guide, args, kwargs): """ Guesses max_plate_nesting by running the (model,guide) pair once diff --git a/pyro/nn/module.py b/pyro/nn/module.py index ad11567537..25d2a370bf 100644 --- a/pyro/nn/module.py +++ b/pyro/nn/module.py @@ -20,7 +20,8 @@ from torch.distributions import constraints, transform_to import pyro -from pyro.poutine.runtime import _PYRO_PARAM_STORE +from pyro.params.param_store import ParamStoreDict +from pyro.poutine.runtime import _PYRO_PARAM_STORE, _module_local_param_enabled class PyroParam(namedtuple("PyroParam", ("init_value", "constraint", "event_dim"))): @@ -380,6 +381,8 @@ def __init__(self, name=""): self._pyro_context = _Context() # shared among sub-PyroModules self._pyro_params = OrderedDict() self._pyro_samples = OrderedDict() + if _module_local_param_enabled(): + self._pyro_param_store = ParamStoreDict() super().__init__() def add_module(self, name, module): @@ -407,6 +410,12 @@ def named_pyro_params(self, prefix="", recurse=True): for elem in gen: yield elem + def _pyro_param_local(self, *args, **kwargs): + with pyro.use_param_store( + getattr(self, "_pyro_param_store", _PYRO_PARAM_STORE) + ): + return pyro.param(*args, **kwargs) + def _pyro_set_supermodule(self, name, context): self._pyro_name = name self._pyro_context = context @@ -434,22 +443,24 @@ def __getattr__(self, name): unconstrained_value = getattr(self, name + "_unconstrained") if self._pyro_context.active: fullname = self._pyro_get_fullname(name) - if fullname in _PYRO_PARAM_STORE: + if fullname in self._pyro_param_store: if ( - _PYRO_PARAM_STORE._params[fullname] + self._pyro_param_store._params[fullname] is not unconstrained_value ): # Update PyroModule <--- ParamStore. - unconstrained_value = _PYRO_PARAM_STORE._params[fullname] + unconstrained_value = self._pyro_param_store._params[ + fullname + ] if not isinstance(unconstrained_value, torch.nn.Parameter): # Update PyroModule ---> ParamStore (type only; data is preserved). unconstrained_value = torch.nn.Parameter( unconstrained_value ) - _PYRO_PARAM_STORE._params[ + self._pyro_param_store._params[ fullname ] = unconstrained_value - _PYRO_PARAM_STORE._param_to_name[ + self._pyro_param_store._param_to_name[ unconstrained_value ] = fullname super().__setattr__( @@ -457,10 +468,12 @@ def __getattr__(self, name): ) else: # Update PyroModule ---> ParamStore. - _PYRO_PARAM_STORE._constraints[fullname] = constraint - _PYRO_PARAM_STORE._params[fullname] = unconstrained_value - _PYRO_PARAM_STORE._param_to_name[unconstrained_value] = fullname - return pyro.param(fullname, event_dim=event_dim) + self._pyro_param_store._constraints[fullname] = constraint + self._pyro_param_store._params[fullname] = unconstrained_value + self._pyro_param_store._param_to_name[ + unconstrained_value + ] = fullname + return self._pyro_param_local(fullname, event_dim=event_dim) else: # Cannot determine supermodule and hence cannot compute fullname. return transform_to(constraint)(unconstrained_value) @@ -491,7 +504,7 @@ def __getattr__(self, name): "_unconstrained" ): if self._pyro_context.active: - pyro.param(self._pyro_get_fullname(name), result) + self._pyro_param_local(self._pyro_get_fullname(name), result) if isinstance(result, torch.nn.Module): if isinstance(result, PyroModule): @@ -508,6 +521,11 @@ def __getattr__(self, name): return result def __setattr__(self, name, value): + + if isinstance(value, ParamStoreDict): + super().__setattr__(name, value) + return + if isinstance(value, PyroModule): # Create a new sub PyroModule, overwriting any old value. try: @@ -527,19 +545,21 @@ def __setattr__(self, name, value): self._pyro_params[name] = constraint, event_dim if self._pyro_context.active: fullname = self._pyro_get_fullname(name) - pyro.param( + self._pyro_param_local( fullname, constrained_value, constraint=constraint, event_dim=event_dim, ) - constrained_value = pyro.param(fullname) + constrained_value = self._pyro_param_local(fullname) unconstrained_value = constrained_value.unconstrained() if not isinstance(unconstrained_value, torch.nn.Parameter): # Update PyroModule ---> ParamStore (type only; data is preserved). unconstrained_value = torch.nn.Parameter(unconstrained_value) - _PYRO_PARAM_STORE._params[fullname] = unconstrained_value - _PYRO_PARAM_STORE._param_to_name[unconstrained_value] = fullname + self._pyro_param_store._params[fullname] = unconstrained_value + self._pyro_param_store._param_to_name[ + unconstrained_value + ] = fullname else: # Cannot determine supermodule and hence cannot compute fullname. unconstrained_value = _unconstrain(constrained_value, constraint) super().__setattr__(name + "_unconstrained", unconstrained_value) @@ -553,12 +573,12 @@ def __setattr__(self, name, value): pass if self._pyro_context.active: fullname = self._pyro_get_fullname(name) - value = pyro.param(fullname, value) + value = self._pyro_param_local(fullname, value) if not isinstance(value, torch.nn.Parameter): # Update PyroModule ---> ParamStore (type only; data is preserved). value = torch.nn.Parameter(value) - _PYRO_PARAM_STORE._params[fullname] = value - _PYRO_PARAM_STORE._param_to_name[value] = fullname + self._pyro_param_store._params[fullname] = value + self._pyro_param_store._param_to_name[value] = fullname super().__setattr__(name, value) return @@ -590,9 +610,9 @@ def __delattr__(self, name): del self._parameters[name] if self._pyro_context.used: fullname = self._pyro_get_fullname(name) - if fullname in _PYRO_PARAM_STORE: + if fullname in self._pyro_param_store: # Update PyroModule ---> ParamStore. - del _PYRO_PARAM_STORE[fullname] + del self._pyro_param_store[fullname] return if name in self._pyro_params: @@ -600,9 +620,9 @@ def __delattr__(self, name): del self._pyro_params[name] if self._pyro_context.used: fullname = self._pyro_get_fullname(name) - if fullname in _PYRO_PARAM_STORE: + if fullname in self._pyro_param_store: # Update PyroModule ---> ParamStore. - del _PYRO_PARAM_STORE[fullname] + del self._pyro_param_store[fullname] return if name in self._pyro_samples: @@ -613,9 +633,9 @@ def __delattr__(self, name): del self._modules[name] if self._pyro_context.used: fullname = self._pyro_get_fullname(name) - for p in list(_PYRO_PARAM_STORE.keys()): + for p in list(self._pyro_param_store.keys()): if p.startswith(fullname): - del _PYRO_PARAM_STORE[p] + del self._pyro_param_store[p] return super().__delattr__(name) @@ -699,6 +719,8 @@ def to_pyro_module_(m, recurse=True): m._pyro_context = _Context() m._pyro_params = OrderedDict() m._pyro_samples = OrderedDict() + if _module_local_param_enabled(): + m._pyro_param_store = ParamStoreDict() # Reregister parameters and submodules. for name, value in list(m._parameters.items()): diff --git a/pyro/poutine/runtime.py b/pyro/poutine/runtime.py index 59b27c8911..612d51fd9a 100644 --- a/pyro/poutine/runtime.py +++ b/pyro/poutine/runtime.py @@ -15,6 +15,18 @@ # the global ParamStore _PYRO_PARAM_STORE = ParamStoreDict() +# toggle usage of local param stores in PyroModules +_PYRO_MODULE_LOCAL_PARAM = False + + +def enable_module_local_param(flag: bool) -> None: + global _PYRO_MODULE_LOCAL_PARAM + _PYRO_MODULE_LOCAL_PARAM = flag + + +def _module_local_param_enabled(): + return _PYRO_MODULE_LOCAL_PARAM + class _DimAllocator: """ diff --git a/pyro/primitives.py b/pyro/primitives.py index 7ab6a97425..2a8ce70251 100644 --- a/pyro/primitives.py +++ b/pyro/primitives.py @@ -14,6 +14,7 @@ import pyro.poutine as poutine from pyro.distributions import constraints from pyro.params import param_with_module_name +from pyro.params.param_store import ParamStoreDict from pyro.poutine.plate_messenger import PlateMessenger from pyro.poutine.runtime import ( _MODULE_NAMESPACE_DIVIDER, @@ -45,7 +46,19 @@ def clear_param_store(): return _PYRO_PARAM_STORE.clear() -_param = effectful(_PYRO_PARAM_STORE.get_param, type="param") +@contextmanager +def use_param_store(param_store: ParamStoreDict): + try: + global _PYRO_PARAM_STORE + _PYRO_PARAM_STORE, prev_store = param_store, _PYRO_PARAM_STORE + yield param_store + finally: + _PYRO_PARAM_STORE = prev_store + + +@effectful(type="param") +def _param(*args, **kwargs): + return _PYRO_PARAM_STORE.get_param(*args, **kwargs) def param(name, init_tensor=None, constraint=constraints.real, event_dim=None): diff --git a/tutorial/source/bayesian_regression_module.ipynb b/tutorial/source/bayesian_regression_module.ipynb new file mode 100644 index 0000000000..0e61a3a853 --- /dev/null +++ b/tutorial/source/bayesian_regression_module.ipynb @@ -0,0 +1,854 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Bayesian Regression - Introduction (Part 1)\n", + "\n", + "Regression is one of the most common and basic supervised learning tasks in machine learning. Suppose we're given a dataset $\\mathcal{D}$ of the form\n", + "\n", + "$$ \\mathcal{D} = \\{ (X_i, y_i) \\} \\qquad \\text{for}\\qquad i=1,2,...,N$$\n", + "\n", + "The goal of linear regression is to fit a function to the data of the form:\n", + "\n", + "$$ y = w X + b + \\epsilon $$\n", + "\n", + "where $w$ and $b$ are learnable parameters and $\\epsilon$ represents observation noise. Specifically $w$ is a matrix of weights and $b$ is a bias vector.\n", + "\n", + "In this tutorial, we will first implement linear regression in PyTorch and learn point estimates for the parameters $w$ and $b$. Then we will see how to incorporate uncertainty into our estimates by using Pyro to implement Bayesian regression. Additionally, we will learn how to use the Pyro's utility functions to do predictions and serve our model using `TorchScript`.\n", + "\n", + "## Tutorial Outline\n", + "\n", + " - [Setup](#Setup)\n", + " - [Dataset](#Dataset)\n", + " - [Linear Regression](#Linear-Regression)\n", + " - [Training with PyTorch Optimizers](#Training-with-PyTorch-Optimizers)\n", + " - [Regression Fit](#Plotting-the-Regression-Fit)\n", + " - [Bayesian Regression with Pyro's SVI](#Bayesian-Regression-with-Pyro's-Stochastic-Variational-Inference-%28SVI%29)\n", + " - [Model](#Model)\n", + " - [Using an AutoGuide](#Using-an-AutoGuide)\n", + " - [Optimizing the Evidence Lower Bound](#Optimizing-the-Evidence-Lower-Bound)\n", + " - [Model Evaluation](#Model-Evaluation)\n", + " - [Serving the Model using TorchScript](#Model-Serving-via-TorchScript)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup\n", + "Let's begin by importing the modules we'll need." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%reset -s -f" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from functools import partial\n", + "import torch\n", + "import numpy as np\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "import matplotlib.pyplot as plt\n", + "\n", + "import pyro\n", + "import pyro.distributions as dist\n", + "\n", + "# for CI testing\n", + "smoke_test = ('CI' in os.environ)\n", + "assert pyro.__version__.startswith('1.7.0')\n", + "pyro.set_rng_seed(1)\n", + "pyro.enable_module_local_param(True)\n", + "\n", + "\n", + "# Set matplotlib settings\n", + "%matplotlib inline\n", + "plt.style.use('default')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Dataset \n", + "\n", + "The following example is adapted from \\[1\\]. We would like to explore the relationship between topographic heterogeneity of a nation as measured by the Terrain Ruggedness Index (variable *rugged* in the dataset) and its GDP per capita. In particular, it was noted by the authors in \\[2\\] that terrain ruggedness or bad geography is related to poorer economic performance outside of Africa, but rugged terrains have had a reverse effect on income for African nations. Let us look at the data and investigate this relationship. We will be focusing on three features from the dataset:\n", + "\n", + " - `rugged`: quantifies the Terrain Ruggedness Index\n", + " - `cont_africa`: whether the given nation is in Africa\n", + " - `rgdppc_2000`: Real GDP per capita for the year 2000" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "DATA_URL = \"https://d2hg8soec8ck9v.cloudfront.net/datasets/rugged_data.csv\"\n", + "data = pd.read_csv(DATA_URL, encoding=\"ISO-8859-1\")\n", + "df = data[[\"cont_africa\", \"rugged\", \"rgdppc_2000\"]]\n", + "df = df[np.isfinite(df.rgdppc_2000)]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We perform three simple preprocessing steps. First, the response variable GDP is highly skewed, so we will log-transform it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df[\"rgdppc_2000\"] = np.log(df[\"rgdppc_2000\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Second, we multiply the two features `rugged` and `cont_africa` to create a new feature, which lets us add an interaction term to our model to separately account for the effect of ruggedness on the GDP for nations within and outside Africa." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df[\"cont_africa_x_rugged\"] = df[\"cont_africa\"] * df[\"rugged\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, we convert the data to a `torch.Tensor` so that we can use PyTorch and Pyro to write our models." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data = torch.tensor(df[[\"cont_africa\", \"rugged\", \"cont_africa_x_rugged\", \"rgdppc_2000\"]].values,\n", + " dtype=torch.float)\n", + "x_data, y_data = data[:, :-1], data[:, -1]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Linear Regression\n", + "\n", + "We would like to predict log GDP per capita of a nation as a function of two features from the dataset - whether the nation is in Africa, and its Terrain Ruggedness Index. We will create a trivial class called `PyroModule[nn.Linear]` that subclasses [PyroModule](http://docs.pyro.ai/en/dev/nn.html#module-pyro.nn.module) and `torch.nn.Linear`. `PyroModule` is very similar to PyTorch's `nn.Module`, but additionally supports [Pyro primitives](http://docs.pyro.ai/en/dev/primitives.html#primitives) as attributes that can be modified by Pyro's [effect handlers](http://pyro.ai/examples/effect_handlers.html) (see the [next section](#Model) on how we can have module attributes that are `pyro.sample` primitives). Some general notes:\n", + "\n", + " - Learnable parameters in PyTorch modules are instances of `nn.Parameter`, in this case the `weight` and `bias` parameters of the `nn.Linear` class. When declared inside a `PyroModule` as attributes, these are automatically registered in Pyro's param store. While this model does not require us to constrain the value of these parameters during optimization, this can also be easily achieved in `PyroModule` using the [PyroParam](http://docs.pyro.ai/en/dev/nn.html#pyro.nn.module.PyroParam) statement. \n", + " - Note that while the `forward` method of `PyroModule[nn.Linear]` inherits from `nn.Linear`, it can also be easily overridden. e.g. in the case of logistic regression, we apply a sigmoid transformation to the linear predictor." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from torch import nn\n", + "from pyro.nn import PyroModule\n", + "\n", + "assert issubclass(PyroModule[nn.Linear], nn.Linear)\n", + "assert issubclass(PyroModule[nn.Linear], PyroModule)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training with PyTorch Optimizers" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We use the mean squared error (MSE) as our loss and Adam as our optimizer from the `torch.optim` module. We would like to optimize the parameters of our model, namely the `weight` and `bias` parameters of the network, which corresponds to our regression coefficents and the intercept." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[iteration 0050] loss: 3179.7847\n", + "[iteration 0100] loss: 1616.1371\n", + "[iteration 0150] loss: 1109.4115\n", + "[iteration 0200] loss: 833.7544\n", + "[iteration 0250] loss: 637.5822\n", + "[iteration 0300] loss: 488.2652\n", + "[iteration 0350] loss: 376.4650\n", + "[iteration 0400] loss: 296.0483\n", + "[iteration 0450] loss: 240.6140\n", + "[iteration 0500] loss: 203.9385\n", + "[iteration 0550] loss: 180.6171\n", + "[iteration 0600] loss: 166.3493\n", + "[iteration 0650] loss: 157.9457\n", + "[iteration 0700] loss: 153.1786\n", + "[iteration 0750] loss: 150.5735\n", + "[iteration 0800] loss: 149.2020\n", + "[iteration 0850] loss: 148.5065\n", + "[iteration 0900] loss: 148.1668\n", + "[iteration 0950] loss: 148.0070\n", + "[iteration 1000] loss: 147.9347\n", + "[iteration 1050] loss: 147.9032\n", + "[iteration 1100] loss: 147.8900\n", + "[iteration 1150] loss: 147.8847\n", + "[iteration 1200] loss: 147.8827\n", + "[iteration 1250] loss: 147.8819\n", + "[iteration 1300] loss: 147.8817\n", + "[iteration 1350] loss: 147.8816\n", + "[iteration 1400] loss: 147.8815\n", + "[iteration 1450] loss: 147.8815\n", + "[iteration 1500] loss: 147.8815\n", + "Learned parameters:\n", + "weight [[-1.9478593 -0.20278622 0.39330277]]\n", + "bias [9.22308]\n" + ] + } + ], + "source": [ + "# Regression model\n", + "linear_reg_model = PyroModule[nn.Linear](3, 1)\n", + "\n", + "# Define loss and optimize\n", + "loss_fn = torch.nn.MSELoss(reduction='sum')\n", + "optim = torch.optim.Adam(linear_reg_model.parameters(), lr=0.05)\n", + "num_iterations = 1500 if not smoke_test else 2\n", + "\n", + "for j in range(num_iterations):\n", + " # run the model forward on the data\n", + " y_pred = linear_reg_model(x_data).squeeze(-1)\n", + " # calculate the mse loss\n", + " loss = loss_fn(y_pred, y_data)\n", + " # initialize gradients to zero\n", + " optim.zero_grad()\n", + " # backpropagate\n", + " loss.backward()\n", + " # take a gradient step\n", + " optim.step()\n", + " if (j + 1) % 50 == 0:\n", + " print(\"[iteration %04d] loss: %.4f\" % (j + 1, loss.item()))\n", + "\n", + " \n", + "# Inspect learned parameters\n", + "print(\"Learned parameters:\")\n", + "for name, param in linear_reg_model.named_parameters():\n", + " print(name, param.data.numpy())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Plotting the Regression Fit\n", + "\n", + "Let us plot the regression fit for our model, separately for countries outside and within Africa." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAA+UAAAJJCAYAAADMaparAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAA9hAAAPYQGoP6dpAACZ0klEQVR4nOzdeXgT5drH8V9aaMtSyk7LvgoUFERkV0ThAAcQd0VRwO3IUcH1dRdwQ9z1qKjowaXiDgIuKIjgVgGtqLWKggUUWlGWtgIt0s77R08iadNmkk4yk+T7ua5cmsnM5Mmk5Jn7We7HZRiGIQAAAAAAEHZxdhcAAAAAAIBYRVAOAAAAAIBNCMoBAAAAALAJQTkAAAAAADYhKAcAAAAAwCYE5QAAAAAA2ISgHAAAAAAAmxCUAwAAAABgE4JyAAAAAABsQlAOAIhI7du3l8vl8nokJiaqdevWGj9+vN566y27ixixNm/eLJfLpfbt29tdlGpV/P59Pd58801Jf/+9bN682dYyAwBQUS27CwAAQE0MHjxYnTt3liQVFBToq6++0pIlS7RkyRJdeeWVeuCBB2wuIUJt5MiRSk1N9fla27Ztqz128uTJeu655zR//nxNnjw5BKUDAKB6BOUAgIh24YUXegVTBw8e1JVXXqlHH31UDz74oCZMmKCjjz7avgJGoFatWun7779X7dq17S6KKddff72OO+64avf54IMP9Ndff6lVq1bhKRQAACYxfB0AEFVq1aqle++9Vw0aNJAkLV261OYSRZ7atWurW7du6tSpk91FsUynTp3UrVu3iGloAADEDoJyAEDUSUpKUpcuXSRJv/32m899PvjgA51yyilKS0tTQkKCmjdvrpNPPlmZmZlVnjc7O1unnnqqmjZtqrp16+rwww/XQw89pLKysirnLB+6ffHixTr++OPVuHFjuVwurVq1yrPf7t27NWPGDPXu3VvJycme899xxx3at29fpbKUlZXpqaee0uDBg9WwYUPVrl1bzZs3V69evXT55ZdXKkdeXp6mT5+uww47TElJSapbt67atGmjE044Qffdd5/Xvv7mlP/666+6/PLL1aVLFyUlJSklJUWDBw/Wk08+qdLS0kr7P/vss3K5XJo8ebL27t2rG264QZ07d1ZiYqJSU1M1adIkbdu2rcrrboWK34/7Mz733HOSpClTpnjNRZ85c2ZIywMAgBvD1wEAUamwsFCS1KJFi0qvXXPNNbr//vsVFxenvn376phjjtHWrVu1ePFiLV26VPPmzdOUKVO8jlm9erVGjx6t/fv3q1OnThoxYoR27typ6667Tp9//rnf8tx///169NFH1bdvX40aNUrbt29XfHy8JCknJ0ejRo3SL7/8orS0NA0ZMkS1a9fW2rVrdcstt+iNN97QqlWrlJKS4jnfhRdeqPnz5yspKUlDhgxRs2bNtGvXLv3888969NFHdcIJJ3iC6vz8fPXt21fbt29X27ZtNWrUKCUlJWn79u1av369vvzyS11zzTWmruu6des0atQo7dq1S23bttVJJ52kgoICrVq1Sp999pkWLVqkJUuWKCEhodKxBQUFGjRokLZu3apjjjlGPXv2VGZmpp5//nmtXr1aX3/9tddnDKX69etr0qRJ+uSTT7Rp0yav3ASS1Lt377CUAwAAGQAARKB27doZkoz58+dXei0nJ8eIj483JBnr1q3zeu2pp54yJBmdO3c2vv76a6/XVq9ebSQnJxsJCQnGjz/+6Nm+b98+o1WrVoYk4+qrrzZKS0s9r3333XdGixYtDEmGJCM3N9dnOePj443FixdXKuu+ffuMTp06GZKMm2++2SgpKfG8tnfvXmPChAmGJGPKlCme7Vu2bDEkGa1btzby8vJ8fv4tW7Z4ns+aNcuQZFx88cVGWVmZ174HDhwwVqxY4bUtNzfXkGS0a9fOa3txcbHn81xyySXGgQMHPK9t2rTJaN++vSHJuPHGG72Omz9/vuf6jBw50igoKPC8tmvXLqN3796GJOOuu+6q9Fmq4z7nhx9+6Hdfd7krfj+TJk2q8u8IAIBwYPg6ACBqFBQU6P3339cpp5yi0tJS3Xzzzerbt6/n9bKyMs+w5JdffllHHHGE1/HHHnusbrnlFh04cEBPPvmkZ/vrr7+ubdu2qV27dpo9e7bi4v6uPtPT03XLLbf4LdukSZN04oknVtr+3HPPadOmTRo7dqxuv/12rx7munXr6qmnnlLz5s31wgsvaPfu3ZL+HpLfp08fn1nHu3fv7pV13L3/qFGj5HK5vPatXbu2TjjhBL/ll6TXXntNW7ZsUcuWLfXQQw95zc/u2LGjZxj8f/7zHxUXF1c6vl69epo/f75nvr8kNWrUSNdff70kacWKFabKUdGwYcN8LodGNnUAQCRg+DoAIKJNmTKl0lDz+Ph4ZWRk6JxzzvHa/tVXX2n79u3q1KmTjjrqKJ/nc2fx/uyzzzzbVq9eLUk6/fTTfSYKO+ecc3TZZZdVW87TTjvN5/a3335bknTmmWf6fL1+/frq27ev3nnnHa1bt07/+Mc/1K1bNyUnJ+udd97RnXfeqbPPPlsdOnSo8r379eunxx9/XNdff70Mw9A//vEP1a9fv9ry+uKeA3/WWWcpMTGx0uunnHKKGjVqpN27d+vLL7/U4MGDvV7v27ev0tLSKh3XvXt3SQp6XnlVS6INGTIkqPMBABBOBOUAgIh26Fzg33//XR9//LGKioo0depUdenSRf369fPs+/PPP0uSNm3aVKnHuKLff//d8/+//vqrJFWZ+Kxhw4ZKSUlRQUFBleer6lh3mc4991yde+65psqUnJys+fPna8qUKbr55pt18803Ky0tTQMGDNCoUaN09tlnewXd5557rpYvX64XX3xRp556quLj45Wenq4hQ4botNNO0/HHH1/t+7q5g+aqGgBcLpc6dOig3bt3+wywq1oz3N1z7qt33QwzS6IBAOBUBOUAgIhWcZ3ygoICnXzyyfrwww91xhlnKCcnR3Xr1pVUPnxdklJTUzVy5Mhqz9u0adNK26oL5P0F+XXq1PG53V2mUaNG+UxKd6h27dp5/v/UU0/V8OHDtWTJEn388cf69NNPtWjRIi1atEi33nqrli9frsMPP1ySFBcXp4yMDN144416++239emnn+rTTz/V3LlzNXfuXI0bN06LFi3yJJ4LlUOH/QMAgHIE5QCAqJKSkqJXXnlF3bp105YtW/TAAw/o5ptvliS1adNGktSkSRM9++yzps/ZqlUrSaq0zJhbQUGB9uzZE1R527Rpox9++EEXXHBBlUPcq5KSkuLVw/7LL7/o8ssv1+LFi3XZZZd5ht27paenKz09Xddee60Mw9DKlSt19tlna+nSpXr++ecrTQOoyH0d3L37vuTm5nrtCwAAqkeTNQAg6jRr1swTiN93332egPnoo49W06ZNlZOTo++++870+Y499lhJ5YnODh48WOn1BQsWBF3W0aNHS5JeffXVoM/h1qZNG82aNUuStH79+mr3dblcOuGEE3T22Web2l/6e779K6+84nOo+aJFi7R7924lJydXOWffadyJ9Xx9rwAAhANBOQAgKv373/9W27ZtVVBQoPvvv19SeabxGTNmyDAMnXzyyfrkk08qHVdaWqqVK1d6rT1++umnKy0tTZs3b9ZNN93kGXIuST/88INuu+22oMt58cUXq127dnrttdd03XXXqaioqNI++fn5mjdvnuf5V199pVdeeUX79++vtO/SpUsleQ91f/755/Xll19W2reoqMiTvO3Q/aty+umnq23bttq+fbuuuuoqr0A2NzdXV199tSTp8ssvV1JSkt/zOUHr1q0lKaBGGgAArMTwdQBAVEpMTNTMmTN1/vnn6+GHH9aVV16pxo0b67LLLtPWrVt177336phjjlGPHj3UuXNn1alTR/n5+Vq/fr327NmjuXPnasCAAZLKlybLyMjQmDFjdM8992jhwoXq27evdu3apVWrVmn8+PFas2aNtm7d6rWkmRn16tXT22+/rbFjx+qee+7RU089pSOOOEKtW7fWvn379OOPP+r7779X8+bNddFFF0mStmzZorPOOkt16tRRnz591KZNGx08eFDffvutNmzYoISEBN1zzz2e91i4cKEmTZqkli1bqnfv3p4M6Z9++qkKCgrUs2dPz7n9XdPXX39do0aN0ty5c/XOO+9owIABKioq0sqVK1VcXKyRI0dqxowZAV0DO5100kmaNWuWHnnkEWVnZ6tNmzaKi4vTiSee6HMJOwAArEZPOQAgap133nlKT09XUVGR7r33Xs/2e+65R59++qnOOecc/fnnn1q2bJnefvttbd++Xccdd5yefvrpSkuUHX/88VqzZo1OPvlk7dq1S2+++aZ+/fVX3XnnncrIyFB+fr7i4uLUuHHjgMvZo0cPffPNN7rnnnvUvXt3ffPNN3rttde0Zs0a1atXT9dcc40WLVrk2X/AgAG6++67NWzYMG3fvl1LlizR+++/r/j4eF166aX65ptvNGrUKM/+V199ta644gq1bt1aWVlZeu2115SVlaX09HT95z//0eeff67k5GRTZT366KO1fv16XXrppYqPj9eiRYv08ccf68gjj9TcuXP11ltvBdwwYacjjjhCb7zxhgYOHKg1a9bo2Wef1TPPPKOsrCy7iwYAiBEuwzAMuwsBAEAk++ijjzR06FAdfvjh+uabb+wuDgAAiCD0lAMAYMLvv//uySx+qOzsbM/Qb3/ZywEAACqipxwAABNWrVqlYcOGKT09XR07dlSdOnWUm5urrKwslZWVacSIEXrnnXdUqxbpWgAAgHkE5QAAmLB9+3bdddddWr16tbZt26aioiIlJyerR48eOvvss3XRRRcRkAMAgIARlAMAAAAAYBPmlAMAAAAAYBOCcgAAAAAAbEJQDgAAAACATQjKAQAAAACwCUE5AAAAAAA2ISgHAAAAAMAmBOUAAAAAANiEoBwAAAAAAJsQlAMAAAAAYBOCcgAAAAAAbEJQDgAAAACATQjKAQAAAACwCUE5AAAAAAA2ISgHAAAAAMAmBOUAAAAAANiEoBwAAAAAAJsQlAMAAAAAYBOCciDC/PTTT/rHP/6hlJQUuVwuvfnmm9XuP3PmTLlcrvAULkIcd9xxOu644+wuBgAgyrzwwgvq1q2bateurYYNG/rdv3379po8eXLIyxUpNm/eLJfLpWeffdbuogBhRVCOmPHss8/K5XIpKSlJ27Ztq/T6cccdp549e9pQsr/t2bNHSUlJcrlc+v77733uM2nSJH377be688479cILL6hv375hLqW12rdvL5fLpcsvv7zSa6tWrZLL5dLrr78e8HlzcnI0c+ZMbd682YJSAgBi3eOPPy6Xy6X+/fv7fP2HH37Q5MmT1alTJ82bN09PPfVUmEtoLXcd7HK59OWXX1Z6ffLkyapfv35Q516wYIEeeuihGpYQiB4E5Yg5JSUluvvuu+0uhk+vvfaaXC6XUlNT9eKLL1Z6ff/+/crMzNQFF1ygyy67TBMnTlTr1q2rPefNN9+s/fv3h6rIlpk3b562b99u2flycnI0a9Ysn0H5+++/r/fff9+y9wIARL8XX3xR7du319q1a7Vx48ZKr69atUplZWV6+OGHNXnyZJ1xxhl+z7lhwwbNmzcvFMW11MyZMy09X1VBebt27bR//36de+65lr4f4HQE5Yg5vXv3tjwAtEpGRob++c9/asKECVqwYEGl13///XdJMjUkbu/evZKkWrVqKSkpydJyWq1Hjx4qLS0NW2NJQkKCEhISwvJeAIDIl5ubq88++0wPPPCAmjVr5rPhfMeOHZL819GGYXgayxMTE1W7dm3Ly2ul3r1766233lJWVlbI38s9ojE+Pj7k7wU4CUE5Ys6NN95oOgA8ePCgbr/9dnXq1EmJiYlq3769brzxRpWUlHjt1759e40dO1affPKJ+vXrp6SkJHXs2FHPP/+86XJt3bpVH3/8sc466yydddZZnhsAt5kzZ6pdu3aSpGuvvVYul0vt27f3vOZyuZSTk6Ozzz5bjRo10pAhQ7xeqygjI0P9+vVT3bp11ahRIx177LFevceLFy/WmDFj1LJlSyUmJqpTp066/fbbVVpa6nUe97D/nJwcDRs2THXr1lWrVq10zz33mP7s7du313nnnWeqsWTLli3697//ra5du6pOnTpq0qSJTj/9dK8e8WeffVann366JGnYsGGe4XerVq3ylLninPIdO3boggsuUIsWLZSUlKRevXrpueee89rHPdftvvvu01NPPeX5uzj66KO1bt06r33z8/M1ZcoUtW7dWomJiUpLS9P48eMZTg8AEejFF19Uo0aNNGbMGJ122mmVgvL27dtrxowZkqRmzZrJ5XJ5epfd9wjvvfee+vbtqzp16ujJJ5/0vFZxTvmePXt05ZVXqn379kpMTFTr1q113nnn6Y8//pAkHThwQLfeequOOuoopaSkqF69ejrmmGP04Ycfep0nkDqrOpdffrkaNWpkqrfczL3Dcccdp7fffltbtmzx1M/u+5mq5pSvXLlSxxxzjOrVq6eGDRtq/Pjxlab5ue93Nm7cqMmTJ6thw4ZKSUnRlClTtG/fPq99ly9friFDhqhhw4aqX7++unbtqhtvvNH0NQGsVsvuAgDh1qFDB08AeP3116tly5ZV7nvhhRfqueee02mnnaarr75aa9as0ezZs/X9999r0aJFXvtu3LhRp512mi644AJNmjRJ//3vfzV58mQdddRR6tGjh99yvfTSS6pXr57Gjh2rOnXqqFOnTnrxxRc1aNAgSdIpp5yihg0b6sorr9SECRP0z3/+s9JcrtNPP11dunTRXXfdJcMwqnyvWbNmaebMmRo0aJBuu+02JSQkaM2aNVq5cqX+8Y9/SCoPbOvXr6+rrrpK9evX18qVK3XrrbeqsLBQ9957r9f5du/erVGjRumUU07RGWecoddff13XXXedDj/8cI0ePdrvZ5ekm266Sc8//7zuvvtuPfLII1Xut27dOn322Wc666yz1Lp1a23evFlz587Vcccdp5ycHNWtW1fHHnuspk2bpkceeUQ33nijunfvLkme/1a0f/9+HXfccdq4caMuu+wydejQQa+99pomT56sPXv2aPr06V77L1iwQEVFRfrXv/4ll8ule+65R6eccop+/vlnT4/Hqaeequ+++06XX3652rdvrx07dmj58uXaunWr5+YDABAZXnzxRZ1yyilKSEjQhAkTNHfuXK1bt05HH320JOmhhx7S888/r0WLFmnu3LmqX7++jjjiCM/xGzZs0IQJE/Svf/1LF110kbp27erzff78808dc8wx+v7773X++eerT58++uOPP7RkyRL9+uuvatq0qQoLC/X0009rwoQJuuiii1RUVKRnnnlGI0eO1Nq1a9W7d2+vc5qps6rToEEDXXnllbr11luVlZWlPn36VLmvmXuHm266SQUFBfr111/14IMPSlK1c9NXrFih0aNHq2PHjpo5c6b279+v//znPxo8eLCysrIq1alnnHGGOnTooNmzZysrK0tPP/20mjdvrjlz5kiSvvvuO40dO1ZHHHGEbrvtNiUmJmrjxo369NNP/V4LIGQMIEbMnz/fkGSsW7fO2LRpk1GrVi1j2rRpnteHDh1q9OjRw/N8/fr1hiTjwgsv9DrPNddcY0gyVq5c6dnWrl07Q5Lx0Ucfebbt2LHDSExMNK6++mpT5Tv88MONc845x/P8xhtvNJo2bWr89ddfnm25ubmGJOPee+/1OnbGjBmGJGPChAmVzut+ze2nn34y4uLijJNPPtkoLS312resrMzz//v27at0rn/9619G3bp1jeLiYs+2oUOHGpKM559/3rOtpKTESE1NNU499VS/n7tdu3bGmDFjDMMwjClTphhJSUnG9u3bDcMwjA8//NCQZLz22mvVliszM7NSGV577TVDkvHhhx9W2n/o0KHG0KFDPc8feughQ5KRkZHh2XbgwAFj4MCBRv369Y3CwkLDMP6+/k2aNDF27drl2Xfx4sWGJGPp0qWGYRjG7t27fX5PAIDI88UXXxiSjOXLlxuGUV5Xtm7d2pg+fbrXfu769vfff/fa7r5HWLZsWaVzt2vXzpg0aZLn+a233mpIMhYuXFhpX3cdffDgQaOkpMTrtd27dxstWrQwzj//fM82s3VWVQ6tg/fs2WM0atTIOPHEEz2vT5o0yahXr57XMWbvHcaMGWO0a9eu0r7uMs+fP9+zrXfv3kbz5s2NnTt3erZ9/fXXRlxcnHHeeed5trmv/6HXwDAM4+STTzaaNGnief7ggw/6/J4AOzF8HTGpY8eOOvfcc/XUU08pLy/P5z7vvPOOJOmqq67y2n711VdLkt5++22v7enp6TrmmGM8z5s1a6auXbvq559/9lueb775Rt9++60mTJjg2TZhwgT98ccfeu+998x9KEmXXHKJ333efPNNlZWV6dZbb1VcnPdPwKHD3OvUqeP5/6KiIv3xxx865phjtG/fPv3www9ex9WvX18TJ070PE9ISFC/fv1MffZD3XzzzTp48GC1UwsOLddff/2lnTt3qnPnzmrYsGHQ893eeecdpaamel3/2rVra9q0afrzzz+1evVqr/3PPPNMNWrUyPPc/b27P2+dOnWUkJCgVatWaffu3UGVCQDgDC+++KJatGihYcOGSSqvK88880y9/PLLlaZ0VaVDhw4aOXKk3/3eeOMN9erVSyeffHKl19x1dHx8vCcvSllZmXbt2qWDBw+qb9++PutBf3WWGSkpKbriiiu0ZMkSffXVV1XuF8i9gxl5eXlav369Jk+erMaNG3u2H3HEERoxYoTnXu1QFe+FjjnmGO3cuVOFhYWS/p7zv3jxYpWVlQVcJiAUCMoRs/wFgFu2bFFcXJw6d+7stT01NVUNGzbUli1bvLa3bdu20jkaNWpkKijLyMhQvXr11LFjR23cuFEbN25UUlKS2rdv7zOZTFU6dOjgd59NmzYpLi5O6enp1e733Xff6eSTT1ZKSooaNGigZs2aeQLvgoICr31bt25dad662c9+KDONJfv379ett96qNm3aKDExUU2bNlWzZs20Z8+eSuUya8uWLerSpUulRgr3cHd/37X7Zsf9eRMTEzVnzhy9++67atGihY499ljdc889ys/PD6p8AAB7lJaW6uWXX9awYcOUm5vrqaP79++v3377TR988IGp85ipn6XyOtrM8qzPPfecjjjiCCUlJalJkyZq1qyZ3n77bZ/1oL86y6zp06erYcOG1c4tD+TewQx3/etruH/37t31xx9/eBLbuvn7vGeeeaYGDx6sCy+8UC1atNBZZ52lV199lQAdtiIoR8zq2LGjJk6cWG0AKMlnkjRfqsoUalQzt9v9+ksvvaS9e/cqPT1dXbp08Tw2b96sxYsX688//zRVhkNbqGtiz549Gjp0qL7++mvddtttWrp0qZYvX+6Zj1Wx4gr2s/ty00036eDBg573qujyyy/XnXfeqTPOOEOvvvqq3n//fS1fvlxNmjQJW4Vq5vNeccUV+vHHHzV79mwlJSXplltuUffu3avtYQAAOMvKlSuVl5enl19+2at+di93Zrbh3Kr6WSpvyHevh/7MM89o2bJlWr58uY4//nif9aBVdbS/3vJA7x1Cxd/nrVOnjj766COtWLFC5557rr755hudeeaZGjFihOmRD4DVSPSGmHbzzTcrIyPDZwDYrl07lZWV6aeffvJKEPbbb79pz549nkzoNbV69Wr9+uuvuu222yolItu9e7cuvvhivfnmm17Dw2uiU6dOKisrU05OTqVkMG6rVq3Szp07tXDhQh177LGe7bm5uZaUwV/5Jk6cqCeffFL9+/ev9Prrr7+uSZMm6f777/dsKy4u1p49e7z2M9uYIpV/1998843Kysq8esvdQ+2C/a47deqkq6++WldffbV++ukn9e7dW/fff78yMjKCOh8AILxefPFFNW/eXI899lil1xYuXKhFixbpiSeesCzo7tSpk7Kzs6vd5/XXX1fHjh21cOFCr7rOnf09lK644go99NBDmjVrVqWl3wK5dzBbR7vr3w0bNlR67YcfflDTpk1Vr169AD5Bubi4OJ1wwgk64YQT9MADD+iuu+7STTfdpA8//FDDhw8P+HxATdFTjph2aABYcWjxP//5T0nlGVUP9cADD0iSxowZY0kZ3EPXr732Wp122mlej4suukhdunQJaAi7PyeddJLi4uJ02223VWq1drciu1uZD21FP3DggB5//HHLylGdm2++WX/99ZfPZdXi4+Mrte7/5z//qdS67a6kKwbrvvzzn/9Ufn6+XnnlFc+2gwcP6j//+Y/q16+voUOHBlT+ffv2qbi42Gtbp06dlJycXGk5PQCAM+3fv18LFy7U2LFjK9XPp512mi677DIVFRVpyZIllr3nqaeeqq+//rrSCi9S9XX0mjVrlJmZaVk5quLuLV+8eLHWr1/v9Vog9w716tUzNZw9LS1NvXv31nPPPedVn2dnZ+v999/33KsFYteuXZW2uTspqKNhF3rKEfNuuukmvfDCC9qwYYPX0mW9evXSpEmT9NRTT3mGZK1du1bPPfecTjrpJE/Cl5ooKSnRG2+8oREjRigpKcnnPieeeKIefvhh7dixo8bvJ0mdO3fWTTfdpNtvv13HHHOMTjnlFCUmJmrdunVq2bKlZs+erUGDBqlRo0aaNGmSpk2bJpfLpRdeeCGo4ejBcDeWVFwnXJLGjh2rF154QSkpKUpPT1dmZqZWrFihJk2aeO3Xu3dvxcfHa86cOSooKFBiYqKOP/54NW/evNI5L774Yj355JOaPHmyvvzyS7Vv316vv/66Pv30Uz300ENKTk4OqPw//vijTjjhBJ1xxhlKT09XrVq1tGjRIv32228666yzArsYAABbLFmyREVFRTrxxBN9vj5gwAA1a9ZML774os4880xL3vPaa6/V66+/rtNPP13nn3++jjrqKO3atUtLlizRE088oV69emns2LFauHChTj75ZI0ZM0a5ubl64oknlJ6ebnq6W01Mnz5dDz74oL7++muvXupA7h2OOuoovfLKK7rqqqt09NFHq379+ho3bpzP97v33ns1evRoDRw4UBdccIFnSbSUlBRTa6dXdNttt+mjjz7SmDFj1K5dO+3YsUOPP/64WrdurSFDhgR8PsAK9JQj5nXu3LnKoeFPP/20Zs2apXXr1umKK67QypUrdcMNN+jll1+25L3ffvtt7dmzp8qKSJLGjRungwcPWvaeUnmF9N///lf79+/XTTfdpFtvvVVbtmzRCSecIElq0qSJ3nrrLaWlpenmm2/WfffdpxEjRvjsuQ6Vm2++2ee8sIcffljnnXeeXnzxRV199dXKy8vTihUrKq1xmpqaqieeeEI7duzQBRdcoAkTJignJ8fne9WpU0erVq3SOeeco+eee05XX321du3apfnz51dao9yMNm3aaMKECVq1apVuuOEG3XDDDSosLNSrr76qU089NeDzAQDC78UXX1RSUpJGjBjh8/W4uDiNGTNGy5Yt086dOy15z/r16+vjjz/W1KlT9c4772jatGl6/PHH1bVrV7Vu3VqSNHnyZN111136+uuvNW3aNL333nvKyMhQ3759LSmDPw0bNtQVV1xRaXsg9w7//ve/dfbZZ2v+/Pk6++yzdfnll1f5fsOHD9eyZcvUpEkT3Xrrrbrvvvs0YMAAffrpp6YT6B3qxBNPVNu2bfXf//5Xl156qR577DEde+yxWrlypVJSUgI+H2AFlxGuri8AAAAAAOCFnnIAAAAAAGxCUA4AAAAAgE0IygEAAAAAsAlBOQAAAAAANiEoBwAAAADAJgTlAAAAAADYpJbdBQi1srIybd++XcnJyXK5XHYXBwAAGYahoqIitWzZUnFxtI/XFHU9AMBpAqnroz4o3759u9q0aWN3MQAAqOSXX35R69at7S5GxKOuBwA4lZm6PuqD8uTkZEnlF6NBgwY2lwYAAKmwsFBt2rTx1FGoGep6AIDTBFLXR31Q7h7G1qBBAypqAICjMNTaGtT1AACnMlPXM5ENAAAAAACbEJQDAAAAAGATgnIAAAAAAGxCUA4AAAAAgE0IygEAAAAAsAlBOQAAAAAANiEoBwAAAADAJgTlAAAAAADYhKAcAAAAAACbEJQDAAAAAGATW4Pyjz76SOPGjVPLli3lcrn05ptver2+cOFC/eMf/1CTJk3kcrm0fv16W8oJAAAAAEAo2BqU7927V7169dJjjz1W5etDhgzRnDlzwlwyAAAAAABCr5adbz569GiNHj26ytfPPfdcSdLmzZvDVCIAAAAAAMLH1qAckaO0zNDa3F3aUVSs5slJ6tehseLjXHYXCwAAAAAiWtQF5SUlJSopKfE8LywstLE00WFZdp5mLc1RXkGxZ1taSpJmjEvXqJ5pNpYMABCLqOsBANEk6rKvz549WykpKZ5HmzZt7C5SRFuWnaepGVleAbkk5RcUa2pGlpZl59lUMgBArKKuBwBEk6gLym+44QYVFBR4Hr/88ovdRYpYpWWGZi3NkeHjNfe2WUtzVFrmaw8AAEKDuh4AEE2ibvh6YmKiEhMT7S5GVFibu6tSD/mhDEl5BcVam7tLAzs1CV/BAAAxjboeABBNbA3K//zzT23cuNHzPDc3V+vXr1fjxo3Vtm1b7dq1S1u3btX27dslSRs2bJAkpaamKjU11ZYyx5IdRVUH5MHsBwAAAADwZuvw9S+++EJHHnmkjjzySEnSVVddpSOPPFK33nqrJGnJkiU68sgjNWbMGEnSWWedpSOPPFJPPPGEbWWOJc2TkyzdDwAAAADgzdae8uOOO06GUfV85MmTJ2vy5MnhKxC89OvQWGkpScovKPY5r9wlKTWlfHk0AAAAAEDgoi7RG6wTH+fSjHHpksoD8EO5n88Yl8565QAAAAAQJIJyVGtUzzTNndhHqSneQ9RTU5I0d2If1ikHAAAAgBqIuuzrsN6onmkakZ6qtbm7tKOoWM2Ty4es00MOAAAAADVDUA5T4uNcLHsGAAAAABZj+DoAAAAAADYhKAcAAAAAwCYE5QAAAAAA2ISgHAAAAAAAmxCUAwAAAABgE4JyAAAAAABsQlAOAAAAAIBNCMoBAAAAALAJQTkAAAAAADYhKAcAAAAAwCYE5QAAAAAA2ISgHAAAAAAAmxCUAwAAAABgE4JyAAAAAABsQlAOAAAAAIBNCMoBAAAAALAJQTkAAAAAADYhKAcAAAAAwCYE5QAAAAAA2ISgHAAAAAAAmxCUAwAAAABgE4JyAAAAAABsQlAOAAAAAIBNCMoBAAAAALAJQTkAAAAAADYhKAcAAAAAwCYE5QAAAAAA2ISgHAAAAAAAmxCUAwAAAABgE4JyAAAAAABsQlAOAAAAAIBNCMoBAAAAALAJQTkAAAAAADapZXcBAIRWaZmhtbm7tKOoWM2Tk9SvQ2PFx7nsLhYAAAAAEZQjxAgI7bUsO0+zluYor6DYsy0tJUkzxqVrVM80G0sGAAAAQCIoRwgRENprWXaepmZkyaiwPb+gWFMzsjR3Yh++BwAAAMBmzClHSLgDwkMDcunvgHBZdp5NJYsNpWWGZi3NqRSQS/Jsm7U0R6VlvvYAAAAAEC4E5bAcAaH91ubuqtQgcihDUl5Bsdbm7gpfoQAAAABUQlAOyxEQ2m9HUdXXP5j9AAAAAIQGQTksR0Bov+bJSZbuBwAAACA0CMphOQJC+/Xr0FhpKUmqKs+9S+VJ9/p1aBzOYgEAAACogKAcliMgtF98nEszxqVLUqXvwf18xrh0lqcDAAAAbEZQDssREDrDqJ5pmjuxj1JTvEckpKYksRwaAAAA4BCsUx7BSssMrc3dpR1FxWqeXN7z7JRA1x0QVlynPJV1ysNqVM80jUhPdezfCQAAABDrCMoj1LLsvEoBb5rDAl4CQmeIj3NpYKcmdhcDqJaTGxkBAABCiaA8Ai3LztPUjKxK64DnFxRrakaWo4YmExAC8CcSGhkBAABChTnlEaa0zNCspTmVAnJJnm2zluaotMzXHgDgLO5GxkMDcunvRsZl2Xk2lQwAACA8CMojzNrcXZVuXg9lSMorKNba3F3hKxQABIFGRgAAAIavR5wdRVUH5MHsF42Ymwor8fcUOoE0MjINBgAARCuC8iAEcpNu9Q198+Qk/zsFsF+0YW4qrMTfU2jRyAgAAEBQHrBAbtJDcUPfr0NjpaUkKb+g2OeQT5fKlx3r16FxUOePZJGUAA/Ox99T6NHICAAAwJzygASSkChUyYvi41yaMS5dUnkAfij38xnj0mNueC1zU2El/p7Cw93IWNWvlUvlDZmx2MgIAABiB0G5SYHcpIf6hn5UzzTNndhHqSnevUepKUkx23tHAjxYib+n8KCREQAAgOHrpgV6k17T5EX+5qKP6pmmEempls1Xj/RkVsxNhZX4ewofdyNjxak+qczdBwAAMYKg3KRQ3KRXta/ZuejxcS5LMhJHQzIr5qbCSvw9hZfVjYwAAACRhOHrJgVyk16TG/pQzUWvSrjfL1SYmwor8fcUfu5GxvG9W2lgpyYE5AAAIGYQlJsUyE16sDf0NZ2LXlpmKHPTTi1ev02Zm3b6nbMeTcmsmJsKK/H3BAAAgHAhKDcpkJv0YG/oa5Jcall2nobMWakJ8z7X9JfXa8K8zzVkzspqe7qjLZlVtCTAC7RxBaERLX9PAAAAcDbmlAcgkIREwSQvCnbeejDrKZeWGfp04+9BvV8wwpVILtLnpoZzfn+kJ/cLh0j/ewIAAIDzEZQHyN9NesVAZ/W1w/Tllt1+b+hLywz9UVRiqgyHzkU3OwR9RHqq5319BX5m3y8Y4U4kZ1UCvHALpnGlJu8V6cn9wiVS/54AAAAQGQjKg1DVTXp1gc743q2qPJ/ZINml8p72Q+ei+xuCLpUPQX905U+aPvywKgM/s+8XqHAGmpHMX+OKS5UbV4LFdwIAAAA4B3PKLRJsFvOqjquoqrnoZoeWP7jiJ73zzfYqAz+z7xeIaEokF2rhmt9/4GCZblz0Ld8JAAAA4BAE5RYINvis7riKqkouFcjQ8psXZ5sesm5FMqtoSyQXCu6kbu+aXHquJvP7l2XnacDsD7Rr719V7sN3AgAAAIQXw9ctEEjweeiwdzNDzyXpljHdNXlwB5891u7l18ycp7pg7FCXDeusK0ccVuNh0sEmrosVgc7tl4Kf3x/ItAUpdr8TAAAAINzoKbdAsMGn2eOaJidWGSAfuvyaVQZ3bmpJdmmzAWRNE8lFIrPTFtyqWtvejEBGZLjF4ncCAAAA2IGg3ALBBp9WBa2jeqbpyuGHmTpX43oJldZOd6tJ4OeLuxc/XO9npVCuFR5okFzT+f1mR2S438up3wkAAAAQjQjKA1BVoBZs8Gll0HrZ8Z2V2iCxytfd57pjfE/P84qvSzVL7FbRob344Xg/qyzLztOQOSs1Yd7nmv7yek2Y97mGzFlZZbK+QAUSJEs1n98f6FB0J34nAAAAQLQiKDepukAt2ODTyqA1Ps6lmSf2kMvPuf55RJrmTuyj1BTv3ncrErv5MqpneN+vpoLNoh8Is0HyeQPb6aWLBuiT646v0XUyOyKjcb3ajvxOAAAAgGjmMgwjqtc+KiwsVEpKigoKCtSgQYOgzlFVkix3sOsOZKpbp7y6QCfY42pyrtIyQ2tzd2lHUbGaJ5f3xoeydzTc7xeM0jJDQ+asrLIX271u+yfXHV+jsmdu2qkJ8z73u99LFw3wSgwYLPfnyi8ornLIfJN6Ccq84QQl1KKdDggHK+om/I3rCQBwmkDqJrKv++FvuTOXypc7G5GeqlE90zQiPTXg4DPY42pyrvg4lyUBn1nhfr9gBJtFP1DuaQtVBcnu4N+qed3uERlTM7Lkkrze0/1XcefJPQnIAQAAABsQlPsRaKAWbPBpZdAaCQGwE4VrCTczQbLV87rd0wgqjqJIDXJEBgAAAABrEJT7wVrbkcPsEPmq9gvnEm52BMlWjsgAAAAAYA2Ccj9YazsymJ1LX91+I9JTwzqs3I4gmVEUAAAAgLPYOon0o48+0rhx49SyZUu5XC69+eabXq8bhqFbb71VaWlpqlOnjoYPH66ffvoprGWM5LW2Y4XZjOn+9luek+/Jhl9RqIaVu4Pk8b1beaY/AAAAAIgdtgble/fuVa9evfTYY4/5fP2ee+7RI488oieeeEJr1qxRvXr1NHLkSBUXh2+oeKSutR0r/CXik8oT8R04WGZqv7IyQyl1a1fap2FdlgsDAAAAYD1bh6+PHj1ao0eP9vmaYRh66KGHdPPNN2v8+PGSpOeff14tWrTQm2++qbPOOits5SRJlnOZTcT3QuZmU/v9e8FXPl/fve+vGpYUAAAAACpz7Jzy3Nxc5efna/jw4Z5tKSkp6t+/vzIzM6sMyktKSlRSUuJ5XlhYaEl5SJLlTGYT7G3Zta9G73Po0nd85wBgr1DV9QAA2MGxCxPn5+dLklq0aOG1vUWLFp7XfJk9e7ZSUlI8jzZt2lhWJub/Oo/ZBHvtGtet0fscuvQdAMBeoazrAQAIN8cG5cG64YYbVFBQ4Hn88ssvdhcpIpSWGcrctFOL129T5qadKi3zNfvaecwm4jt3YPtq9zOLpe8AwH7U9QCAaOLY4eupqamSpN9++01paX/P2f7tt9/Uu3fvKo9LTExUYmJiqIsXVcwuJ+ZE7kR8UzOy5JK8ErkdmogvoVZctfuZbYJg6bu/mV0XHgCsRl0PAIgmju0p79Chg1JTU/XBBx94thUWFmrNmjUaOHCgjSWLLmaXE3MydyK+1BTvgDk1JckrY3p1+z1+dh+WvgvAsuw8DZmzUhPmfa7pL6/XhHmfa8iclRHx9wIAAAA4ia095X/++ac2btzoeZ6bm6v169ercePGatu2ra644grdcccd6tKlizp06KBbbrlFLVu21EknnWRfoaOIv+XEIim5mdlEfNXtFxcnvz3uTr8O4eBuyKn4d+NuyGHpOAAAAMA8W4PyL774QsOGDfM8v+qqqyRJkyZN0rPPPqv/+7//0969e3XxxRdrz549GjJkiJYtW6akJIYQW8HscmJrc3dpYKcm4StYkNyJ+ILdL5xL30Xq0O9oasgBAAAAnMDWoPy4446TYVQ9m9flcum2227TbbfdFsZSxQ6zSctiKblZOJa+i+Q5/NHWkAMAAADYzbGJ3hB6ZpOWxVpyM7M97sGI9KHfNOQAAAAA1nJsojeEntnlxEhuZg1/Q7+l8qHfTl6OjoYcAAAAwFoE5THMvZyYpEqBOcnNrBfI0G+noiEHAAAAsBZBeQwrLTOUUidB5w9ur0b1anu9VnE5MdRcNAz9piEHAAAAsBZzymOUr2Rjjesl6KTeLTUiPdVR2cBrmqncKZnOo2Xodziz1AMAAADRjqA8BlWVbGz33gOa/+lmRwXkNc1U7qRM5+6h3/kFxT7nlbtUHthGwtDvcGSpBwAAAGIBw9djTCQlG3M3HlSch+3OVL4sOy+kx1st2oZ+u7PUj+/dSgM7NYmYcgMAAESL0jJDmZt2avH6bcrctNMR9/AIHD3lMSZS1pn213jgUnnjwYj0VJ/BYE2PDxWGfgMAAMAKThoRipohKI8xkZJsrKaNB05ufGDod3CckhsAAADAblVNR3WPCCVhc2QhKI8xkZJsrKaNB05vfHAP/YY5tAQDAACUc+qIUASPOeVBitT5G5GyznRNGw8ipfEB/jktNwAAAICdAhkRishAT3kQIrnXzp1sbGpGllySVwubk5KN1TRTeTRlOg+VSBgOTkswAACAN6ePCEXg6CkPUDT02rmTjaWmePcSp6YkOWb+SU0zlUdbpnOrLcvO05A5KzVh3uea/vJ6TZj3uYbMWem4v19aggEAALwxIjT60FMegGjqtYuEZGM1zVROpnPfIikxCC3BAAAA3hgRGn0IygPg5IzewYiEZGM1bTyIhMaHcIq0hiVaggEAALxFynRUmEdQHgB67exR08YDOxsfnDZvO9IalmgJBgAAqIwRodGFoDwANe21Ky0z9Pmmncr8+Q9J5YHigI5NaMWKUk5MCBhpDUu0BAMAAPjGiNDoQVAegJr02i3LztP1C7/Vnn1/ebY9+uFGNaxbW3efcjitWVHGqfO2I3E4OC3BAAAAvkXCdFT4R1AegGB77ZZl5+mSjCyf59yz7y9dkpGlJxyUXAs14+R525E6HJyWYAAAAEQrlkQLUKDLiZWWGZq55Du/5521NEelZb7CJGcoLTOUuWmnFq/fpsxNOx1dVrs5eRmvSF4qzt0SPL53Kw3sxLQPAAAARAd6yoMQSK/d2txdyi8s8XtOJyXXqqi6udH0Xlbm9HnbDAcHAOdwWkJQAED4EZQHyez8jUACL6ck1zpUdXOjL8nIUsO6tb3mydudyMwJImHeNsPBAcB+TkwICgAIP4avh1gggZeTkmtJ/udGS/IKyKW/E5kty84Lefmcyj1vu6rw1qXymy67520zHBwA7ONu9K443Yl6FABiD0F5iPXr0FipDRL97ueEIK0if3OjfXEH606fIx9KkTxvGwAQemYavWO5HgWAWENQHmLxcS7NPLGH3/2cGKQFO5zezkRmThFoQkAAQOxwckJQAED4Mac8DEb1TNMTE/tUWqdckhrVra3ZDl2nvKbD6Z04Rz6cmLcNAPDF6QlBAQDhRVAeJu4A7fNNO5X58x+SyufzDujo3Lm8/ta09sdpc+TtYDYhIAAgdkRCQlAAQPgQlIdRfJxLg7s01eAuTe0uiinuudFTM7LkkkwH5i6VD9N22hx5AACcwF+jN/UoAMQW5pSjWlXNjW5Ut7YkEpkBABAoEoICAA5FTzn8qmpu9PKc/Errq6ayvioAAH65G72pRwEABOUwxdfcaBKZAQAQPOpRAIBEUI4aIpEZAADBox4FABCUO0hpmUFrOQAAAADEEIJyh1iWnVdpXlka88oAAAAAIKqRfd0BlmXnaWpGlldALkn5BcWampGlZdl5NpUMAAAAABBKBOU2Ky0zNGtpjs91St3bZi3NUWmZ2VXCAQAAAACRgqDcZmtzd1XqIT+UISmvoFhrc3eFr1AhUFpmKHPTTi1ev02Zm3bSyAAAAAAAYk657XYUVR2QB7OfEzFfHgAAAAB8o6fcZs2TkyzdL9QC7fFmvjwAAAAAVI2ecpv169BYaSlJyi8o9jmv3CUpNaV8ebRQMrMcW6A93v7my7tUPl9+RHoqS78BAAAAiEkE5TaLj3Npxrh0Tc3IkkvyCmDdYeqMcekhDVrNBNvuHu+KAba7x3vuxD6VAvNA5ssP7NTEqo8DAAAAABGD4esOMKpnmuZO7KPUFO8h6qkpST6DXSuZGV4ebIb4WJgvDwAAAAA1QU+5Q4zqmaYR6al+h5Bbyezw8uSk2kH1eEfafHkAAAAACDeCcgeJj3OFdRi32eHlmZt2mjpfxR5vp8yXBwAAAACnYvh6DDM/bNzcmuIVe7zd8+Wlv+fHu4VrvjwAAAAAOBlBeQwzO2x8YMemSktJqhRYu7lUnhjOV4+3nfPlAQAAAMDpGL4ew8wOLx/QqUmNMsTbMV/eFzPLvgEAAABAOBGUmxSNAV0gy7G5e7wrLp2WWs065RXfy85lzwJdYx0AAAAAwsFlGIa5CcMRqrCwUCkpKSooKFCDBg2COke0B3SBfL5IbJyoao11d6kZRg8g3Kyom/A3ricAwGkCqZsIyv2IlYAuEoNtM0rLDA2Zs7LKLPPuIfqfXHd8VHxeAJGBINJaXE8AgNMEUjcxfL0aZtfxHpGeGvEBnd3Dy0PF7LJvFddYBwAAAIBwIPt6NQIJ6OBMZpd9M788HAAAAABYh6C8GgR0kc/ssm9m9wMAAAAAKzF8vRrhCOiidS63U5hd9s3XGusAAAAAEGoE5dUIdUAX7VndnSCQZd8AAAAAINwYvl4Nd0An/R3AudU0oHNnda84Zz2/oFhTM7K0LDsviBLDF/ca66kp3iMaUlOSoiZ7PgAAAIDIRE+5H+6AbuaSHOUX/h1Ap9agRzuWsro7xaieaRqRnspUAQAAAACOQlBumncIXZPl3Vmmyx7RuuwbAAAAgMjF8HU/3MPM8wtLvLb/VlgS9DBzsroDAAAAACR6yqtVcZj55qSzvV7/oayNpix92DPM3GwmdZbpAgAAAABIBOXVOnSYeV1V7rXuFveLMktOkW4rfx4vaULxAs/rVWVSZ5kuAAAQDiy9CgDOR1BejUOHj++TuV5rr970EqlDxouaO/Eor8CcZbr84yYCAICaYelVAIgMBOXVqDh8vH3xAp0b/75ur/2s6XPkJp0jva7yhyRd/4uU1MCT1b1iZVmTrO7RgpsIAABqxp0Tp+KIPPfSq5G4JCgN9gCilcuoSRrxCFBYWKiUlBQVFBSoQYMGAR1bWmZoyJyV1Q4zv6XOazrfWBR0+Uov/FBrS9pRwfxPVTcR7isSiTcRAFBRTeomVMb19Oa+f6lqpRf3NLlPrjs+Yu45aLAHEGkCqZvIvl4N9zBz6e+g0M39fFufa9W+eIHncfNfUwJ7j6eHaeALHTX+zXQNfKGj4lffXfOCRyh/67dL5eu3l5ZFdTsSAAA1EsjSq5HA3WBf8TO5e/2DWQkHAJyEoNwP9zDz1BTvoeypKUmaO7GPhqenem3PKB3hFaRfdeCSwN5w9d3SzJS/H3e1rulHiBhOvIkoLTOUuWmnFq/fpsxNO6OmQSBaPxcAILqWXqXBHkAsYE65CaN6pmlEeqrPeUylZUa1mdQXlR2rzMR//D1E7Jd10jPDzb/5gaLy4PxQMwv8HhaJ866cdhMRrUPlovVzAQDKRdPSq4E02A/s1CR8BQMACxGUmxQf5/L5Yx9wJvU2R3sH1bu3SA8fEVhh/ATpdgVdNW0IcNJNRDQmyJGi93MBAP4WTUuvOq3BHgBCIeCgPDc3Vx9//LG2bNmiffv2qVmzZjryyCM1cOBAJSU5v8U1FKrKpN6iQaIm9GurkoNlyty003eQ2qidd1B9sES6o3lgBagQpE8tXhD2oMuKhgCn3ET4GyrnUvlQuRHpqY4ffXCoaP1cACIf9xbWiqalV53UYA8AoWI6KH/xxRf18MMP64svvlCLFi3UsmVL1alTR7t27dKmTZuUlJSkc845R9ddd53atWsXyjI7UsUh7pv/2KeX1m7Vgyt+8uxjKkitlVh5eHrFnnE/cg9dK13lS7mFMuiyqvfVKTcR0TpULlo/F4DIxb1F6ETL0qtOabAHgFAyFZQfeeSRSkhI0OTJk/XGG2+oTZs2Xq+XlJQoMzNTL7/8svr27avHH39cp59+ekgK7GTuIe7LsvP00IofLemtXpadV6nne3OFoNsfr/1LpMzcny0LuqzufXXCTYRdQ+VCnQeAIYAAnIR7i9CrLidOpHBKgz0AhJKpdcrfe+89jRw50tQJd+7cqc2bN+uoo46qceGsEO61S61cG9TsuTJLTqlZoW/MkxLqBnVo5qadmjDvc7/7vXTRgIAaAuxMVBeqz1SdcOQBsONzAfCNdbWtvbfgekY/kpQCiDSB1E2mesrNVpqS1KRJEzVpErs39FYOETZ7rsyLynu+3UFXoD3puqtCZXbhSqm1uUaVUPW+VpVYLxzCPVQuXMnXGAIIwEm4t0AgoqHXHwCqEnCit/z8fK1Zs0b5+fmSpNTUVPXv31+pqal+jowNVgapgZ7LHXR1KKjZcHc9fbz388HTpRG3+dw1GhOwhHOoXDiTrzEEEIBTcW8BM+xssAeAUDIdlO/du1f/+te/9PLLL8vlcqlx4/LetF27dskwDE2YMEFPPvmk6tYNbhh0tLAySA30XFUFXe2LF3iCrrkT+2jU691Mndfj04fLH26160o35UmK3t7XcM1tD3fyNSfM2QcAN+4tAAAIICifPn261q5dq7ffflvDhw9XfHy8JKm0tFQffPCBLr/8ck2fPl3z5s0LWWEjgZVBajDnMhV09ayQ3X3+P6Utn5r6fJKkv/Z5MsLHS8qU1EELoq73NRxD5exIvsYQQABOwb0FAAAmE71JUqNGjfT2229r0KBBPl//9NNPNXbsWO3evdvSAhYVFemWW27RokWLtGPHDh155JF6+OGHdfTRR5s63o7kL+45wpLvIDXg7OtBnKtGidKynpeWXG5u3yq0L15AAhYTSL4GxCYSk5Wz6t6C6wkAcBrLE71JUllZmRISEqp8PSEhQWVlZeZLadKFF16o7OxsvfDCC2rZsqUyMjI0fPhw5eTkqFWrVpa/nxVGpKfqiuFdNP/Tzdqz/y/P9mCGCAc73LhG8676nFf+cNu9RXr4iIBOsTnpbKlE0uv/e1Rcex2Sonf4PwCYYde9BQAATmK6p/ycc87R999/r2eeeUZHHnmk12tfffWVLrroInXr1k0ZGRmWFW7//v1KTk7W4sWLNWbMGM/2o446SqNHj9Ydd9zh9xzhbj33tWRHwzq1NWVwB112fOeghwjbuURYJWVl0m2NanYOgnQPK0dWAIgM9OyWs+regusJAHCakPSUP/roozr77LN11FFHqVGjRmrevLkkaceOHdqzZ49GjhypRx99tGYlr+DgwYMqLS1VUpJ3wrM6derok08+8XlMSUmJSkpKPM8LCwstLVN1qlraqmD/X3poxY/qmlo/6ODKURlH4+IqB9X/m2NuWsX9YzhIJ/kagFgV7L1FOOt6RzWKAwCikumecrfvv/9en3/+udeyJQMHDlS3bgFm9DZp0KBBSkhI0IIFC9SiRQu99NJLmjRpkjp37qwNGzZU2n/mzJmaNWtWpe2hbj0vLTM0ZM7KKjNpu4chf3Ld8bFRmQcapFc0Y4/kioHrdAhu/IDYQc+ut0DvLcJV1/sa/Ua+FACAGYHU9QEH5eG2adMmnX/++froo48UHx+vPn366LDDDtOXX36p77//vtL+vlrP27RpE/IbHxJ2+VHTIP3qDVJydK5XSzAOxB6C8poJR11f1eg3phYBAMwIyfB1STpw4IDefPNNZWZmerVmDxo0SOPHj682WUuwOnXqpNWrV2vv3r0qLCxUWlqazjzzTHXs2NHn/omJiUpMTLS8HP7YsbRVRKnpcPf7u3o/P/UZ6fDTalYmB6AXBkCsC+beItR1fWmZoVlLc3wm4DRUHpjPWpqjEempNKICAGoszuyOGzduVPfu3TVp0iR99dVXKisrU1lZmb766iudd9556tGjhzZu3BiygtarV09paWnavXu33nvvPY0fPz5k7xWM5slJ/ncKYL+oN7PA+xGoNy4oD+zdj+dPsryIoebuhak45SG/oFhTM7K0LDvPppIBQHjYfW9RlbW5u6qcjiaVB+Z5BcVam7srfIUCAEQt0z3lU6dO1eGHH66vvvqqUvd7YWGhzjvvPF166aV67733LC3ge++9J8Mw1LVrV23cuFHXXnutunXrpilTplj6PjVldmmrsjJDi9dvY5hyRTXtSf/5w4hKHkcvDADYd2/hD6PfAADhZDoo//TTT7V27Vqf4+EbNGig22+/Xf3797e0cFJ50pYbbrhBv/76qxo3bqxTTz1Vd955p2rXrm35e9VEfJxLM8ala2pGllyqvLSVIWn/X6U655k1nu0MU65GxYB64b+kb14O8BzODdID6YWJyRwEAGKCXfcW/jD6DQAQTqaHrzds2FCbN2+u8vXNmzerYcOGFhTJ2xlnnKFNmzappKREeXl5evTRR5WSUsOkYSHiXtoqNcW7km5Yt7wBYc++v7y2M0w5AKc86T3c/Zw3Aj/HocPda5p4robohQEA++4t/HGPfqtqnJJL5Q3r/To0DmexAABRynRP+YUXXqjzzjtPt9xyi0444QS1aNFCkvTbb7/pgw8+0B133KHLL788ZAWNFKN6pmlEeqonm3bT+om6+tX1PveN5GHKtmcM7zLcu+d7/25pTvvAzmFjTzq9MADg3HsLf6PfJGnGuPSIqrcBAM4V0JJoc+bM0cMPP6z8/Hy5/reGtGEYSk1N1RVXXKH/+7//C1lBg2X3sjPRuFRaxGQMr2lveAiDdPe69v5yEMTMuvZAjLG7bnISK+4tQnU9I6a+AwA4TsjXKf/555/122+/SSpftqRDhw7BlTQM7L7xWbx+m6a/vN7vfg+f1Vvje7cKfYFqKKLXbXVYkO6+lpLvXhhHX0sANWJ33eRENbm3COX1tH1kGAAgIoVsnXK3jh07VrlOOLxF0zDliM8YXtMM7xX3v+UPKT74hIPuHAQVe2FS6YUBEIOcem8RH+eKmJFsAIDIFFBQnpOTo0cffVSZmZnKz8+XVN6aPXDgQF122WVKT08PSSEjmdml0iIhWUzUZQyvaZB+e1Pv55d9ITXtEtApKuYgoBcGQKzh3gIAEOtMB+XvvvuuTjrpJPXp00fjx4/3SsayfPly9enTR4sXL9bIkSNDVthIFE3JYqI+Y3hNg/RH+3o/H32P1P9ffg+jFwZArOLeAgCAAOaU9+rVS+PHj9dtt93m8/WZM2dq4cKF+uabbywtYE05Zd6e05PFmJkzF41J6wJS0znpbQdJ579rTVkARDSn1E12s+regusJAHCakCR6q1OnjtavX6+uXbv6fH3Dhg3q3bu39u/fH3iJQ8hJFbVTk8WYbTAgY3gFc4dIv31bs3OEcRm2YDj1bxaIdE6qm+xk1b0F1xMA4DQhSfTWvn17vf3221VWnG+//bbatWsXWEljjBOHKVeVTT2/oFhTM7K8MoBH01D8QwUdeE79xPv52nnSO9cE9uY2rpXuj9NHdwCIfNxbAAAQQE/5a6+9prPPPlujR4/W8OHDveZ9ffDBB1q2bJkWLFigU089NaQFDhSt51Vz93xXlbytqp7vaArWQvpZ/vip8jzzQNkUpEf00ndABKBuKmfVvQXXEwDgNCFbp/yzzz7TI4884jND6vTp0zVw4MCalTwEqKirVpM54v56lyNh2HPYA8/SvypnbA9UGIL0YBtrAJhH3fQ3K+4tuJ4AAKcJ2TrlgwYN0qBBg2pUODhHINnUfQXZVQ3Fj4SedFvWXI+vbf1a6SEI0qNu6TsAjsa9hTmR0NgNAAhOQEE5okvz5CRT+23+Y2+lntOqguxA5qiHU8WbmbIywxmBpwOD9Khf+g4AgmBnUBwJjd0AgOAFFJS/8847WrhwoRo3bqwpU6aoe/funtd2796tU089VStXrrS8kAiNfh0aKy0lqdps6g3r1taDK36q9JqvINuW3mcTfN3MNKxT29SxYQ88rQ7Sb8yTEuoGdAqzjTVN6ycGdF4A8CUS7i3sDIqDbeymZx0AIkec2R0XLFigE088Ufn5+crMzFSfPn304osvel4/cOCAVq9eHZJCIjTc2dSlv+dRu7mfV5VwwL191tIclZaVPwtk2HMolZYZyty0U4vXb9PDK37S1IysSuXas/8vU+cyG6CGzMwC70eg7korD9Tdjx0/+D3E3Vjj79bt6lfXa1l2XuBlAoD/iYR7C3dQXLEecQfFofwd9NfYLXnXw27LsvM0ZM5KTZj3uaa/vF4T5n2uIXNW8psNAA5luqf83nvv1QMPPKBp06ZJkl599VWdf/75Ki4u1gUXXBCyAjpFtLY4j+qZprkT+1TqAUhNSdJZR7fVgyt+rPLYikO8nTDs2VdvRjDcycz6dWhsTcGsUtOe9Mf7ez8//Vmpx8lem6pb+u5QvxWW2DolAUDkc/q9hd0jwILJ8eHUaWQAgKqZDsp/+uknjRs3zvP8jDPOULNmzXTiiSfqr7/+0sknn1zN0ZEtEudyBdKIMKpnmkakp1ba/61vtpt6L3eQbbZX2cx+wTSCVHUjEqiIWnO9YpD+xBAp/1vzx782ufzhdsw10gm3eBprZi7JUX6h7xtCO6ckAIgOTr+3sDvxZaCN3XY3IgAAgmM6KG/QoIF+++03dejQwbNt2LBheuuttzR27Fj9+uuvISmg3SKxxTmYRoT4OFelG4pAg2wzc9TN9D4HU/7qbkT8aVinttdw9lSHN7hU65JPvJ+veVJ69//MH//xfeUPSaMkDWmUrp6FN1e5O5nYAdSE0+8t7B4BFmg9bHcjAgAgOKaD8n79+undd9/VgAEDvLYPHTpUS5cu1dixYy0vnN0iscXZykaEQIPs6oY9m+19Drb8/m5EqvPYOX0U53JF3dQESVL/f0n9/+Vp6Ghc+L3eTrzJ9OH1d+doc9LZXtvaFy+otB+Z2AEEw+n3FlaOAAtGoPWw3Y0IAIDgmE70duWVVyopyXelc9xxx2np0qU677zzLCuYEzglcZlZ/hoRDEkzl3xXKSFMVcwkgqsYZLuHPaemeP+tpKYk+W0QCDahjRTcDYZL5T3wAzo20cBOTTS+dysN7NQkegLy/zk0SdF3Rge1L16g9sUL1KF4gXoU/zfg821OOtvrITkgIR6AiOT0ewt/iS/d9Uio8o8EWg/b3YgAAAiO6Z7yoUOHaujQoVW+PmzYMA0bNsySQjlFpLU4m+ktzi8s0aMrN2r68C6mzlldIriqhnhXNUfdX7Bbk2F3gd5gOH3euFWJBf01dOxTkgYmLtQn1x1ffn7DkGY1DOg9NiedLb1wyAYL1koHEBucfm9hxQiwmgqkHrZqGhkAILxMBeV79+5VvXr1TJ800P2dKtJanM02Djy44kd1Ta1vehh7MEG2rznq/tSkEcTfjUhFTp43bmViwYAbOlwu69dKJ0gH4EOk3FsE0zgdijKYqYed0IgAAAicqaC8c+fOmj59uiZNmqS0NN+Vj2EYWrFihR544AEde+yxuuGGGywtqB0ircU5kMaBQOfCBxNkB6omjSD+bkQMSVcO76L2Tes5et641YkFLRntMbPAq6Gg4hxzvyoG6TP2lAf/AGJaJN1bBDsCzEpm62EnNCIAAAJjKihftWqVbrzxRs2cOVO9evVS37591bJlSyUlJWn37t3KyclRZmamatWqpRtuuEH/+te/Ql3usIi0Fmd3I4KZhGdOzL5a00aQSL8RCUViQatGexx6Q7q4KMf7hjTQnvSKw+Nv3C4lRP7IGgCBibR7C6sap62anlQdJzQiAADMcxmGYXoVqa1bt+q1117Txx9/rC1btmj//v1q2rSpjjzySI0cOVKjR49WfHx8KMsbsMLCQqWkpKigoEANGjQI6hyRtE75suw8XZKRZWrfh8/qrfG9W4W4RIFx9xRLvhtBzPQUh+OGJxQyN+3UhHmf+93vpYsGmL4xLC0zNGTOSr8NHZ455VZ4bpyU+1Hwx0//RmrUzvM0Ur9PoDpW1E3Rwop7i0i5npF0PwEAqJlA6qaAgvJIZFVFHUmBwcMrftSDK37yu18gwV04xepNy+L12zT95fV+9/PVmFLd36cVDR01supuadXsoA//d+3b9E5RZ8/zWPhbQPSLlCAyUkTC9axqelLYfosBAGFFUH6ISKiorVZaZmjw3SuVX+h7GHtIekctFkmNIFYJtqfcTCOGoxo6vn9LeuWcoA+/5q9/6Y3SodzAIqLFYt0USk6/nu5RS1VNL4uEehkAEBiC8kM4vaIOFdt7RxGwYIaaB9Lz4tiGjh3fS48PCP74/pdIo+dYVx4gDGK1bgoVp1/PUExPAgA4WyB1k+l1yhFZIj3pWSwKNLFgoInhwpFBPyjNu3svm7Z/tzSnvfnj1zxR/nBreaR08SqrSgcANWbJShgAgKhFUB7FyL4aeQJpTAl4DfJIUaeRFp+U45lfH6cy/Zw00fzx279irXQAjmLVShgAgOhEUB7lHNs7iiqZbUyJ5p6XQ29MyxSn9sULvF6v8VrpBOkAwqimS34CAKJbnNkd9+7dq6lTp6pVq1Zq1qyZzjrrLP3++++hLBtQrdIyQ5mbdmrx+m3K3LRTpWXRkx7B3ZgyvncrDezUxOfohmjueXHfwFY1pqND8QINTFyo0lv3BBdgz0zxfgCwRazcW7inJ0mq9Lvma3oSACC2mO4pv+WWW/TCCy/onHPOUZ06dbRgwQJdfPHFWrRoUSjLB/gUqmzijk2G5kM097wEOr++UmAeaKBdcf9bd0lx1a+LDKDmYuneglwvAICqmM6+3qFDB91zzz06/fTTJUlffvmlBgwYoP3796tWLeeOgnd6RlYELlRrvTpq2TCToj3LvmXfyVPHlc81D9b1v0hJ/H7AOtRN5ay6t4ik6xlJjb8wj+8VQEUhWRKtdu3a2rJli1q2bOnZVrduXf3www9q27ZtzUocQnZV1Pw4h0ao1noNVaAfDpHYmBCIkPxbevd6ac3c4I+/4lupoXN/9+B8kRREhpJV9xZcT9gp2uthAMEJyZJoZWVlql27tvfBtWqptLQ0uFJGMX6cQ+fzTTstzzge6NJiThPtWfZDkqxw9N3lD7evXpQW/9v88Q8d7v38ghVSm6OtKRsQQ7i3QKSrqlE/v6BYUzOyHN2oD8A5TAflhmHohBNO8BpOtm/fPo0bN04JCQmebVlZWdaWMMLw4xw6y7LzdP0b35raN5CM49GwtBhZ9mvoyHPKH27bsqR5w8wf/8xw7+cTXpG6jrKmbEAU494CkSySG/UZ0Qk4i+mgfMaMGZW2jR8/3tLCRDqrfpz5oaysqsaOqgSScdxsAP9udp4k8X3EglZ9vJPH/fm7dF9n88e/dKb381F3SwOmWlM2IIpwb4FIFqmN+ozoBJynRkE5vJn9cX7201w1TU70GXDzQ1lZdY0dFQWTcdxsAP985hY9n7kl5r+PmFS/mXeQ/lexdGcL88cvu7784dZnknTiI9aVD4hQ3Fsgkplt1A9k9F6oMaIT8OPgAalWgv/9LBZU2vQ//vhDmzdvlsvlUvv27dWkiXNa/+xk9kf39re/9/z/oQEeP5S++WvsqCjQtV79LS1WUax/H5BUO8k7SDcMaVZD88dnPVf+cGt1lHTRSsuKB0Qi7i0Qacw26gcyei+UInm4PRAye/+QPr5f+vzxv7e17idNfkuqlRi2YsQFsvN3332nY489Vi1atFD//v3Vr18/NW/eXMcff7w2bNgQqjJGjGB+dN0B3jvf5FX7QymV/1CWlpkdwB09zDZ2NKxTO6hA2b0mtvR3tvXqxPr3AR9crvIg/dBHILZ9Wb5W+qEPIEZwb4FI5W7Ur+rewaXyzpdARu+FUiDD7YGo9dd+6dNHpNualN9v3dvJOyCXpF/XSiV/hrVYpnvK8/PzNXToUDVr1kwPPPCAunXrJsMwlJOTo3nz5umYY45Rdna2mjdvHsryOlqgPa7S3y2TtyzO1s69B6rdz4nzksLBbGPHY+f00eDOTYN6j1E90zR3Yp9KUweqEsvfB0yqGJg/PlDakRPA8RUC8xl7yoN/IIpwb4FI5m7Un5qRJZfkde/n/rUOdPReKEXicHugxsrKpO8WSitmSQVbzR0zao5UL7z396aD8gcffFDt2rXTp59+qqSkv4OkUaNGaerUqRoyZIgefPBBzZ49OyQFjQTV/ThXx5CqDcgPFYs/lP4aO9zzyAd0rNk/nkOXFns3O0/PZ27xe0wsfh8I0r8zvZ8v/Jf0zcvmj684PP6WP6T42j53BSIF9xaIdFU16qc6MP9MpA23B4K2+RPpg9ukX9aYPyb9JGnYTVKzw0JWrOqYDsqXL1+u66+/3qvSdKtTp46uvfZa3XPPPTFfcQba4xqoWPyhDGdL9KFLi5kJymPx+4BFTnmy/OH2yUPSigCSXt1eYVTIdVukOg2tKBkQNtxbIBoc2qjv5JVzzHZyOGW4PWDa7z9KH94h5Sw2f0yb/tIJt0rth4SuXAEwHZT//PPP6tOnT5Wv9+3bVz///LMlhYp0FX+c/ygq8UruVpXG9Wpr996/+KH0Idwt0VRcCLshV5Q/3HKWSK+ea/74Oe28n0//WmrU3oKCAaHDvQWixaGN+k4VacPtgSr9+Xt5crY1c80f07CtdMIMqccpUlxAadXCwnRQXlRUpAYNGlT5enJysv78M7wT4p3s0B/n0jJDT3+S6zfAu2VMd1264Ct+KKsQzpZoKi7YLv1E73np27KkecPMH/9wL6+n3458TT/X6eHYHhzEJu4tgPCKpOH2gMeBfdK6eeXzwo1Sc8fE1ZaGz5COvlCqXSe05bNAQEuiFRUV+RxiJkmFhYUyDDJR+2I2wBvVM01z41z8UFYjnC3RVFxwlFZ9vIP0gl+lB3uYPvzw907X4Yc8/2LIPPUdfoZ15QOCxL0FEF6RMtweMaysVMp+ozwIL/zV/HEDLpWOuUqqF1ziZzu5DJO1XVxcnFzVZP41DEMul0ulpSZbL8KksLBQKSkpKigoqLY1PhyWZedVCvDSfAR4pWWG54eyab1EySX98WcJP5o2OfT74DuAY5UUSbNbB3/8+MekIydaVx5Uy0l1k52surfgegJAhMv9WPpglvTrOvPH9Di5PDlb0y6hK1cNBFI3me4p//DDD2tcsFh3aMtkfmGxdv1Zosb1EpRSJ0GlZYYn0HP3Bi/LztM1r3/tN4hHaEXCPDFAiclePemlBw8q/o4A/m4XX1r+cBv7oNT3fAsLCFTGvQUAxKjfN0grb5e+X2r+mLYDpeNvkdoPDl25bGK6pzxSObH13EyP+bLsPE3NyKo0B93dnzB3Yh8C8whFzzvCIXPTTk2Y97nXth8SJynJ9VdwJxx8hTRiVs0LBknOrJsiGdcTABzuzx3SR/dJa5/0v69bo/blGdLTT3ZkcjZ/QtJTXtF3333nNZwsPj5ePXqYn98Yq6oKtvMLijU1I0tzJ/bRiPRUzVqa4zMpnKHywHzW0hyNSE8lmLNBTYJqs1MYgJraUVR5ScZuJc95PX8tYaaOjvvR3Ak/faj84dbjZOn0Z4MuH+AL9xYAECUO7CsPwD+4TTLKzB0Tn1CeIf3oCyIiOZuVTAflH3/8sa666iqtW1c+zn/AgAHat2+fJwGLy+XSe++9p+HDh4empFGgtMwwFWwnJ9Wudo1zQ1JeQbHW5u5iWHWYVRVU3zImXY3qJVQbqJtpkCEwh1WaJ/tOnHWo0w/M1EsXDfj7d+SzR6X3bzL3Bt8tKn+4tegpXfKJVM38YKAi7i0AIEqUlUrfvlYehBduM3/cwMukIVdGZHI2K5kOyh9//HGde673mrkffvih2rVrJ8Mw9Mgjj2ju3LlUnNVYm7vLVLCduWmnqfP56glD6FQVVOcVFOvfC7K8tlXs/TbbIMPoB1ilX4fGSktJ8rsUY78Ojf/eOOiy8odb9kLp9Snm3vC3bGlWw7+fxydKN26T4msHUXrECu4tACCC/by6PDnbti/NH9PzVOm4G6WmnUNXrghkOij/4osvdNNN3j0orVu3Vrt27SRJ5557rsaMGWNt6aKM+SDa3DR/Mz1hThHp86irC6p9qdj7bbZBhtEPsIrZpRir/XfY85Tyh9uWz6T5o80VoLREur1Cq/cNv5YnpAP+h3sLAIggO34oT872w1vmj2k3uDw5W7uBoStXFDAdlP/6669KSUnxPH/uueeUmprqed64cWPt3GmuhzdWmQ2iB3ZsqjeytgXWw+Vg0TCP2l9QXVHF3m+zDTKMfoCVRvVM09yJfSr9+0sN9t9fu0Hea6X/vkF6rJ/54ysu2XbVD1KDyPgNQGhwbwEADlb0m/TxfdLap8wf06iDNHyG1H18RCZns4vpoDw5OVmbNm1SmzZtJEmnnHKK1+u5ublkPPXD7HDSAZ2a1LyHyyGiZR51MMHyob3fZhtkImn0AyLDoUsxWj5SpVlX7yC9KF+6v6v54x/o5v3832uk5t1874uoxL0FADjIgb3SmifK54WbVSupPDlb3/Ol2tzHBst0UN6/f389//zzOu6443y+/uyzz6p///5WlSsqBTKc1PIeLhtE0zzqmgTLO4qKNfaIloHP70VYRfoUi+rEx7nCMy0iOdU7SC8pqtw7Xp3HK9Qh//5cat7dmrLBkbi3AAAblZVK37xaPi+8KM/8cYMulwZfKdVjyqVVTAflV111lYYPH64mTZro2muvVfPmzSVJO3bs0Jw5c5SRkaH3338/ZAWNFlUF243q1dYd43t6Bdsh7eEKg2iaR+0e5RDIEHa35slJ1szvRchEwxQLR0pM9g7SS/+S5rSXDvxp7vjHB3g/P/99qS0BWjTh3gIAwuznVdKKWdL2LL+7ehx+unTcDVKTTiErVqxzGe51R0x4/PHHdeWVV+rgwYNq0KCBXC6XCgoKVKtWLd1///267LLL/J8kzAJZtD2c3vkmTzcvztauvQc820IVBNjVA7h4/TZNf3m93/0ePqu3xvduFfLy1NTsd3L05Ee5pvd3935/ct3xnutN8Oc8VU2xcP8LiZQpFhHJMKSnjpPy1gd3/FkvSd3+aWWJwsapdZMdrLi34HoCQBV2fC+tvCPA5GxDpBNupSG8hgKpmwIKyiXpl19+0euvv66ffvpJktSlSxeddtppnvlgTuPEijqcQYCdQWDmpp2aMO9zv/t5rZPsUFV9Z1Wp7ruM5mHSkaa0zNCQOSurHAHhq2EFIZb5mPTejcEdO+4R6ahJ1pYnRJxYN9mppvcWsXA9qTsAmFKUL310r7TuafPHNO5YPi+8+4kkZ7NQSIPySOO0ijqcQYDdPYDuz+pvHrXTAx5/35kkxbmkskM+JL3fkSGaGo6i1jevSQsvDO7YYTdJQ//P2vJYxGl1U6SL9uvJKCsAVfpzh/T21dL3S8wfU7tueU/4UVNIzhZCgdRNpueUwxrhmmfthCRr0TKP2sxyaGWGdMuY7mqanEgPRgRhqboIcMTp5Q+3TSulF042d+yHd5Y/JKlVX2nodVKXEZKLf5uIHNGyikksYVQDQurgASnjFGnzx4EdN3i6NPgKqS5JhZ2IoDzMwhUEOCXJWjRkkTf7XTRNToyIufH4G0vVRaBOx3snj9v+Vfm8dH+2fSEtOCS4b9ZdOvYaqcfJUly85cUErOCEBnYEhlENsJxhSMtvlT57JLDjDj9DGnZD+dB0OB5BeZiFKwhwUg9gpGeRJ3CLXu6s+ixVF8FaHukdpO/cJP2nj//jfv9eeuOC8ockpbSRjr1W6jVBqpUQmrICAXJKAzvMYVQDLPPt63/XT2bVaSxNeJnkbBGKoDzMwhUEOC2QDGadZKcM/yJwi17RMsUCh2jSqcJa6X9KX/xX+ug+qaSg6uMKfpGWTit/SFJSw/Igve8UKaFeSIsMVMVJDeyoHqMaUCNmR31VdPgZ0slPkpwtChCUh1m4goBIDySdNPyLwC26RcMUC1Qjsb40eFr5Q5IOlkjrX5RW3ysVba/6uOI90vs3lT8kyRUvHX+TNOQq5qQjbJzWwI6qMaoBASn6TXr4COlggA1qjTtJF6+SkqIvoWWsCzgob9SokVw+bkhcLpeSkpLUuXNnTZ48WVOmTLGkgNEoHEFAJAeSThz+ReAW3SJ9igUCUCtR6nt++UOSSg9K3y0qXz7mjw1VH2eUSh/cJu3ZKo17ODxljSHcW/gW6Q3ssYRRDajWwRLp+ZOkrZ8Ffuy0r5gXHgMCDspvvfVW3XnnnRo9erT69esnSVq7dq2WLVumSy+9VLm5uZo6daoOHjyoiy66yPICR4twBAHhDCStGmru5OFfBG7RLZgpFogC8bW8M7wbhvTjMmn1PdL2rMr7538b3vLFCO4tfIvkBvZYw6gGeDEM6b2bpM8fC/zY8xZLHY+zvEhwtoCD8k8++UR33HGHLrnkEq/tTz75pN5//3298cYbOuKII/TII4/EVMUZjHAEAWYCyZoG1FYONXf68C8CNyDKuVxS19HlD6n8xmrzJ+U96clp0jFX21u+KMW9RdUYqRUZGNUAff2KtOjiwI8bdbc0YKr15UFEcRmG4eu3o0r169fX+vXr1blzZ6/tGzduVO/evfXnn39q06ZNOuKII7R3715LCxuMQBZtj0U1DairGmruDukDHWq+eP02TX95vd/9Hj6rN8uPAYhY1E3eanpvEQvX0ynJT1E19z2R5HtUA9nXo8y2L6V5xwd+3BFnSSfNJTlbDAikbgr4r6Fx48ZaunRppe1Lly5V48blrX979+5VcnJyoKdGmLkrj4o90+6528uy86o93t9Qc6l8qHlpmfl2HycN/yotM5S5aacWr9+mzE07A/ocAADzuLfwzz1Sa3zvVhrYqQkBuQO5RzWkpnjfo6SmJBGQR4OifOm2ptLMlPKH2YC8aVfphl/LVwaZWSCdQrZ0VBbw8PVbbrlFU6dO1YcffuiZ97Vu3Tq98847euKJJyRJy5cv19ChQ60tKSxlxdztUAw1d8rwLydlfweAaMe9BaIF+WeiyF/F0nPjpF/XBn7stPVS4w6WFwnRK+Cg/KKLLlJ6eroeffRRLVy4UJLUtWtXrV69WoMGDZIkXX01c+6czoqAOhSZRp2Q1MaJ2d8BIJpxb4FoQv6ZCGUY0rIbpDVzAz/2vCVSRxoNEbyg1ikfPHiwBg8ebHVZEEZWBNShGmpuZ1IbJ2d/B4Boxr0FgLBb/5L05iX+96to9L1S/yCSugFVCCooLy0t1Ztvvqnvv/9ektSjRw+deOKJio+Pt7RwCB0rAup+HRortUGS8gt9B+41GWpu1/Avp2d/B4Boxb0FgJD79Qvp6RMCP673ROnE/zAXHCETcFC+ceNG/fOf/9S2bdvUtWtXSdLs2bPVpk0bvf322+rUqZPlhYT1rJi7vTwnX8UHS32+ZsVQczuGf4ViSD4AoHrcWwAIicLt0gPpks+73Wo0T5cuWC4l1g9JsYCKAm7umTZtmjp16qRffvlFWVlZysrK0tatW9WhQwdNmzbN0sKVlpbqlltuUYcOHVSnTh116tRJt99+uwJcxQ0+uOduS38H0G5mAmr3vOs9+/7y+XrDurUjcu61k7K/A0CsCOe9BYAo9lex9NSwvzOkP9BdpgPy6d/8nSH935kE5AirgHvKV69erc8//9yzRIkkNWnSRHfffbflc8HmzJmjuXPn6rnnnlOPHj30xRdfaMqUKUpJSaGStkCwc7erm3ftllgrTiPSUy0uceg5Jfs7AMSScN5bAIgihiG9e5209snAj530ltThGOvLBAQh4KA8MTFRRUVFlbb/+eefSkhIsKRQbp999pnGjx+vMWPGSJLat2+vl156SWvXBrE0AXwKZu62v3nXkpRfWBKR866dkP0dAGJNOO8tAES4rzKkxZcGftyY+6WjL7S+PIAFAg7Kx44dq4svvljPPPOMZy3RNWvW6JJLLtGJJ55oaeEGDRqkp556Sj/++KMOO+wwff311/rkk0/0wAMPVHlMSUmJSkpKPM8LCwstLVM0CnTudrTPu7Yz+zsAxKJA7y2o64EY8sta6ZkRgR/X5zxp3COSi44UOF/AQfkjjzyiSZMmaeDAgapdu7Yk6eDBgzrxxBP18MMPW1q466+/XoWFherWrZvi4+NVWlqqO++8U+ecc06Vx8yePVuzZs2ytBzwFgvzru3K/g4AsSjQewvqeiCKFW7/31zwALU4XLrgPSmhnvVlAkLMZQSZNe2nn37SDz/8IEnq3r27OnfubGnBJOnll1/Wtddeq3vvvVc9evTQ+vXrdcUVV+iBBx7QpEmTfB7jq/W8TZs2KigoUIMGDSwvYywqLTM0ZM5Kv/OuP7nueIJYAPChsLBQKSkp1E0VmL23oK4Hoshf+6X/jpLy1gd+7BXfSg3bWl4kwAqB1PVBB+Xh0KZNG11//fW69NK/543ccccdysjI8FTa/nDjExru7OuS73nXkZh5HQDChbrJWlxPIIIYhvT21dIXzwR+7JR3pXaDrC8TEAKB1E2mhq9fddVVpt+8uvnegdq3b5/i4rxXbYuPj1dZWZll74HgMO8aAFATdt1bALBB1vPSkssDP27sQ1LfKZYXB3AaU0H5V199ZepkLosTKYwbN0533nmn2rZtqx49euirr77SAw88oPPPP9/S90FwmHcNAAiWXfcWAMJg6+fSf0cGftxRU6SxD5KcDTHH0cPXi4qKdMstt2jRokXasWOHWrZsqQkTJujWW281vUQKQ9oAAE5D3WQtridgs4JfpQd7BH5c6hHS+e9JCXWtLxNgs6iZU24FKmoAgNNQN1mL6wmEWXGBdHeQCdauyJYatrG2PIADWT6nHAAAAECMKiuTbmsU3LHnvye1HWBteYAoQ1AO/E9pmcH8eAAAAEl68lgp7+vAjxv3sHTUZMuLA0QzgnJA5Uu8Vcwkn0YmeQAAECs+uk9aeXvgxx01uTxLOsnZgKARlCPmuddcr5hcIb+gWFMzsmxZc51eewCAE1AfRbEtn0nzRwd37P/lSnUbW1seIIYRlCOmlZYZmrU0p1JALkmGJJekWUtzNCI9NWw3IfTaAwCcgPooyuz9Q7q3U3DHXviB1LqvteUB4EFQjpi2NneX181GRYakvIJirc3dpYGdmoS8PE7stQcAxB7qoyhQk+Rs/7hTGnSZteUBUCWCclgq0oa57SiqOiAPZr+acGKvvS+R9h0DAAITKfURfHh8kLTju8CPazdYmvKO9eUBYApBOaoVSAAWicPcmicnWbpfTTit196XSPyOAQCBiYT6CP+zao606q7gjp2xh+RsgEMQlKNKgQRgkTrMrV+HxkpLSVJ+QbHPHgGXpNSU8saIUHNSr70vkfodAwAC4/T6KKblfiw9Nza4Y6/bItVpaGlxAFiDoBw+BRKARfIwt/g4l2aMS9fUjCy5JK/P4C7pjHHpYSm3k3rtK4rk7xgAEBgn10cx588d0n1dgjv2og+lVn2sLQ+AkCAoRyWBBmCRNsyt4pD8EempmjuxT6VRAalhHpbtpF77iiLtOwYABM/J9VHUKyuVbgvyuo6aIw24xNryAAgLgnJUEmgAFknD3Kobkv/JdcfbmsDMSb32FUXSdwwAqBkn10dR6dF+0h8bAj+uw1Bp0hLrywMg7AjKUUmgAVikDHOLhDnRo3qmOaLXvqJI+Y4BANZwan0UFVbeKX10T3DHkpwNiEoE5agk0AAsEoa5RdKc6FE90zQiPdVRy45FwncMALCWE+ujiJSzWHr1vOCOvX6rlJRibXkAOA5BOSoJNACLhGFukTYnOj7O5YhyuEXCdwwAsJ7T6qOIsHOT9J8gE6xdvFpq2dvS4gBwPoJyVBJMAOb0YW7Mia45p3/HAADYovSgdHuQDRf/vE/qd5G15QEQcQjK4VMwAZiTh7kxJ9oaTv6OAQAIm5lBDinvPEKa+Lq1ZQEQ8QjKUaVgAjCnDnNjTrR1nPodAwAQMq9Nkb5bGNyxJGcD4AdBOaoVLQEYc6IBAIBp2Qul16cEd+z/5Up1aeQHIk1pmWHbaFCCcsQM5kQDAACfdm+WHu4V3LHnvye1HWBpcQCE17LsvEoxQloYYwSCcsQU5kQDAACV/iXd3jS4Y4+7QTruemvLA8A2y7LzNDUjq9IU1/yCYk3NyNLciX1CHpgTlCPmRMuQfAAAEIBgk7M1bCdd8Y21ZQHgCKVlhmYtzfGZc8pQ+TTXWUtzNCI9NaSdeATlgE3snLcCAEDUe+Vc6fslwR1LcjYgJqzN3eU1ZL0iQ1JeQbHW5u4KaaceQTlgA7vnrQAAEHW+eU1aeGFwx16/VUoKsicdQMTaUVR1QB7MfsEiKAfCzAnzVgAAkYcRVhXs+ll65Mjgjr1ghdTmaGvLAyDiNE9OsnS/YBGUA2HklHkrAIDIwggrSQdLpDuaB3fs8TdLx15rbXkARLx+HRorLSVJ+QXFPu/PXSpfqalfh9Auc0hQDoSIrx4Np8xbAQBEjpgeYRVscrYmnaXLv7S2LACiTnycSzPGpWtqRpZcktfvrLt7bMa49JB3lhGUAyFQVY/GP3ummjo+1PNWAACRIeZGWN3eXCotCe7YmQXWlgVATBjVM01zJ/apdO+eyjrlQOSqrkfjmU83mzpHqOetwHrM9QSin1X/zgM5T9SPsPrgdunj+4I79vpfpKQG1pbHItQJQGQZ1TNNI9JTbft3S1AeBfjh9y9c18hfj4Ykxbkkw5Ct81ZgLeZ6AtHPqn/ngZ7HKZmBLbMtS5o3LLhjz3ld6jLC2vKEAHUCEJni41y2NW4SlEc4fvj9C+c18tejIUll/4vG7Zy3AuvE9FxPIEZY9e88mPM4JTNw0P7aL91pbupWJYefIZ06z9ryhBh1AoBgxNldAATP/cNfMQh0//Avy86zqWTOEe5rZLan4vzB7ZWa4n0DlZqSRGUdYcyMjJi1NEelZb72ABAJrPp3Hux53JmBq2qqdam8odlRI6xmpvz9CDQgn1nw9yPCAnLqBADBoqc8QsVc4pcg2HGNzPZUjEhP1U1j0kMypJ7pDOET9XM9AVj27zzY85jNDCxJmZt22vPbH2yGdCmqkrNRJwAIFkF5hOKH3z87rlEgax2GYt4K0xnCK+rmegKoxKp/5zU5j7/MwJI0ZM7K8P32v3+z9Nl/gjv2ui1SnYaWFscpqBMABIugPELxw++fHdfIzrUOmccWfhE/1xOAX1b9O6/pearKDLw8Jz/0v/1bMqX5o4I79txFUqfja/b+ESKW6gRG5QHWIiiPULH0wx8su66RHWsdMp3BHoGMjAAQmaz6d27FeSqOsArZb/+BvdJdLc3vf6je50gnPR7csREuVuoERuUB1iMoj1Cx8sNfE3Zeo3Cvdch0BnvYOTICQGj46gE0+++8ut7DUPxeWPrbz7zwGouFOoFReUBoEJRHqFj44a8pu69RONc6ZDqDfewYGQEgNKrrAfT379xM76HVvxc1+u0nCA+JaK4TGJUHhA5BeQSL5h9+q8TKNWI6g73CPTICgPXM9AB+ct3xPv+dV3fsJRlZunJ4F7VvWk/Nk5M0Ij3Vst+LgH77Xz5H+uGtgN9DknT9VimpBkF8jInWOoFReUDoEJRHuGj94bdSLFwjpjPYL5wjIwBYK5AewIr/zs2sTf3gip8826yce1vdb//QuK/1XMKc8icvBHjiGErOFirRWCcwKg8IHYLyKBCNP/xWi/ZrZPdQfQCIZDXpAfR3bEVWzr099Lc/Wfv0bdKFwZ3oiDOlU56qUVkQ/RiVB4QOQTlChuUywitWhuoDgNVq0gMYaK+gpXNvZ6ZolKTcYGIg5oUjQIzKA0KHoBwhwXIZ9oiFofoAYLWa9AAG0ysY9NxbkrPBRozKA0KHoByWY7kMe0X7UH0AsFpNegD9HVsdv73s88dIWz4J8Kz/c90WqU7D4I4FqsCoPCA0CMphKZbLiFxMNwAQq8z0AN4yJt3nb2R1x/pTqZf9m1elhRcF9yEmLpQ6nxDcsUAAGJUHWI+gHJZiuYzIxHSD2EVjDFCuuh7AE3ul6fa3q/6NrOrYqnh63lsYwQ9J73mqdNp/gzsWqCFG5QHWIiiHpVguI/Iw3SB20RgDePPVA7h7b4kuXfCV39/Iisdu/mOfHlrxoyTv3vPNSWeX/0+JpPsCLCDzwgEgKhGURzin9XKxXEZkYbpB7KIxBvDt0B7A0jJDQ+asNP0bWbH3sGtqfY16vVvwhSEIB4CYQFAewZzYy8VyGZGF6QaxicYYwJygfiPnnSBt+0KSNCrQN7x+q5RUgwzrAICIRFAeYqHqyXZqLxfLZUQWphvEJhpjAHPM/PYdH5elgS+cHdwbnPWS1O2fwR0LAIgaBOUhFKqebKf3crFcRuRgukFsojEGMMfXb1897dd3SRcEd8IuI6VzXq1hqQAA0YagPERC2ZMdCb1cLJcRGZhuEJtojAHMcf9GZpacEvxJmBcOAPCDoDwEQt2THSm9XCyX4XxMN4hNNMYAfuz4QXq8v+IlZQZ6LEE4ACBAcXYXIBoF0pMdDHq5YCX3dIPUFO+/l9SUJDJwRyl3Y4z0d+OLG40xiHlrnpIe7296967Fz6p98QINTFyoZaf9EMKCAQCiFT3lIRDqnmx6uWA1phvEHnI/AFXIfqPal0eW3K0NRttK2/NYThAAECSC8hAIdU82Q44RCkw3iD00xgA+jLhNeudqKf/b8ucTF6q04/EaMmdltaPg3FhOEAAQKILyEAhHTza9XACsQGMMUEHb/tIln3htWrtpp6mA3AmJVgEAkYegPATC1ZNNLxcAAKEX6HQzuxOtAgAiC0F5iISrJ9vqXq7SMoMgHwCAQwQ63YxEqwCAQBCUh1Ck9WQvy86r1IiQxnB4ABGKRkZYxT0tzd8QdhKtAgCCQVAeYpEyX3NZdp6mZmRVmgOfTzZZABGIRkZY6dBpab5yxRyKRKsAgECxTjlUWmZo1tIcnzca7m2zluaotMzfrQgA2M/dyFixV9PdyLgsO8+mkiGSuaelpaX4HpqelpJEAzYAICj0lENrc3dVOySPbLIAIoW/RkaXWLIKwTt0Wlp+wX7t2ntAjesnKrUB0yMAAMEjKIfpLLFkkwXgdDQyItQiZVoaACByEJTDdJZYsskCcDoaGYFyJDoEgMhBUA5PVtn8gmKfQz7JJgsgUtDICJDoEAAiDYne4MkqK5UH4IdyPyebLIBI4G5krOrXyqXy4IRGRkQrEh0CQOQhKIekv7PKplbIKptKNlkAEYRGRsQyVlMBgMjE8HV4HJpVljloACKVu5Gx4vDdVIbvIsqR6BAAIhNBObyQVRZANKCREbGIRIcAEJkIygEAUYlGRsQaEh0CQGQiKAcAAKhGpCwvxmoqABCZCMoBAACqEEnLi7kTHU7NyJJL8grMSXQIAM5F9nUAAAAfInF5MVZTAYDIQ085AABABf6WF3OpfHmxEempjut5JtEhAEQWgnIAAIAKIn15MRIdAkDkcPzw9fbt28vlclV6XHrppXYXDQAARCmWFwMAhIvje8rXrVun0tJSz/Ps7GyNGDFCp59+uo2lAgAA0YzlxQAA4eL4oLxZs2Zez++++2516tRJQ4cOtalEAAAg2rG8GAAgXBw/fP1QBw4cUEZGhs4//3y5XL6TlZSUlKiwsNDrAQAAokc46nr38mLS38uJubG8GADAShEVlL/55pvas2ePJk+eXOU+s2fPVkpKiufRpk2b8BUQAACEXLjqepYXAyorLTOUuWmnFq/fpsxNO1Va5mssCYBAuAzDiJh/SSNHjlRCQoKWLl1a5T4lJSUqKSnxPC8sLFSbNm1UUFCgBg0ahKOYAABUq7CwUCkpKdRNQQp3XV9aZrC8GCBpWXaeZi3N8VqZIC0lSTPGpdNIBVQQSF3v+Dnlblu2bNGKFSu0cOHCavdLTExUYmJimEoFAADCLdx1vR3Li9EQAKdZlp2nqRlZlXIs5BcUa2pGFqNHgBqImKB8/vz5at68ucaMGWN3UQAAQAwKV6BMbyScprTM0KylOT6THhoqz7Mwa2mORqSn0ngEBCEigvKysjLNnz9fkyZNUq1aEVFkAAAQRcIVKNMbCSdam7vL62+/IkNSXkGx1ubuCvuoEiAaRESitxUrVmjr1q06//zz7S4KAACIMe5AuWJQ4g6Ul2XnWfI+/nojpfLeSBJrIdx2FFUdkAezHwBvERGU/+Mf/5BhGDrssMPsLgoAAIgh4QyUA+mNBMKpeXKS/50C2A+AN8aCAwAAHKK0zNDnm3Yq8+c/tG33/rAN26U3Ek7Vr0NjpaUkKb+g2GcDlUvlSwX269A43EUDogJBOQAAwP8sy87T9Qu/1Z59fwV0nBWBMr2RcKr4OJdmjEvX1IwsuSSvwNyd1m3GuHSSvAFBiojh6wAAAKG2LDtPl2RkBRyQS9YEyu7eyKrCGpfKk8vRGwk7jOqZprkT+yg1xftvPTUliQSEQA3RUw4AAGJeaZmhmUu+C/g4K4ft0hsJpxvVM00j0lPDsjQgEEsIygEAQMxbm7tL+YUlAR0TikDZ3RtZcfm1VNYpj1jhWt8+XOLjXCx7BliMoBwAAMS8YOaEhypQpjcyeoRrfXsAkY2gHAAAxLxA5oRfNqyTBnduFtJAmd7IyOde375itnL3+vbMwwbgRlAOAABiXr8OjZXaINHvEPa0lCRdOaKrJxiPtqHJsIa/9e1dKl/ffkR6Kn8vAAjKAQAA4uNcmnliD12SkVXtfofOH7d6aDIBfvRYm7srbOvbA4h8BOUAAAAqn8v9xMQ+Ptcpb1S3tmafcrgn2LZ6aDJzj6OL2RwFVqxvDyDyEZQDAAD8jzvJ2uebdirz5z8klc/tHtCxideQdSuHJjP3OPqYzVFgxfr2ACIfQTkAAMAh4uNcGtylqQZ3aerzdSuHJjP3ODr169BYaSlJyi8o9vndWrm+PYDIF2d3AQAAACKJlUOTAwnwETni41yaMS5d0t/r2buFYn17AJGNoBwAACAAZoccN62XqMxNO7V4/TZlbtqp0rLKfaZOnXtcWmb4LTuqN6pnmuZO7KPUFO+/l9SUJKYkAPDC8HUAAIAAmBmanFK3tq5+7WvlF1afuM2Jc49JOmcdd44CsuoDqA495QAAAAHwNzTZkLRn319eAbn0d+K2Zdl5nm3uAL+qEM2l8oA4XHOP3UnnKg6p91V2mBMfV54scHzvVhrYqQkBOYBKCMoBAAACVNXQ5BYNEtWwbm2fx7h71WctzfEMB3fS3GN/Seck77IDAKzB8HUgSpSWGQyPA0KEf1/wxdfQ5DLD0DlPr6nyGF+Z2d0BfsUh46lhHjJuZVZ5AIB5BOVAFGD+HxA6/PuKTWYbYtxDk90Wr99m6vyfbvzD65xOmHvs1KRzABDtCMqBCOee/1dxMKF7/h8ZXoHg8e8rNtWkIcZsQrZHP9yoBWu36OTerTQ8PdUTgNvZA+3EpHMAEAuYUw5EMOb/AaHDv6/YVNNEZ/4Stx1q196/9MynmzVh3ucaMmel7UnUnJZ0DgBiBUE5EMECmf8HIDD8+4o9VjTEVJe4rTpOyG7upKRzABBLCMqBCMb8PyB0+PcVe6xqiKkqM3t1nDL6oqqyp6YkMV0DAEKEOeVABGP+HxA6/PuKPVY2xLgTtz24fIMe/XCTqfM6Jbu5E5LOAUAsoacciGDM/wNCh39fscfqhpj4OJcGd24WcDmcMPrCnXRufO9WGtipCQE5AIQQQTkQwZj/B4QO/75iTygaYgJJ/ObG6AsAiC0E5UCEY/4fEDr8+4otoWiIOfSc/jD6AgBik8swjKhey6WwsFApKSkqKChQgwYN7C4OEDKlZQbz/4AQsfrfF3WTtay+nr7WKW9cr7buGN9T/zyipWXnPJT7r4nGHgCIDoHUTSR6A6KEe/4fAOvx7yu2jOqZprIy6ebF2dq194Ck8jXFb3/7e8XFuYIKmg9NnrY8J19vrt/uObdUPvpixrh0AnIAiEH0lAMAEGbUTdYKRU/51IysSuuVW9mbzegmAIhu9JQDAAAEobTM0KylOZUCcql8yTKXytcSH5GeWqMgmtEXAAA3Er0BAAD8z9rcXVXO+5a81xI3o7TMUOamnVq8fpsyN+1UaVlUD1AEAASBnnIAAID/MbtGuJn9fCV3S2PuOACgAnrKAQAA/sfsGuH+9nPPS6/Y655fUKypGVlalp0XdBkBANGFoBwAAOB/+nVorLSUpErrlLuZWUvc37x0qXxeOkPZAQASQTkAAIBHfJxLM8alS1KlwNz9fMa49GqTvFk9Lx0AEN0IygEAAA4xqmea5k7so9QU7yHqqSlJppZDs3JeejQjCR4AlCPRGwAAQAWjeqZpRHpqUGuJWzUvPZqRBA8A/kZQDgAA4EOwa4m756XnFxT7nFfuUnmve3Xz0qOZOwlexWvjToLnbzRCaZkRVGMJADgVQTkAAICF3PPSp2ZkySV5BZ9m56VHK39J8FwqT4I3Ij3V5/Whhx1ANGJOOQAAgMVqOi89WtUkCR7LzAGIVvSUAwAAhEBN5qVHq2CT4NW0hx0AnIygHAAAIESCnZcerYJNghdIDzvXG0CkISgHAACIIXYmSgs2CR7LzAGIZgTlAAAAMcLuRGnBJsFjmTkA0YxEbwAAADHAKYnSgkmC5+5hr6o/36XyxoVYXWYOQGSjpxwAACDKOS1RWqBJ8FhmDkA0o6ccAAAgytVkKbJQcSfBG9+7lQZ2auI3oGaZOQDRip5yAACAKBctidJYZg5ANCIoBwAAiHLRlCiNZeYARBuGrwMAAEQ5EqUBgHMRlAMAAEQ5d6I0SZUCcxKlAYC9CMoBAABiAInSAMCZmFMOAADgQKVlhuUJzUiUBgDOQ1AOAADgMMuy8zRraY7XMmZpKUmaMS69xj3aJEoDAGdh+DoAAICDLMvO09SMrErriucXFGtqRpaWZefZVDIAQCgQlAMAADhEaZmhWUtzZPh4zb1t1tIclZb52gMAEIkIygEAABxibe6uSj3khzIk5RUUa23urvAVCgAQUgTlAAAADrGjqOqAPJj9AADOR1AOAADgEM2Tk/zvFMB+AADnIygHAABwiH4dGistJUlVLVDmUnkW9n4dGoezWACAECIoBwAAcIj4OJdmjEuXpEqBufv5jHHprCsOAFGEoBwAAMBBRvVM09yJfZSa4j1EPTUlSXMn9qnxOuUAAGepZXcBAAAAYkVpmaG1ubu0o6hYzZPLh6H76vUe1TNNI9JTTe0LAIhsBOUAAABhsCw7T7OW5ngteZaWkqQZ49J99n7Hx7k0sFOTcBYRAGADhq8DAACE2LLsPE3NyKq0Bnl+QbGmZmRpWXaeTSUDANiNoBwAACCESssMzVqaI8PHa+5ts5bmqLTM1x4AgGhHUA4AABBCa3N3VeohP5QhKa+gWGtzd4WvUAAAxyAoBwAACKEdRVUH5MHsBwCILgTlAAAAIdQ8Ocn/TgHsBwCILgTlAAAAIdSvQ2OlpSSpqsXMXCrPwt6vQ+NwFgsA4BAE5QAAACEUH+fSjHHpklQpMHc/nzEunTXIASBGEZQDAACE2KieaZo7sY9SU7yHqKemJGnuxD4+1ykHAMSGWnYXAAAAIBaM6pmmEempWpu7SzuKitU8uXzIOj3kABDbCMoBAADCJD7OpYGdmthdDACAgzB8HQAAAAAAmzg+KN+2bZsmTpyoJk2aqE6dOjr88MP1xRdf2F0sAAAAAABqzNHD13fv3q3Bgwdr2LBhevfdd9WsWTP99NNPatSokd1FAwAAAACgxhwdlM+ZM0dt2rTR/PnzPds6dOhgY4kAAAAAALCOo4evL1myRH379tXpp5+u5s2b68gjj9S8efOqPaakpESFhYVeDwAAED2o6wEA0cTRQfnPP/+suXPnqkuXLnrvvfc0depUTZs2Tc8991yVx8yePVspKSmeR5s2bcJYYgAAEGrU9QCAaOIyDMOwuxBVSUhIUN++ffXZZ595tk2bNk3r1q1TZmamz2NKSkpUUlLieV5YWKg2bdqooKBADRo0CHmZAQDwp7CwUCkpKdRNQaKuBwA4XSB1vaPnlKelpSk9Pd1rW/fu3fXGG29UeUxiYqISExNDXTQAAGAT6noAQDRx9PD1wYMHa8OGDV7bfvzxR7Vr186mEgEAAAAAYB1HB+VXXnmlPv/8c911113auHGjFixYoKeeekqXXnqp3UUDAAAAAKDGHB2UH3300Vq0aJFeeukl9ezZU7fffrseeughnXPOOXYXDQAAAACAGnP0nHJJGjt2rMaOHWt3MQAAAAAAsJyje8oBAAAAAIhmju8pryn3im+FhYU2lwQAgHLuOsnBq5JGFOp6AIDTBFLXR31QXlRUJElq06aNzSUBAMBbUVGRUlJS7C5GxKOuBwA4lZm63mVEeTN9WVmZtm/fruTkZLlcrhqdq7CwUG3atNEvv/zidwH4WMU1qh7Xxz+uUfW4Pv5FwjUyDENFRUVq2bKl4uKYSVZTgdb1kfA3Eiqx/Nml2P78fHY+O589vAKp66O+pzwuLk6tW7e29JwNGjSIuT/qQHGNqsf18Y9rVD2uj39Ov0b0kFsn2Lre6X8joRTLn12K7c/PZ+ezxxo7P7vZup7meQAAAAAAbEJQDgAAAACATQjKA5CYmKgZM2YoMTHR7qI4Fteoelwf/7hG1eP6+Mc1gj+x/DcSy59diu3Pz2fns8eaSPrsUZ/oDQAAAAAAp6KnHAAAAAAAmxCUAwAAAABgE4JyAAAAAABsQlAOAAAAAIBNCMoD8Nhjj6l9+/ZKSkpS//79tXbtWruL5BgfffSRxo0bp5YtW8rlcunNN9+0u0iOMnv2bB199NFKTk5W8+bNddJJJ2nDhg12F8sx5s6dqyOOOEINGjRQgwYNNHDgQL377rt2F8ux7r77brlcLl1xxRV2F8UxZs6cKZfL5fXo1q2b3cWCQ8VqfR6rdXUs18HUr3+Ltboz1uvFbdu2aeLEiWrSpInq1Kmjww8/XF988YXdxaoSQblJr7zyiq666irNmDFDWVlZ6tWrl0aOHKkdO3bYXTRH2Lt3r3r16qXHHnvM7qI40urVq3XppZfq888/1/Lly/XXX3/pH//4h/bu3Wt30RyhdevWuvvuu/Xll1/qiy++0PHHH6/x48fru+++s7tojrNu3To9+eSTOuKII+wuiuP06NFDeXl5nscnn3xid5HgQLFcn8dqXR3LdTD1a7lYrTtjtV7cvXu3Bg8erNq1a+vdd99VTk6O7r//fjVq1MjuolXNgCn9+vUzLr30Us/z0tJSo2XLlsbs2bNtLJUzSTIWLVpkdzEcbceOHYYkY/Xq1XYXxbEaNWpkPP3003YXw1GKioqMLl26GMuXLzeGDh1qTJ8+3e4iOcaMGTOMXr162V0MRADq83KxXFfHeh0ca/VrrNadsVwvXnfddcaQIUPsLkZA6Ck34cCBA/ryyy81fPhwz7a4uDgNHz5cmZmZNpYMkaqgoECS1LhxY5tL4jylpaV6+eWXtXfvXg0cONDu4jjKpZdeqjFjxnj9FuFvP/30k1q2bKmOHTvqnHPO0datW+0uEhyG+hxS7NbBsVq/xnLdGav14pIlS9S3b1+dfvrpat68uY488kjNmzfP7mJVq5bdBYgEf/zxh0pLS9WiRQuv7S1atNAPP/xgU6kQqcrKynTFFVdo8ODB6tmzp93FcYxvv/1WAwcOVHFxserXr69FixYpPT3d7mI5xssvv6ysrCytW7fO7qI4Uv/+/fXss8+qa9euysvL06xZs3TMMccoOztbycnJdhcPDkF9jlisg2O5fo3lujOW68Wff/5Zc+fO1VVXXaUbb7xR69at07Rp05SQkKBJkybZXTyfCMqBMLv00kuVnZ0dM/N6zOratavWr1+vgoICvf7665o0aZJWr14dMzcO1fnll180ffp0LV++XElJSXYXx5FGjx7t+f8jjjhC/fv3V7t27fTqq6/qggsusLFkAJwkFuvgWK1fY73ujOV6saysTH379tVdd90lSTryyCOVnZ2tJ554wrFBOcPXTWjatKni4+P122+/eW3/7bfflJqaalOpEIkuu+wyvfXWW/rwww/VunVru4vjKAkJCercubOOOuoozZ49W7169dLDDz9sd7Ec4csvv9SOHTvUp08f1apVS7Vq1dLq1av1yCOPqFatWiotLbW7iI7TsGFDHXbYYdq4caPdRYGDUJ/Htlitg2O1fqXu9BZL9WJaWlqlRqfu3bs7evg+QbkJCQkJOuqoo/TBBx94tpWVlemDDz6IqTk5CJ5hGLrsssu0aNEirVy5Uh06dLC7SI5XVlamkpISu4vhCCeccIK+/fZbrV+/3vPo27evzjnnHK1fv17x8fF2F9Fx/vzzT23atElpaWl2FwUOQn0em6iDvcVK/Urd6S2W6sXBgwdXWvbwxx9/VLt27WwqkX8MXzfpqquu0qRJk9S3b1/169dPDz30kPbu3aspU6bYXTRH+PPPP71a3nJzc7V+/Xo1btxYbdu2tbFkznDppZdqwYIFWrx4sZKTk5Wfny9JSklJUZ06dWwunf1uuOEGjR49Wm3btlVRUZEWLFigVatW6b333rO7aI6QnJxcae5jvXr11KRJk5iZE+nPNddco3Hjxqldu3bavn27ZsyYofj4eE2YMMHuosFhYrk+j9W6Opbr4FiuX2O97ozlevHKK6/UoEGDdNddd+mMM87Q2rVr9dRTT+mpp56yu2hVszv9eyT5z3/+Y7Rt29ZISEgw+vXrZ3z++ed2F8kxPvzwQ0NSpcekSZPsLpoj+Lo2koz58+fbXTRHOP/884127doZCQkJRrNmzYwTTjjBeP/99+0ulqPF0rIuZpx55plGWlqakZCQYLRq1co488wzjY0bN9pdLDhUrNbnsVpXx3IdTP3qLZbqzlivF5cuXWr07NnTSExMNLp162Y89dRTdhepWi7DMIxwNgIAAAAAAIByzCkHAAAAAMAmBOUAAAAAANiEoBwAAAAAAJsQlAMAAAAAYBOCcgAAAAAAbEJQDgAAAACATQjKAQAAAACwCUE5EAKrVq2Sy+XSnj177C6KrWLlOhx33HG64oor7C4GACCMYqWO8ydWrgN1PUKJoBxRw+VyVfuYOXNm2MoyaNAg5eXlKSUlJehzbN682av8jRs31tChQ/Xxxx9bWNLY9uyzz6phw4Z2FwMAYBJ1PQJFXY9IQFCOqJGXl+d5PPTQQ2rQoIHXtmuuuSag8/3111+Vth04cMDUsQkJCUpNTZXL5QroPX1ZsWKF8vLy9NFHH6lly5YaO3asfvvttxqfFwCASENdDyAaEZQjaqSmpnoeKSkpcrlcXttefvllde/eXUlJSerWrZsef/xxz7HulupXXnlFQ4cOVVJSkl588UVNnjxZJ510ku688061bNlSXbt2lSS98MIL6tu3r5KTk5Wamqqzzz5bO3bs8Jyv4lAudyvte++9p+7du6t+/foaNWqU8vLy/H6uJk2aKDU1VT179tSNN96owsJCrVmzxuu8h3rzzTcr3SDccccdat68uZKTk3XhhRfq+uuvV+/evT2vHzx4UNOmTVPDhg3VpEkTXXfddZo0aZJOOukkzz5lZWWaPXu2OnTooDp16qhXr156/fXXvd7nnXfe0WGHHaY6depo2LBh2rx5s9frZq/D008/XeV3deDAAV122WVKS0tTUlKS2rVrp9mzZ0uSDMPQzJkz1bZtWyUmJqply5aaNm2a32vsNnPmTPXu3VsvvPCC2rdvr5SUFJ111lkqKiry7LN3716dd955ql+/vtLS0nT//fdXOk9JSYmuueYatWrVSvXq1VP//v21atUqSVJxcbF69Oihiy++2LP/pk2blJycrP/+97+mywoAsYi6vhx1PXU9oowBRKH58+cbKSkpnucZGRlGWlqa8cYbbxg///yz8cYbbxiNGzc2nn32WcMwDCM3N9eQZLRv396zz/bt241JkyYZ9evXN84991wjOzvbyM7ONgzDMJ555hnjnXfeMTZt2mRkZmYaAwcONEaPHu15vw8//NCQZOzevdtTntq1axvDhw831q1bZ3z55ZdG9+7djbPPPrvKz+Au01dffWUYhmHs27fPuOaaawxJxrvvvuvzcxqGYSxatMg49J92RkaGkZSUZPz3v/81NmzYYMyaNcto0KCB0atXL88+d9xxh9G4cWNj4cKFxvfff29ccsklRoMGDYzx48d77dOtWzdj2bJlxqZNm4z58+cbiYmJxqpVqwzDMIytW7caiYmJxlVXXWX88MMPRkZGhtGiRYuAr4O/7+ree+812rRpY3z00UfG5s2bjY8//thYsGCBYRiG8dprrxkNGjQw3nnnHWPLli3GmjVrjKeeeqrKa1zx+s2YMcOoX7++ccoppxjffvut8dFHHxmpqanGjTfe6Nln6tSpRtu2bY0VK1YY33zzjTF27FgjOTnZmD59umefCy+80Bg0aJDx0UcfGRs3bjTuvfdeIzEx0fjxxx8NwzCMr776ykhISDDefPNN4+DBg8aAAQOMk08+ucpyAgAqo66nrqeuR7QgKEdUqvgD3KlTJ8+Pudvtt99uDBw40DCMvyvFhx56yGufSZMmGS1atDBKSkqqfb9169YZkoyioiLDMHxX1JKMjRs3eo557LHHjBYtWlR5TneZ6tSpY9SrV89wuVyGJOOoo44yDhw44PNzGkblirp///7GpZde6rXP4MGDvSrqFi1aGPfee6/n+cGDB422bdt6Kuri4mKjbt26xmeffeZ1ngsuuMCYMGGCYRiGccMNNxjp6eler1933XUBXwd/39Xll19uHH/88UZZWVmla3b//fcbhx12mOf6+OOroq5bt65RWFjo2Xbttdca/fv3NwzDMIqKioyEhATj1Vdf9by+c+dOo06dOp6KesuWLUZ8fLyxbds2r/c64YQTjBtuuMHz/J577jGaNm1qXHbZZUZaWprxxx9/mCozAKAcdT11vRnU9YgEDF9H1Nu7d682bdqkCy64QPXr1/c87rjjDm3atMlr3759+1Y6/vDDD1dCQoLXti+//FLjxo1T27ZtlZycrKFDh0qStm7dWmU56tatq06dOnmep6WleQ2Dq8orr7yir776Sm+88YY6d+6sZ599VrVr1/Z7nNuGDRvUr18/r22HPi8oKNBvv/3mtS0+Pl5HHXWU5/nGjRu1b98+jRgxwusaPv/8855r+P3336t///5e7zNw4MBK5anuOpj5riZPnqz169era9eumjZtmt5//33PuU4//XTt379fHTt21EUXXaRFixbp4MGDpq+VJLVv317Jyck+y7dp0yYdOHDA63M2btzYM9RRkr799luVlpbqsMMO8/oMq1ev9vp7u/rqq3XYYYfp0Ucf1X//+181adIkoHICAP5GXU9dHwjqejhNLbsLAITan3/+KUmaN29epYokPj7e63m9evUqHV9x2969ezVy5EiNHDlSL774opo1a6atW7dq5MiR1SaHqVi5ulwuGYbht/xt2rRRly5d1KVLFx08eFAnn3yysrOzlZiYqLi4uErn8JW0pqbc1/Dtt99Wq1atvF5LTEwM6FzVXQcz31WfPn2Um5urd999VytWrNAZZ5yh4cOH6/XXX1ebNm20YcMGrVixQsuXL9e///1v3XvvvVq9erXpmxtf5SsrKzP9+f7880/Fx8fryy+/rPT3Vb9+fc//79ixQz/++KPi4+P1008/adSoUabfAwDgjbq+5qjrqethH3rKEfVatGihli1b6ueff1bnzp29Hh06dAj4fD/88IN27typu+++W8ccc4y6detmqhXcCqeddppq1arlSYbSrFkzFRUVae/evZ591q9f73VM165dtW7dOq9thz5PSUlRixYtvLaVlpYqKyvL8zw9PV2JiYnaunVrpWvYpk0bSVL37t3/v737CYlqDeM4/puJUUbt1GB/xoUpNaRCiQOBFpWLCoNq0cKKjAQtclEDhiYUhBURBVZkQRjERP+QsEVhizCoRahhLUZBaBNtCxpCiiyc5646985Vb3av3YP2/ewO5z3vPOc9i4dnzjnP0YsXL9J+p6+v76fOb6rXynEc7dy5U9euXVNnZ6e6urr04cMHSVIwGNS2bdt06dIlPX36VL29vRocHPypOCazbNkyBQIBt/mOJCWTSb1+/drdjkajGhsb07t378adQzgcdsfV1dVp5cqVunHjhlpaWjQ8PDwtMQLA74hcT64n12Mm4045fgsnTpxQLBbTvHnztHnzZo2OjmpgYEDJZFKHDx/+qbmWLFmijIwMtbe3q6GhQUNDQzp16tQvijydz+dTLBZTa2urDhw4oPLycmVlZeno0aOKxWLq7+9XPB5PO+bQoUPav3+/Vq1apTVr1qizs1OJREJLly5NG3PmzBlFIhEVFxervb1dyWTS7ew6d+5cNTU1qbGxUalUSmvXrtXHjx/1/PlzOY6j2tpaNTQ0qK2tTc3Nzdq3b59evnw5Lpap+NG1On/+vPLy8hSNRuX3+3Xv3j2Fw2HNnz9f8XhcY2Nj7rrcunVLwWBQBQUF/2XZXTk5Oaqvr1dzc7Nyc3O1aNEiHTt2TH7/n/9vLl++XDU1Ndq7d6/a2toUjUb1/v17PXnyRKWlpdqyZYuuXLmi3t5eJRIJ5efnq7u7WzU1Nerr6xv3+CQAYGrI9eT66UCuhye8fKEd+FUmaopy+/ZtKysrs4yMDAuFQrZ+/Xq7f/++mY3vfvpdbW1tWlfS7+7cuWOFhYWWmZlpq1evtgcPHqQdP1Hzlx81afm7yWL69OmThUIhO3v2rDtPJBKxYDBoW7dutY6OjnHznjx50hYsWGA5OTlWV1dnsVjMKioq3P3fvn2zgwcPmuM4FgqFrKWlxaqrq23Xrl3umFQqZRcvXrSioiILBAK2cOFCq6qqsmfPnrljHj58aJFIxDIzM23dunV2/fr1f7UO/3StOjo6rKyszLKzs81xHNuwYYO9evXKnau8vNwcx7Hs7GyrqKiwnp6eSdd4ouYvf22KY2Z24cIFKygocLdHRkZsz549lpWVZYsXL7Zz585ZZWVlWkfWr1+/2vHjx62wsNACgYDl5eXZ9u3bLZFI2PDwsAWDwbQGN8lk0vLz8+3IkSOTxgoASEeuJ9eT6zFb+Mym8KILgFll06ZNCofDunnz5oT7U6mUSkpKtGPHjv/tzgAAAJg+5Hpg5uDxdWCW+/z5s65evaqqqirNmTNHd+/edZujfPf27Vs9fvxYlZWVGh0d1eXLl/XmzRvt3r3bw8gBAMBUkOuBmY2iHJjlfD6fHj16pNOnT+vLly8qKipSV1eXNm7c6I7x+/2Kx+NqamqSmWnFihXq6elRSUmJh5EDAICpINcDMxuPrwMAAAAA4BE+iQYAAAAAgEcoygEAAAAA8AhFOQAAAAAAHqEoBwAAAADAIxTlAAAAAAB4hKIcAAAAAACPUJQDAAAAAOARinIAAAAAADxCUQ4AAAAAgEf+AHVGft8TyC6HAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fit = df.copy()\n", + "fit[\"mean\"] = linear_reg_model(x_data).detach().cpu().numpy()\n", + "\n", + "fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 6), sharey=True)\n", + "african_nations = fit[fit[\"cont_africa\"] == 1]\n", + "non_african_nations = fit[fit[\"cont_africa\"] == 0]\n", + "fig.suptitle(\"Regression Fit\", fontsize=16)\n", + "ax[0].plot(non_african_nations[\"rugged\"], non_african_nations[\"rgdppc_2000\"], \"o\")\n", + "ax[0].plot(non_african_nations[\"rugged\"], non_african_nations[\"mean\"], linewidth=2)\n", + "ax[0].set(xlabel=\"Terrain Ruggedness Index\", ylabel=\"log GDP (2000)\", title=\"Non African Nations\")\n", + "ax[1].plot(african_nations[\"rugged\"], african_nations[\"rgdppc_2000\"], \"o\")\n", + "ax[1].plot(african_nations[\"rugged\"], african_nations[\"mean\"], linewidth=2)\n", + "ax[1].set(xlabel=\"Terrain Ruggedness Index\", ylabel=\"log GDP (2000)\", title=\"African Nations\");" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We notice that the relationship between terrain ruggedness has an inverse relationship with GDP for non-African nations, but it positively affects the GDP for African nations. It is however unclear how robust this trend is. In particular, we would like to understand how the regression fit would vary due to parameter uncertainty. To address this, we will build a simple Bayesian model for linear regression. [Bayesian modeling](http://mlg.eng.cam.ac.uk/zoubin/papers/NatureReprint15.pdf) offers a systematic framework for reasoning about model uncertainty. Instead of just learning point estimates, we're going to learn a _distribution_ over parameters that are consistent with the observed data." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Bayesian Regression with Pyro's Stochastic Variational Inference (SVI)\n", + "\n", + "### Model\n", + "\n", + "In order to make our linear regression Bayesian, we need to put priors on the parameters $w$ and $b$. These are distributions that represent our prior belief about reasonable values for $w$ and $b$ (before observing any data).\n", + "\n", + "Making a Bayesian model for linear regression is very intuitive using `PyroModule` as earlier. Note the following:\n", + "\n", + " - The `BayesianRegression` module internally uses the same `PyroModule[nn.Linear]` module. However, note that we replace the `weight` and the `bias` of the this module with `PyroSample` statements. These statements allow us to place a prior over the `weight` and `bias` parameters, instead of treating them as fixed learnable parameters. For the bias component, we set a reasonably wide prior since it is likely to be substantially above 0.\n", + " - The `BayesianRegression.forward` method specifies the generative process. We generate the mean value of the response by calling the `linear` module (which, as you saw, samples the `weight` and `bias` parameters from the prior and returns a value for the mean response). Finally we use the `obs` argument to the `pyro.sample` statement to condition on the observed data `y_data` with a learned observation noise `sigma`. The model returns the regression line given by the variable `mean`." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "from pyro.nn import PyroSample\n", + "\n", + "class BatchedLinear(torch.nn.Linear):\n", + " def forward(self, x):\n", + " return torch.einsum(\"...ij,...nj->...ni\", self.weight, x) + self.bias\n", + "\n", + "class BayesianRegression(PyroModule):\n", + " def __init__(self, in_features, out_features):\n", + " super().__init__()\n", + " self.linear = PyroModule[BatchedLinear](in_features, out_features)\n", + " self.linear.weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2))\n", + " self.linear.bias = PyroSample(dist.Normal(0., 10.).expand([out_features]).to_event(1))\n", + " \n", + " @PyroSample\n", + " def sigma(self):\n", + " return dist.Uniform(0., 10.)\n", + " \n", + " def forward(self, x, y=None):\n", + " sigma = self.sigma\n", + " mean = self.linear(x).squeeze(-1)\n", + " with pyro.plate(\"data\", x.shape[0]):\n", + " obs = pyro.sample(\"obs\", dist.Normal(mean, sigma), obs=y)\n", + " return mean" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Using an AutoGuide\n", + "\n", + "In order to do inference, i.e. learn the posterior distribution over our unobserved parameters, we will use Stochastic Variational Inference (SVI). The guide determines a family of distributions, and `SVI` aims to find an approximate posterior distribution from this family that has the lowest KL divergence from the true posterior. \n", + "\n", + "Users can write arbitrarily flexible custom guides in Pyro, but in this tutorial, we will restrict ourselves to Pyro's [autoguide library](http://docs.pyro.ai/en/dev/infer.autoguide.html). In the next [tutorial](bayesian_regression_ii.ipynb), we will explore how to write guides by hand.\n", + "\n", + "To begin with, we will use the `AutoDiagonalNormal` guide that models the distribution of unobserved parameters in the model as a Gaussian with diagonal covariance, i.e. it assumes that there is no correlation amongst the latent variables (quite a strong modeling assumption as we shall see in [Part II](bayesian_regression_ii.ipynb)). Under the hood, this defines a `guide` that uses a `Normal` distribution with learnable parameters corresponding to each `sample` statement in the model. e.g. in our case, this distribution should have a size of `(5,)` correspoding to the 3 regression coefficients for each of the terms, and 1 component contributed each by the intercept term and `sigma` in the model. \n", + "\n", + "Autoguide also supports learning MAP estimates with `AutoDelta` or composing guides with `AutoGuideList` (see the [docs](http://docs.pyro.ai/en/dev/infer.autoguide.html) for more information)." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "from pyro.infer.autoguide import AutoDiagonalNormal\n", + "\n", + "model = BayesianRegression(3, 1)\n", + "guide = AutoDiagonalNormal(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Optimizing the Evidence Lower Bound\n", + "\n", + "We will use stochastic variational inference (SVI) (for an introduction to SVI, see [SVI Part I](svi_part_i.ipynb)) for doing inference. Just like in the non-Bayesian linear regression model, each iteration of our training loop will take a gradient step, with the difference that in this case, we'll use the Evidence Lower Bound (ELBO) objective computed by `pyro.infer.Trace_ELBO` instead of the PyTorch MSE loss. " + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[iteration 0050] loss: 6350.9521\n", + "[iteration 0100] loss: 5906.7910\n", + "[iteration 0150] loss: 5671.1948\n", + "[iteration 0200] loss: 5489.0088\n", + "[iteration 0250] loss: 5285.6274\n", + "[iteration 0300] loss: 5026.8740\n", + "[iteration 0350] loss: 4727.3896\n", + "[iteration 0400] loss: 4390.3813\n", + "[iteration 0450] loss: 4010.3523\n", + "[iteration 0500] loss: 3485.6401\n", + "[iteration 0550] loss: 2713.5918\n", + "[iteration 0600] loss: 2320.7119\n", + "[iteration 0650] loss: 2318.7930\n", + "[iteration 0700] loss: 2319.1804\n", + "[iteration 0750] loss: 2318.4934\n", + "[iteration 0800] loss: 2318.6567\n", + "[iteration 0850] loss: 2318.9456\n", + "[iteration 0900] loss: 2318.7878\n", + "[iteration 0950] loss: 2318.5378\n", + "[iteration 1000] loss: 2318.7805\n", + "[iteration 1050] loss: 2318.7441\n", + "[iteration 1100] loss: 2320.6340\n", + "[iteration 1150] loss: 2318.8325\n", + "[iteration 1200] loss: 2319.2222\n", + "[iteration 1250] loss: 2318.6689\n", + "[iteration 1300] loss: 2318.8411\n", + "[iteration 1350] loss: 2318.6680\n", + "[iteration 1400] loss: 2318.4939\n", + "[iteration 1450] loss: 2319.4907\n", + "[iteration 1500] loss: 2319.1812\n" + ] + } + ], + "source": [ + "from pyro.infer import Trace_ELBO\n", + "\n", + "elbo = Trace_ELBO(num_particles=10, vectorize_particles=True)(model, guide)\n", + "\n", + "# initialize guide\n", + "guide(x_data, y_data);\n", + "\n", + "# Define loss and optimizer\n", + "optim = torch.optim.Adam(guide.parameters(), lr=0.03)\n", + "\n", + "# optimize\n", + "for j in range(1500 if not smoke_test else 2):\n", + " loss = elbo(x_data, y_data)\n", + " optim.zero_grad()\n", + " loss.backward()\n", + " optim.step()\n", + " if (j + 1) % 50 == 0:\n", + " print(\"[iteration %04d] loss: %.4f\" % (j + 1, loss.item()))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can examine the optimized parameter values by fetching from Pyro's param store." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loc Parameter containing:\n", + "tensor([-2.2698, -1.9349, -0.2027, 0.3860, 9.2165])\n", + "scale_unconstrained Parameter containing:\n", + "tensor([-3.9453, -3.1433, -4.3532, -3.7626, -3.7387])\n" + ] + } + ], + "source": [ + "guide.requires_grad_(False)\n", + "\n", + "for name, value in guide.named_parameters():\n", + " print(name, value)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As you can see, instead of just point estimates, we now have uncertainty estimates (`AutoDiagonalNormal.scale`) for our learned parameters. Note that Autoguide packs the latent variables into a single tensor, in this case, one entry per variable sampled in our model. Both the `loc` and `scale` parameters have size `(5,)`, one for each of the latent variables in the model, as we had remarked earlier.\n", + "\n", + "To look at the distribution of the latent parameters more clearly, we can make use of the `AutoDiagonalNormal.quantiles` method which will unpack the latent samples from the autoguide, and automatically constrain them to the site's support (e.g. the variable `sigma` must lie in `(0, 10)`). We see that the median values for the parameters are quite close to the Maximum Likelihood point estimates we obtained from our first model." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'sigma': tensor([0.9256, 0.9366, 0.9476]),\n", + " 'linear.weight': tensor([[[-1.9634, -0.2113, 0.3705]],\n", + " \n", + " [[-1.9349, -0.2027, 0.3860]],\n", + " \n", + " [[-1.9064, -0.1941, 0.4014]]]),\n", + " 'linear.bias': tensor([[9.2007],\n", + " [9.2165],\n", + " [9.2324]])}" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "guide.quantiles([0.25, 0.5, 0.75])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model Evaluation\n", + "\n", + "To evaluate our model, we'll generate some predictive samples and look at the posteriors. For this we will make use of the [Predictive](http://docs.pyro.ai/en/stable/inference_algos.html#pyro.infer.predictive.Predictive) utility class.\n", + "\n", + " - We generate 800 samples from our trained model. Internally, this is done by first generating samples for the unobserved sites in the `guide`, and then running the model forward by conditioning the sites to values sampled from the `guide`. Refer to the [Model Serving](#Model-Serving-via-TorchScript) section for insight on how the `Predictive` class works.\n", + " - Note that in `return_sites`, we specify both the outcome (`\"obs\"` site) as well as the return value of the model (`\"_RETURN\"`) which captures the regression line. Additionally, we would also like to capture the regression coefficients (given by `\"linear.weight\"`) for further analysis.\n", + " - The remaining code is simply used to plot the 90% CI for the two variables from our model." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "ename": "RuntimeError", + "evalue": "The size of tensor a (170) must match the size of tensor b (800) at non-singleton dimension 1\n Trace Shapes: \n Param Sites: \n Sample Sites: \n sigma dist 800 1 | \n value 800 1 | \nlinear.weight dist 800 1 | 1 3\n value 800 | 1 3\n linear.bias dist 800 1 | 1 \n value 800 | 1 ", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m~/development/pyro/pyro/poutine/trace_messenger.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 173\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 174\u001b[0;31m \u001b[0mret\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 175\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mValueError\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/development/pyro/pyro/poutine/messenger.py\u001b[0m in \u001b[0;36m_context_wrap\u001b[0;34m(context, fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mcontext\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/development/pyro/pyro/poutine/messenger.py\u001b[0m in \u001b[0;36m_context_wrap\u001b[0;34m(context, fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mcontext\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/autograd/grad_mode.py\u001b[0m in \u001b[0;36mdecorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__class__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 28\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 29\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mcast\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mF\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/development/pyro/pyro/poutine/messenger.py\u001b[0m in \u001b[0;36m_context_wrap\u001b[0;34m(context, fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mcontext\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/development/pyro/pyro/nn/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 432\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pyro_context\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 433\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 434\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1050\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1051\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1052\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, y)\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0msigma\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msigma\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 20\u001b[0;31m \u001b[0mmean\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 21\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mpyro\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"data\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/development/pyro/pyro/nn/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 432\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pyro_context\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 433\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 434\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1050\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1051\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1052\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meinsum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"...ij,...nj->...ni\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (170) must match the size of tensor b (800) at non-singleton dimension 1", + "\nThe above exception was the direct cause of the following exception:\n", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 4\u001b[0m return_sites=(\"linear.weight\", \"obs\", \"_RETURN\"))\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0msamples\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpredict_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_data\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1049\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1050\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1051\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1052\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1053\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/development/pyro/pyro/infer/predictive.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 278\u001b[0m \u001b[0mparallel\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparallel\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 279\u001b[0m \u001b[0mmodel_args\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 280\u001b[0;31m \u001b[0mmodel_kwargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 281\u001b[0m )\n\u001b[1;32m 282\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/development/pyro/pyro/infer/predictive.py\u001b[0m in \u001b[0;36m_predictive\u001b[0;34m(model, posterior_samples, num_samples, return_sites, return_trace, parallel, model_args, model_kwargs)\u001b[0m\n\u001b[1;32m 137\u001b[0m trace = poutine.trace(\n\u001b[1;32m 138\u001b[0m \u001b[0mpoutine\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcondition\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvectorize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreshaped_samples\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 139\u001b[0;31m ).get_trace(*model_args, **model_kwargs)\n\u001b[0m\u001b[1;32m 140\u001b[0m \u001b[0mpredictions\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0msite\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mshape\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mreturn_site_shapes\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/development/pyro/pyro/poutine/trace_messenger.py\u001b[0m in \u001b[0;36mget_trace\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 196\u001b[0m \u001b[0mCalls\u001b[0m \u001b[0mthis\u001b[0m \u001b[0mpoutine\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mreturns\u001b[0m \u001b[0mits\u001b[0m \u001b[0mtrace\u001b[0m \u001b[0minstead\u001b[0m \u001b[0mof\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mfunction\u001b[0m\u001b[0;31m'\u001b[0m\u001b[0ms\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 197\u001b[0m \"\"\"\n\u001b[0;32m--> 198\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 199\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmsngr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/development/pyro/pyro/poutine/trace_messenger.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 178\u001b[0m \u001b[0mexc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mexc_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mu\"{}\\n{}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexc_value\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mshapes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 179\u001b[0m \u001b[0mexc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mexc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwith_traceback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtraceback\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 180\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mexc\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 181\u001b[0m self.msngr.trace.add_node(\n\u001b[1;32m 182\u001b[0m \u001b[0;34m\"_RETURN\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"_RETURN\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"return\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mret\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/development/pyro/pyro/poutine/trace_messenger.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 172\u001b[0m )\n\u001b[1;32m 173\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 174\u001b[0;31m \u001b[0mret\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 175\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mValueError\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 176\u001b[0m \u001b[0mexc_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexc_value\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtraceback\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexc_info\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/development/pyro/pyro/poutine/messenger.py\u001b[0m in \u001b[0;36m_context_wrap\u001b[0;34m(context, fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_context_wrap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcontext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mcontext\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/development/pyro/pyro/poutine/messenger.py\u001b[0m in \u001b[0;36m_context_wrap\u001b[0;34m(context, fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_context_wrap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcontext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mcontext\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/autograd/grad_mode.py\u001b[0m in \u001b[0;36mdecorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__class__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 28\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 29\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mcast\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mF\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/development/pyro/pyro/poutine/messenger.py\u001b[0m in \u001b[0;36m_context_wrap\u001b[0;34m(context, fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_context_wrap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcontext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mcontext\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/development/pyro/pyro/nn/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 431\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 432\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pyro_context\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 433\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 434\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 435\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__getattr__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1049\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1050\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1051\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1052\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1053\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, y)\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0msigma\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msigma\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 20\u001b[0;31m \u001b[0mmean\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 21\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mpyro\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"data\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0mobs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpyro\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"obs\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdist\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mNormal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msigma\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mobs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/development/pyro/pyro/nn/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 431\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 432\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pyro_context\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 433\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 434\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 435\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__getattr__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1049\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1050\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1051\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1052\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1053\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0mBatchedLinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLinear\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meinsum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"...ij,...nj->...ni\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0mBayesianRegression\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mPyroModule\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (170) must match the size of tensor b (800) at non-singleton dimension 1\n Trace Shapes: \n Param Sites: \n Sample Sites: \n sigma dist 800 1 | \n value 800 1 | \nlinear.weight dist 800 1 | 1 3\n value 800 | 1 3\n linear.bias dist 800 1 | 1 \n value 800 | 1 " + ] + } + ], + "source": [ + "from pyro.infer import Predictive\n", + "\n", + "predict_fn = Predictive(model, guide=guide, num_samples=800, parallel=True,\n", + " return_sites=(\"linear.weight\", \"obs\", \"_RETURN\"))\n", + "\n", + "samples = predict_fn(x_data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def summary(samples):\n", + " site_stats = {}\n", + " for k, v in samples.items():\n", + " site_stats[k] = {\n", + " \"mean\": torch.mean(v, 0),\n", + " \"std\": torch.std(v, 0),\n", + " \"5%\": v.kthvalue(int(len(v) * 0.05), dim=0)[0],\n", + " \"95%\": v.kthvalue(int(len(v) * 0.95), dim=0)[0],\n", + " }\n", + " return site_stats\n", + "\n", + "pred_summary = summary(samples)\n", + "\n", + "mu = pred_summary[\"_RETURN\"]\n", + "y = pred_summary[\"obs\"]\n", + "predictions = pd.DataFrame({\n", + " \"cont_africa\": x_data[:, 0],\n", + " \"rugged\": x_data[:, 1],\n", + " \"mu_mean\": mu[\"mean\"],\n", + " \"mu_perc_5\": mu[\"5%\"],\n", + " \"mu_perc_95\": mu[\"95%\"],\n", + " \"y_mean\": y[\"mean\"],\n", + " \"y_perc_5\": y[\"5%\"],\n", + " \"y_perc_95\": y[\"95%\"],\n", + " \"true_gdp\": y_data,\n", + "})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 6), sharey=True)\n", + "african_nations = predictions[predictions[\"cont_africa\"] == 1]\n", + "non_african_nations = predictions[predictions[\"cont_africa\"] == 0]\n", + "african_nations = african_nations.sort_values(by=[\"rugged\"])\n", + "non_african_nations = non_african_nations.sort_values(by=[\"rugged\"])\n", + "fig.suptitle(\"Posterior predictive distribution with 90% CI\", fontsize=16)\n", + "ax[0].plot(non_african_nations[\"rugged\"], non_african_nations[\"y_mean\"])\n", + "ax[0].fill_between(non_african_nations[\"rugged\"], \n", + " non_african_nations[\"y_perc_5\"],\n", + " non_african_nations[\"y_perc_95\"],\n", + " alpha=0.5)\n", + "ax[0].plot(non_african_nations[\"rugged\"], non_african_nations[\"true_gdp\"], \"o\")\n", + "ax[0].set(xlabel=\"Terrain Ruggedness Index\",\n", + " ylabel=\"log GDP (2000)\",\n", + " title=\"Non African Nations\")\n", + "idx = np.argsort(african_nations[\"rugged\"])\n", + "\n", + "ax[1].plot(african_nations[\"rugged\"], african_nations[\"y_mean\"])\n", + "ax[1].fill_between(african_nations[\"rugged\"],\n", + " african_nations[\"y_perc_5\"],\n", + " african_nations[\"y_perc_95\"],\n", + " alpha=0.5)\n", + "ax[1].plot(african_nations[\"rugged\"], african_nations[\"true_gdp\"], \"o\")\n", + "ax[1].set(xlabel=\"Terrain Ruggedness Index\",\n", + " ylabel=\"log GDP (2000)\",\n", + " title=\"African Nations\");" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We observe that the outcome from our model and the 90% CI accounts for the majority of the data points that we observe in practice. It is usually a good idea to do such posterior predictive checks to see if our model gives valid predictions. \n", + "\n", + "Finally, let us revisit our earlier question of how robust the relationship between terrain ruggedness and GDP is against any uncertainty in the parameter estimates from our model. For this, we plot the distribution of the slope of the log GDP given terrain ruggedness for nations within and outside Africa. As can be seen below, the probability mass for African nations is largely concentrated in the positive region and vice-versa for other nations, lending further credence to the original hypothesis. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "weight = samples[\"linear.weight\"]\n", + "weight = weight.reshape(weight.shape[0], 3)\n", + "gamma_within_africa = weight[:, 1] + weight[:, 2]\n", + "gamma_outside_africa = weight[:, 1]\n", + "fig = plt.figure(figsize=(10, 6))\n", + "sns.distplot(gamma_within_africa, kde_kws={\"label\": \"African nations\"},)\n", + "sns.distplot(gamma_outside_africa, kde_kws={\"label\": \"Non-African nations\"})\n", + "fig.suptitle(\"Density of Slope : log(GDP) vs. Terrain Ruggedness\");" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Training, Serialization and Serving via TorchScript\n", + "\n", + "Finally, note that `model`, `guide`, `elbo(model, guide)` and the `Predictive` utility class are all `torch.nn.Module` instances, and can be serialized as [TorchScript](https://pytorch.org/docs/stable/jit.html). \n", + "\n", + "Here, we show how we can train and serve a Pyro model as a [torch.jit.ModuleScript](https://pytorch.org/docs/stable/jit.html#torch.jit.ScriptModule), which can be run separately as a C++ program without a Python runtime. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create fresh copies of model, guide, elbo\n", + "model = BayesianRegression(3, 1)\n", + "guide = AutoDiagonalNormal(model)\n", + "elbo = Trace_ELBO(num_particles=10, vectorize_particles=True)(model, guide)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This time, we will JIT-compile the ELBO loss function, resulting in substantially faster training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Populate guide parameters\n", + "elbo(x_data, y_data);\n", + "optim = torch.optim.Adam(guide.parameters(), lr=0.03)\n", + "\n", + "# Temporarily disable runtime validation and compile ELBO\n", + "with pyro.validation_enabled(False):\n", + " elbo = torch.jit.trace(elbo, (x_data, y_data), check_trace=False, strict=False)\n", + "\n", + "# optimize\n", + "for j in range(1500 if not smoke_test else 2):\n", + " loss = elbo(x_data, y_data)\n", + " optim.zero_grad()\n", + " loss.backward()\n", + " optim.step()\n", + " if (j + 1) % 50 == 0:\n", + " print(\"[iteration %04d] loss: %.4f\" % (j + 1, loss.item()))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also JIT-compile `pyro.infer.Predictive`. Because `Predictive.forward` returns a dictionary, we must pass the keyword argument `strict=False` to `torch.jit.trace`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "predict_fn = Predictive(model, guide=guide, num_samples=800, parallel=True,\n", + " return_sites=(\"linear.weight\", \"obs\", \"_RETURN\"))\n", + "\n", + "predict_module = torch.jit.trace(predict_fn, (x_data,), check_trace=False, strict=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can serialize and save the compiled predictor using [torch.jit.save](https://pytorch.org/docs/stable/jit.html#torch.jit.save). This saved model `reg_predict.pt` can be loaded with PyTorch's C++ API using `torch::jit::load(filename)`, or using the Python API as we do below. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "torch.jit.save(predict_module, '/tmp/reg_predict.pt')\n", + "pred_loaded = torch.jit.load('/tmp/reg_predict.pt')\n", + "pred_loaded(x_data)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let us check that our `Predictive` module was indeed serialized correctly, by generating samples from the loaded module and regenerating the previous plot." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "weight = pred_loaded(x_data)[\"linear.weight\"]\n", + "weight = weight.reshape(weight.shape[0], 3).detach()\n", + "gamma_within_africa = weight[:, 1] + weight[:, 2]\n", + "gamma_outside_africa = weight[:, 1]\n", + "fig = plt.figure(figsize=(10, 6))\n", + "sns.distplot(gamma_within_africa, kde_kws={\"label\": \"African nations\"},)\n", + "sns.distplot(gamma_outside_africa, kde_kws={\"label\": \"Non-African nations\"})\n", + "fig.suptitle(\"Loaded TorchScript Module : log(GDP) vs. Terrain Ruggedness\");" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the next section, we'll look at how to write guides for variational inference as well as compare the results with inference via HMC." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### References\n", + " 1. McElreath, D., *Statistical Rethinking, Chapter 7*, 2016\n", + " 2. Nunn, N. & Puga, D., *[Ruggedness: The blessing of bad geography in Africa\"](https://diegopuga.org/papers/rugged.pdf)*, Review of Economics and Statistics 94(1), Feb. 2012" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "anaconda-cloud": {}, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tutorial/source/index.rst b/tutorial/source/index.rst index c1f329300b..968dba814a 100644 --- a/tutorial/source/index.rst +++ b/tutorial/source/index.rst @@ -91,6 +91,7 @@ List of Tutorials bayesian_regression bayesian_regression_ii + bayesian_regression_module tensor_shapes modules jit