diff --git a/Python/numpyro_hierarchical_forecasting_1.ipynb b/Python/numpyro_hierarchical_forecasting_1.ipynb index 9507166..94e97d0 100644 --- a/Python/numpyro_hierarchical_forecasting_1.ipynb +++ b/Python/numpyro_hierarchical_forecasting_1.ipynb @@ -712,7 +712,7 @@ "source": [ "## Inference with SVI\n", "\n", - "We now fir the model to the data using stochastic variational inference." + "We now fit the model to the data using stochastic variational inference." ] }, { @@ -870,6 +870,9 @@ " return jnp.average(per_obs_crps, weights=sample_weight)\n", "\n", "\n", + "# For the purposes of comparison, we clip the predictions to be non-negative.\n", + "# But we keep the original values in the idata object.\n", + "\n", "crps_train = crps(\n", " y_train,\n", " jnp.array(idata_train[\"posterior_predictive\"][\"obs\"].sel(chain=0)).clip(min=0),\n", @@ -922,6 +925,8 @@ ")\n", "for i, ax in enumerate(axes):\n", " for j, hdi_prob in enumerate([0.94, 0.5]):\n", + " # For the purposes of comparison, we clip the predictions to be non-negative.\n", + " # But we keep the original values in the idata object.\n", " az.plot_hdi(\n", " time_train[time_train >= T1 - 24 * 7],\n", " idata_train[\"posterior_predictive\"][\"obs\"]\n", diff --git a/Python/numpyro_hierarchical_forecasting_2.ipynb b/Python/numpyro_hierarchical_forecasting_2.ipynb index a56a43b..d6fbd57 100644 --- a/Python/numpyro_hierarchical_forecasting_2.ipynb +++ b/Python/numpyro_hierarchical_forecasting_2.ipynb @@ -6,9 +6,7 @@ "source": [ "# From Pyro to NumPyro: Forecasting Hierarchical Models - Part I\n", "\n", - "In this notebook we provide a NumPyro implementation of the first model presented in the Pyro forecasting documentation: [Forecasting III: hierarchical models](https://pyro.ai/examples/forecasting_iii.html). This model generalizes the local level model with seasonality presented in the univariate example [Forecasting I: univariate, heavy tailed](https://pyro.ai/examples/forecasting_i.html) (see [From Pyro to NumPyro: Forecasting a univariate, heavy tailed time series](https://juanitorduz.github.io/numpyro_forecasting-univariate/) for the corresponding NumPyro implementation).\n", - "\n", - "In this example, we continue working with the BART train ridership [dataset](https://www.bart.gov/about/reports/ridership)." + "In this second notebook, we continue working on the NumPyro implementation of the hierarchical forecasting models presented in Pyro's forecasting documentation: [Forecasting III: hierarchical models](https://pyro.ai/examples/forecasting_iii.html). In this second part, we extend the model described in the first part [From Pyro to NumPyro: Forecasting Hierarchical Models - Part I](https://juanitorduz.github.io/numpyro_hierarchical_forecasting_1/) by adding all stations to the model." ] }, { @@ -20,9 +18,20 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 18, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n", + "The jaxtyping extension is already loaded. To reload it, use:\n", + " %reload_ext jaxtyping\n" + ] + } + ], "source": [ "import arviz as az\n", "import jax\n", @@ -63,7 +72,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -87,12 +96,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "For this first model, we just model the rides to Embarcadero station, from each of the other $50$ stations." + "In this second example, we model all the rides from all stations to all other stations." ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -113,12 +122,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "For training purposes we will use data from 90 days before the test data." + "## Train - Test Split\n", + "\n", + "Similarly as in the first example, for training purposes we will use data from 90 days before the test data." ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ @@ -129,7 +140,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -154,7 +165,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 23, "metadata": {}, "outputs": [], "source": [ @@ -167,27 +178,16 @@ "time_test = jnp.array(range(T1, T2))\n", "t_max_test = time_test.size\n", "\n", - "assert time_train.size + time_test.size == time.size\n", - "assert y_train.shape == (n_stations, n_stations, t_max_train)\n", - "assert y_test.shape == (n_stations, n_stations, t_max_test)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "As in the example before ([From Pyro to NumPyro: Forecasting a univariate, heavy tailed time series](https://juanitorduz.github.io/numpyro_forecasting-univariate/)), we use the covariates input tensor to encode the data size. We can of course use this tensor to encode other covariates, but for this example we will not use them." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ "covariates = jnp.zeros_like(y)\n", "covariates_train = jnp.zeros_like(y_train)\n", - "covariates_test = jnp.zeros_like(y_test)" + "covariates_test = jnp.zeros_like(y_test)\n", + "\n", + "assert time_train.size + time_test.size == time.size\n", + "assert y_train.shape == (n_stations, n_stations, t_max_train)\n", + "assert y_test.shape == (n_stations, n_stations, t_max_test)\n", + "assert covariates.shape == y.shape\n", + "assert covariates_train.shape == y_train.shape\n", + "assert covariates_test.shape == y_test.shape" ] }, { @@ -196,19 +196,7 @@ "source": [ "## Repeating Seasonal Features\n", "\n", - "In the univariate case example we studied a very handy function to generate Fourier modes, [`periodic_features`](https://docs.pyro.ai/en/stable/ops.html?highlight=periodic_features#pyro.ops.tensor_utils.periodic_features). In this case, there is a very handy function to repeat the seasonal features, [`periodic_repeat`](https://docs.pyro.ai/en/stable/ops.html#pyro.ops.tensor_utils.periodic_repeat). It has two main parameters:\n", - "\n", - "- `size` (int) – Desired size of the result along dimension `dim`.\n", - "- `dim` (int) – The tensor dimension along which to repeat.\n", - "\n", - "Let's see some example from the docstrings.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The translation from PyTorch to JAX is not that hard (thank you GitHub Copilot 😅)." + "We also need the JAX version of the [`periodic_repeat`](https://docs.pyro.ai/en/stable/ops.html#pyro.ops.tensor_utils.periodic_repeat) function.\n" ] }, { @@ -261,7 +249,7 @@ "source": [ "## Model Specification\n", "\n", - "This first hierarchical model extends the local level model with seasonality seen in the univariate case example, [From Pyro to NumPyro: Forecasting a univariate, heavy tailed time series](https://juanitorduz.github.io/numpyro_forecasting-univariate/). " + "In this model, the local level dynamic is driven by the destination station. On the other hand, the seasonal components and the noise scales come as a sum of the origin and destination stations. The model structure is very similar to the one presented in the first example." ] }, { @@ -277,6 +265,7 @@ " # Get the time and feature dimensions\n", " n_series, n_series, t_max = covariates.shape\n", "\n", + " # Define the plates to be able to use them below\n", " origin_plate = numpyro.plate(\"origin\", n_series, dim=-3)\n", " destin_plate = numpyro.plate(\"destin\", n_series, dim=-2)\n", " hour_of_week_plate = numpyro.plate(\"hour_of_week\", 24 * 7, dim=-1)\n", @@ -308,9 +297,14 @@ " \"destin_seasonal\", dist.Normal(loc=0, scale=5)\n", " )\n", "\n", + " # We model a static pairwise station->station affinity, which e.g.\n", + " # can compensate for the fact that people tend not to travel from\n", + " # a station to itself.\n", " with origin_plate, destin_plate:\n", " pairwise = numpyro.sample(\"pairwise\", dist.Normal(0, 1))\n", "\n", + " # We model the origin and destination scales separately\n", + " # and then add them together to get the final scale.\n", " with origin_plate:\n", " origin_scale = numpyro.sample(\"origin_scale\", dist.LogNormal(-5, 5))\n", " with destin_plate:\n", @@ -333,9 +327,10 @@ " transition_fn, init=jnp.zeros((n_series,)), xs=jnp.arange(t_max)\n", " )\n", "\n", + " # We need to transpose the prediction levels to match the shape of the data\n", " pred_levels = pred_levels.transpose(1, 0)\n", "\n", - " # # Compute the mean of the model\n", + " # Compute the mean of the model\n", " mu = pred_levels + seasonal_repeat + pairwise\n", "\n", " # Sample the observations\n", @@ -595,7 +590,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Let's plot the prior predictive distribution for the first $8$ stations." + "Let's plot the prior predictive distribution for the first $8$ stations for the destination station `ANTC`." ] }, { @@ -663,7 +658,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Overall, the prior ranges look very reasonable." + "Overall, the prior ranges look very reasonable (even too wide).\n" ] }, { @@ -672,7 +667,7 @@ "source": [ "## Inference with SVI\n", "\n", - "We now fir the model to the data using stochastic variational inference." + "We now fit the model to the data using stochastic variational inference. This time the model runs for longer as compared to the first one ($45$ seconds to $3.5$ minutes)." ] }, { @@ -747,7 +742,7 @@ "source": [ "## Posterior Predictive Check\n", "\n", - "Next, we generate posterior predictive samples for the forecast for each of the $50$ stations." + "Next, we generate posterior predictive samples for the forecast for each of the stations pairs." ] }, { @@ -802,7 +797,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "As in the univariate case example, we compute the CRPS for the training and test data." + "To evaluate the model performance,we compute the CRPS for the training and test data. For comparison purposes, we clip the data to ensure the predictions are non-negative." ] }, {