Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add simple notebook comparing the results of ExactGP and viGP #50

Merged
merged 5 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions .github/workflows/notebook_smoke.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@ on:
branches:
- '*'
tags:
- '*'
- '*'

jobs:
build-linux:


strategy:
matrix:
python-version: ['3.9', '3.10', '3.11']
Expand Down Expand Up @@ -49,6 +48,5 @@ jobs:
pip install ipython
pip install nbformat
pip install seaborn
cd examples
ipython -c "%run simpleGP.ipynb"
bash scripts/test_notebooks.sh

371 changes: 371 additions & 0 deletions examples/compare_GPs.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,371 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "view-in-github"
},
"source": [
"<a href=\"https://colab.research.google.com/github/ziatdinovmax/gpax/blob/main/examples/compare_GPs.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QNbovYISOvRm"
},
"source": [
"# Compare SimpleGP and viGP\n",
"\n",
"This is a simple notebook to compare timings and results of two different commonly used GPs. One trained using NUTS, and the other trained using SVI.\n",
"\n",
"*Prepared by Matthew R. Carbone & Maxim Ziatdinov (2023)*"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Background"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Depending on the amount of data you have, the number of dimensions the inputs have, and your time budget for training, you may want to use a GP fit using stochastic variational inference vs. Markov chain monte carlo. The following compares some strengths and weaknesses of the two methods.\n",
"\n",
"**Hamiltonian Monte Carlo (HMC)/No U-Turn Sampler (NUTS)**\n",
"\n",
"- Sensitivity to Priors: This can be perceived as a strength or a weakness, depending on the context. However, many researchers appreciate it because it offers a more intuitive grasp of the model.\n",
"- Reliable Uncertainty Estimates: Offers robust evaluations of uncertainties as it directly samples from the posterior. The variational methods are known to lead to underestimation of uncertainties.\n",
"- Integration with Classical Bayesian Models: This is particularly evident when you consider the combination of Gaussian Processes with traditional Bayesian models, as demonstrated in structured GP and hypothesis learning.\n",
"- Comprehensive Convergence Diagnostics: Indicators such as n_eff, r_hat, and acc_prob for each inferred parameter.\n",
"- Speed Limitations: One of the primary drawbacks is its computational speed.\n",
"\n",
"**Stochastic Variational Inference (SVI)**\n",
"\n",
"- Efficiency: It's significantly faster and is memory-efficient (performs equally well with 32-bit precision)\n",
"- Acceptable Trade-offs: For many real-world tasks, the slight decrease in the accuracy of predictive uncertainty estimates is either negligible or acceptable.\n",
"- Convergence Indicator Limitations: The loss may not be a very good indicator of convergence - can easily overshoot or undershoot."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HdtH0tCPQ2de"
},
"source": [
"## Install & Import"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "86iUwKxLO7qE"
},
"source": [
"Install GPax package:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "VQ1rLUzqha2i",
"outputId": "44157aab-4e21-4966-ec79-ccf85cd4bbaa"
},
"outputs": [],
"source": [
"!pip install gpax"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vygoK7MTjJWB"
},
"source": [
"Import needed packages:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"try:\n",
" # For use on Google Colab\n",
" import gpax\n",
"\n",
"except ImportError:\n",
" # For use locally (where you're using the local version of gpax)\n",
" print(\"Assuming notebook is being run locally, attempting to import local gpax module\")\n",
" import sys\n",
" sys.path.append(\"..\")\n",
" import gpax"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "KtGDc11Ehh7r"
},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"gpax.utils.enable_x64() # enable double precision"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "o8h5uDi9Q-8Y"
},
"source": [
"## Create data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cwowQv9KjB8k"
},
"source": [
"Generate some noisy observations:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 382
},
"id": "-I4RQ2xCi0VV",
"outputId": "9a2a93dd-eade-48b7-8f46-480787f7204d"
},
"outputs": [],
"source": [
"np.random.seed(0)\n",
"\n",
"NUM_INIT_POINTS = 25 # number of observation points\n",
"NOISE_LEVEL = 0.1 # noise level\n",
"\n",
"# Generate noisy data from a known function\n",
"f = lambda x: np.sin(10*x)\n",
"\n",
"X = np.random.uniform(-1., 1., NUM_INIT_POINTS)\n",
"y = f(X) + np.random.normal(0., NOISE_LEVEL, NUM_INIT_POINTS)\n",
"\n",
"# Plot generated data\n",
"plt.figure(dpi=100)\n",
"plt.xlabel(\"$x$\")\n",
"plt.ylabel(\"$y$\")\n",
"plt.scatter(X, y, marker='x', c='k', zorder=1, label='Noisy observations')\n",
"plt.ylim(-1.8, 2.2);"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "quGwYqtfRCjn"
},
"source": [
"## Standard `ExactGP`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Vbo1zf05r8i5"
},
"source": [
"Next, we initialize and train a GP model. We are going to use an RBF kernel, $k_{RBF}=𝜎exp(-\\frac{||x_i-x_j||^2}{2l^2})$, which is a \"go-to\" kernel functions in GP."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "c7kXm_lui6Dy",
"outputId": "f00b7b7f-4853-4bb3-d9da-75157ec10105"
},
"outputs": [],
"source": [
"# Get random number generator keys for training and prediction\n",
"rng_key, rng_key_predict = gpax.utils.get_keys()\n",
"\n",
"# Initialize model\n",
"gp_model_1 = gpax.ExactGP(1, kernel='RBF')\n",
"\n",
"# Run Hamiltonian Monte Carlo to obtain posterior samples for kernel parameters and model noise\n",
"gp_model_1.fit(rng_key, X, y, num_chains=1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Standard `viGP`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Get random number generator keys for training and prediction\n",
"rng_key, rng_key_predict = gpax.utils.get_keys()\n",
"\n",
"# Initialize model\n",
"gp_model_2 = gpax.viGP(1, kernel='RBF')\n",
"\n",
"# Run Hamiltonian Monte Carlo to obtain posterior samples for kernel parameters and model noise\n",
"gp_model_2.fit(rng_key, X, y)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZGoIdNDknIyW"
},
"outputs": [],
"source": [
"X_test = np.linspace(-1, 1, 100)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"y_pred_1, y_sampled_1 = gp_model_1.predict(rng_key_predict, X_test, n=200)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"y_pred_2, y_sampled_2 = gp_model_2.predict(rng_key_predict, X_test, n=200)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that SVI (the `viGP`) is significantly faster. SVI is usually better to use on larger datasets and is more easily scalable. In this case, they produce similar results."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"y_sampled_1.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"y_sampled_2.shape # Note shape difference between predict methods"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9Z8JdLXMngRn"
},
"source": [
"Plot the obtained results:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 382
},
"id": "7R0jWHFLtQ5b",
"outputId": "b43c503c-90bc-4026-cd53-95c4528872e8"
},
"outputs": [],
"source": [
"_, ax = plt.subplots(1, 1, figsize=(6, 2), dpi=200)\n",
"\n",
"ax.set_xlabel(\"$x$\")\n",
"ax.set_ylabel(\"$y$\")\n",
"ax.plot(X_test, y_pred_1, lw=1.5, zorder=2, c='r', label='NUTS/MCMC')\n",
"ax.fill_between(X_test, y_pred_1 - y_sampled_1.std(axis=(0,1)), y_pred_1 + y_sampled_1.std(axis=(0,1)),\n",
" color='r', alpha=0.3, linewidth=0)\n",
"\n",
"\n",
"ax.set_xlabel(\"$x$\")\n",
"ax.set_ylabel(\"$y$\")\n",
"ax.plot(X_test, y_pred_2, lw=1.5, zorder=2, c='b', label='SVI')\n",
"ax.fill_between(X_test, y_pred_2 - np.sqrt(y_sampled_2), y_pred_2 + np.sqrt(y_sampled_2),\n",
" color='b', alpha=0.3, linewidth=0)\n",
"\n",
"\n",
"\n",
"ax.set_ylim(-1.8, 2.2)\n",
"\n",
"ax.scatter(X, y, marker='x', c='k', zorder=2, label=\"Noisy observations\", alpha=0.7)\n",
"\n",
"ax.legend(loc='upper left', ncols=3)\n",
"\n",
"plt.show()"
]
}
],
"metadata": {
"colab": {
"authorship_tag": "ABX9TyODwU6tDoyKQzfYTxvWNvTp",
"include_colab_link": true,
"name": "simpleGP.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Loading
Loading