Skip to content

Commit

Permalink
final run
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Oct 4, 2024
1 parent 53ccfe7 commit dcabf06
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 52 deletions.
7 changes: 6 additions & 1 deletion Python/numpyro_hierarchical_forecasting_1.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@
"source": [
"## Inference with SVI\n",
"\n",
"We now fir the model to the data using stochastic variational inference."
"We now fit the model to the data using stochastic variational inference."
]
},
{
Expand Down Expand Up @@ -870,6 +870,9 @@
" return jnp.average(per_obs_crps, weights=sample_weight)\n",
"\n",
"\n",
"# For the purposes of comparison, we clip the predictions to be non-negative.\n",
"# But we keep the original values in the idata object.\n",
"\n",
"crps_train = crps(\n",
" y_train,\n",
" jnp.array(idata_train[\"posterior_predictive\"][\"obs\"].sel(chain=0)).clip(min=0),\n",
Expand Down Expand Up @@ -922,6 +925,8 @@
")\n",
"for i, ax in enumerate(axes):\n",
" for j, hdi_prob in enumerate([0.94, 0.5]):\n",
" # For the purposes of comparison, we clip the predictions to be non-negative.\n",
" # But we keep the original values in the idata object.\n",
" az.plot_hdi(\n",
" time_train[time_train >= T1 - 24 * 7],\n",
" idata_train[\"posterior_predictive\"][\"obs\"]\n",
Expand Down
97 changes: 46 additions & 51 deletions Python/numpyro_hierarchical_forecasting_2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
"source": [
"# From Pyro to NumPyro: Forecasting Hierarchical Models - Part I\n",
"\n",
"In this notebook we provide a NumPyro implementation of the first model presented in the Pyro forecasting documentation: [Forecasting III: hierarchical models](https://pyro.ai/examples/forecasting_iii.html). This model generalizes the local level model with seasonality presented in the univariate example [Forecasting I: univariate, heavy tailed](https://pyro.ai/examples/forecasting_i.html) (see [From Pyro to NumPyro: Forecasting a univariate, heavy tailed time series](https://juanitorduz.github.io/numpyro_forecasting-univariate/) for the corresponding NumPyro implementation).\n",
"\n",
"In this example, we continue working with the BART train ridership [dataset](https://www.bart.gov/about/reports/ridership)."
"In this second notebook, we continue working on the NumPyro implementation of the hierarchical forecasting models presented in Pyro's forecasting documentation: [Forecasting III: hierarchical models](https://pyro.ai/examples/forecasting_iii.html). In this second part, we extend the model described in the first part [From Pyro to NumPyro: Forecasting Hierarchical Models - Part I](https://juanitorduz.github.io/numpyro_hierarchical_forecasting_1/) by adding all stations to the model."
]
},
{
Expand All @@ -20,9 +18,20 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 18,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n",
"The jaxtyping extension is already loaded. To reload it, use:\n",
" %reload_ext jaxtyping\n"
]
}
],
"source": [
"import arviz as az\n",
"import jax\n",
Expand Down Expand Up @@ -63,7 +72,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 19,
"metadata": {},
"outputs": [
{
Expand All @@ -87,12 +96,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"For this first model, we just model the rides to Embarcadero station, from each of the other $50$ stations."
"In this second example, we model all the rides from all stations to all other stations."
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 20,
"metadata": {},
"outputs": [
{
Expand All @@ -113,12 +122,14 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"For training purposes we will use data from 90 days before the test data."
"## Train - Test Split\n",
"\n",
"Similarly as in the first example, for training purposes we will use data from 90 days before the test data."
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -129,7 +140,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 22,
"metadata": {},
"outputs": [
{
Expand All @@ -154,7 +165,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -167,27 +178,16 @@
"time_test = jnp.array(range(T1, T2))\n",
"t_max_test = time_test.size\n",
"\n",
"assert time_train.size + time_test.size == time.size\n",
"assert y_train.shape == (n_stations, n_stations, t_max_train)\n",
"assert y_test.shape == (n_stations, n_stations, t_max_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As in the example before ([From Pyro to NumPyro: Forecasting a univariate, heavy tailed time series](https://juanitorduz.github.io/numpyro_forecasting-univariate/)), we use the covariates input tensor to encode the data size. We can of course use this tensor to encode other covariates, but for this example we will not use them."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"covariates = jnp.zeros_like(y)\n",
"covariates_train = jnp.zeros_like(y_train)\n",
"covariates_test = jnp.zeros_like(y_test)"
"covariates_test = jnp.zeros_like(y_test)\n",
"\n",
"assert time_train.size + time_test.size == time.size\n",
"assert y_train.shape == (n_stations, n_stations, t_max_train)\n",
"assert y_test.shape == (n_stations, n_stations, t_max_test)\n",
"assert covariates.shape == y.shape\n",
"assert covariates_train.shape == y_train.shape\n",
"assert covariates_test.shape == y_test.shape"
]
},
{
Expand All @@ -196,19 +196,7 @@
"source": [
"## Repeating Seasonal Features\n",
"\n",
"In the univariate case example we studied a very handy function to generate Fourier modes, [`periodic_features`](https://docs.pyro.ai/en/stable/ops.html?highlight=periodic_features#pyro.ops.tensor_utils.periodic_features). In this case, there is a very handy function to repeat the seasonal features, [`periodic_repeat`](https://docs.pyro.ai/en/stable/ops.html#pyro.ops.tensor_utils.periodic_repeat). It has two main parameters:\n",
"\n",
"- `size` (int) – Desired size of the result along dimension `dim`.\n",
"- `dim` (int) – The tensor dimension along which to repeat.\n",
"\n",
"Let's see some example from the docstrings.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The translation from PyTorch to JAX is not that hard (thank you GitHub Copilot 😅)."
"We also need the JAX version of the [`periodic_repeat`](https://docs.pyro.ai/en/stable/ops.html#pyro.ops.tensor_utils.periodic_repeat) function.\n"
]
},
{
Expand Down Expand Up @@ -261,7 +249,7 @@
"source": [
"## Model Specification\n",
"\n",
"This first hierarchical model extends the local level model with seasonality seen in the univariate case example, [From Pyro to NumPyro: Forecasting a univariate, heavy tailed time series](https://juanitorduz.github.io/numpyro_forecasting-univariate/). "
"In this model, the local level dynamic is driven by the destination station. On the other hand, the seasonal components and the noise scales come as a sum of the origin and destination stations. The model structure is very similar to the one presented in the first example."
]
},
{
Expand All @@ -277,6 +265,7 @@
" # Get the time and feature dimensions\n",
" n_series, n_series, t_max = covariates.shape\n",
"\n",
" # Define the plates to be able to use them below\n",
" origin_plate = numpyro.plate(\"origin\", n_series, dim=-3)\n",
" destin_plate = numpyro.plate(\"destin\", n_series, dim=-2)\n",
" hour_of_week_plate = numpyro.plate(\"hour_of_week\", 24 * 7, dim=-1)\n",
Expand Down Expand Up @@ -308,9 +297,14 @@
" \"destin_seasonal\", dist.Normal(loc=0, scale=5)\n",
" )\n",
"\n",
" # We model a static pairwise station->station affinity, which e.g.\n",
" # can compensate for the fact that people tend not to travel from\n",
" # a station to itself.\n",
" with origin_plate, destin_plate:\n",
" pairwise = numpyro.sample(\"pairwise\", dist.Normal(0, 1))\n",
"\n",
" # We model the origin and destination scales separately\n",
" # and then add them together to get the final scale.\n",
" with origin_plate:\n",
" origin_scale = numpyro.sample(\"origin_scale\", dist.LogNormal(-5, 5))\n",
" with destin_plate:\n",
Expand All @@ -333,9 +327,10 @@
" transition_fn, init=jnp.zeros((n_series,)), xs=jnp.arange(t_max)\n",
" )\n",
"\n",
" # We need to transpose the prediction levels to match the shape of the data\n",
" pred_levels = pred_levels.transpose(1, 0)\n",
"\n",
" # # Compute the mean of the model\n",
" # Compute the mean of the model\n",
" mu = pred_levels + seasonal_repeat + pairwise\n",
"\n",
" # Sample the observations\n",
Expand Down Expand Up @@ -595,7 +590,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's plot the prior predictive distribution for the first $8$ stations."
"Let's plot the prior predictive distribution for the first $8$ stations for the destination station `ANTC`."
]
},
{
Expand Down Expand Up @@ -663,7 +658,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Overall, the prior ranges look very reasonable."
"Overall, the prior ranges look very reasonable (even too wide).\n"
]
},
{
Expand All @@ -672,7 +667,7 @@
"source": [
"## Inference with SVI\n",
"\n",
"We now fir the model to the data using stochastic variational inference."
"We now fit the model to the data using stochastic variational inference. This time the model runs for longer as compared to the first one ($45$ seconds to $3.5$ minutes)."
]
},
{
Expand Down Expand Up @@ -747,7 +742,7 @@
"source": [
"## Posterior Predictive Check\n",
"\n",
"Next, we generate posterior predictive samples for the forecast for each of the $50$ stations."
"Next, we generate posterior predictive samples for the forecast for each of the stations pairs."
]
},
{
Expand Down Expand Up @@ -802,7 +797,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"As in the univariate case example, we compute the CRPS for the training and test data."
"To evaluate the model performance,we compute the CRPS for the training and test data. For comparison purposes, we clip the data to ensure the predictions are non-negative."
]
},
{
Expand Down

0 comments on commit dcabf06

Please sign in to comment.