Skip to content

Commit

Permalink
Update example
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Feb 15, 2024
1 parent 97d02f8 commit 023752e
Showing 1 changed file with 53 additions and 64 deletions.
117 changes: 53 additions & 64 deletions examples/gpax_simpleGP.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -70,30 +70,13 @@
{
"cell_type": "code",
"metadata": {
"id": "VQ1rLUzqha2i",
"outputId": "462cff6e-ea94-48f5-a6b0-a8bd1f255bb4",
"colab": {
"base_uri": "https://localhost:8080/"
}
"id": "VQ1rLUzqha2i"
},
"source": [
"!pip install -q git+https://github.com/ziatdinovmax/gpax.git"
"!pip install -q gpax"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
" Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
" Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m312.7/312.7 kB\u001b[0m \u001b[31m4.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m371.0/371.0 kB\u001b[0m \u001b[31m7.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h Building wheel for gpax (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n"
]
}
]
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
Expand Down Expand Up @@ -154,7 +137,7 @@
"height": 449
},
"id": "-I4RQ2xCi0VV",
"outputId": "5fb08e30-3dab-4147-a1a1-93e23817b330"
"outputId": "e80e0974-0e92-47fd-c1e0-27885f78b6c4"
},
"source": [
"np.random.seed(0)\n",
Expand Down Expand Up @@ -214,24 +197,24 @@
"base_uri": "https://localhost:8080/"
},
"id": "c7kXm_lui6Dy",
"outputId": "0b260517-4273-4ce7-b04e-1911da22a3a5"
"outputId": "c8daf638-fea1-420b-9c3b-099999352d6a"
},
"source": [
"# Get random number generator keys for training and prediction\n",
"rng_key, rng_key_predict = gpax.utils.get_keys()\n",
"key1, key2 = gpax.utils.get_keys()\n",
"\n",
"# Initialize model\n",
"gp_model = gpax.ExactGP(1, kernel='RBF')\n",
"# Run Hamiltonian Monte Carlo to obtain posterior samples for kernel parameters and model noise\n",
"gp_model.fit(rng_key, X, y, num_chains=1)"
"gp_model.fit(key1, X, y, num_chains=1)"
],
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"sample: 100%|██████████| 4000/4000 [00:14<00:00, 277.25it/s, 7 steps of size 5.52e-01. acc. prob=0.90] \n"
"sample: 100%|██████████| 4000/4000 [00:10<00:00, 375.60it/s, 7 steps of size 5.52e-01. acc. prob=0.90] \n"
]
},
{
Expand Down Expand Up @@ -263,12 +246,12 @@
"\n",
"$$𝜇^{post}_*= \\frac{1}{L} ∑_{i=1}^L 𝜇_*^i,$$\n",
"\n",
"which corresponds to the ```y_pred``` in the code cell below, and\n",
"which corresponds to the ```posterior_mean``` in the code cell below, and\n",
"samples\n",
"\n",
"$$f_*^i∼MVNormal(𝜇^i_*, 𝛴^i_*)$$\n",
"\n",
"from multivariate normal distributions for all the pairs of predictive means and covariances (```y_sampled``` in the code cell below). Note that model noise is absorbed into the kernel computation function."
"from multivariate normal distributions for all the pairs of predictive means and covariances (```f_samples``` in the code cell below). Note that model noise is absorbed into the kernel computation function."
]
},
{
Expand All @@ -281,7 +264,7 @@
"X_test = np.linspace(-1, 1, 100)\n",
"# Get the GP prediction. Here n stands for the number of samples from each MVNormal distribution\n",
"# (the total number of MVNormal distributions is equal to the number of HMC samples)\n",
"y_pred, y_sampled = gp_model.predict(rng_key_predict, X_test, n=200)"
"posterior_mean, f_samples = gp_model.predict(key2, X_test, n=200)"
],
"execution_count": 5,
"outputs": []
Expand All @@ -303,17 +286,17 @@
"height": 449
},
"id": "lJIdx7fUnX-W",
"outputId": "29465923-a54e-44c3-ffe6-6d38d06dcbce"
"outputId": "f8076d8c-cbc1-49de-9b71-2c6d6665bbb5"
},
"source": [
"_, ax = plt.subplots(dpi=100)\n",
"ax.set_xlabel(\"$x$\")\n",
"ax.set_ylabel(\"$y$\")\n",
"ax.scatter(X, y, marker='x', c='k', zorder=1, label=\"Noisy observations\", alpha=0.7)\n",
"for y1 in y_sampled:\n",
"for y1 in f_samples:\n",
" ax.plot(X_test, y1.mean(0), lw=.1, zorder=0, c='r', alpha=.1)\n",
"l, = ax.plot(X_test, y_sampled[0].mean(0), lw=1, c='r', alpha=1, label=\"Sampled predictions\")\n",
"ax.plot(X_test, y_pred, lw=1.5, zorder=1, c='b', label='Posterior mean')\n",
"l, = ax.plot(X_test, f_samples[0].mean(0), lw=1, c='r', alpha=1, label=\"Sampled predictions\")\n",
"ax.plot(X_test, posterior_mean, lw=1.5, zorder=1, c='b', label='Posterior mean')\n",
"ax.legend(loc='upper left')\n",
"l.set_alpha(0)\n",
"ax.set_ylim(-1.8, 2.2);"
Expand Down Expand Up @@ -345,7 +328,7 @@
"cell_type": "code",
"metadata": {
"id": "7R0jWHFLtQ5b",
"outputId": "a2123685-171f-4d42-be56-88d63acecb70",
"outputId": "d4878597-20e5-41cf-bd66-c23aea38556e",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 449
Expand All @@ -356,8 +339,10 @@
"ax.set_xlabel(\"$x$\")\n",
"ax.set_ylabel(\"$y$\")\n",
"ax.scatter(X, y, marker='x', c='k', zorder=2, label=\"Noisy observations\", alpha=0.7)\n",
"ax.plot(X_test, y_pred, lw=1.5, zorder=2, c='b', label='Posterior mean')\n",
"ax.fill_between(X_test, y_pred - y_sampled.std(axis=(0,1)), y_pred + y_sampled.std(axis=(0,1)),\n",
"ax.plot(X_test, posterior_mean, lw=1.5, zorder=2, c='b', label='Posterior mean')\n",
"ax.fill_between(X_test,\n",
" posterior_mean - f_samples.std(axis=(0,1)),\n",
" posterior_mean + f_samples.std(axis=(0,1)),\n",
" color='r', alpha=0.3, label=\"Model uncertainty\")\n",
"ax.legend(loc='upper left')\n",
"ax.set_ylim(-1.8, 2.2);"
Expand Down Expand Up @@ -426,7 +411,7 @@
],
"metadata": {
"id": "znLfcvK0HSqY",
"outputId": "becc4a51-9c6d-40ac-93c6-2972a66881a0",
"outputId": "e4197eaf-64cd-4ea2-93a4-887ed12e9521",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 449
Expand Down Expand Up @@ -470,26 +455,26 @@
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "4506ef74-5e1e-4bb5-b239-8fcfbc045c3d",
"outputId": "5ccd70c9-a530-40dc-b403-a6cbba39fcd1",
"id": "Qwx5D237IVdC"
},
"source": [
"# Get random number generator keys for training and prediction\n",
"rng_key, rng_key_predict = gpax.utils.get_keys()\n",
"key1, key2 = gpax.utils.get_keys()\n",
"\n",
"# Initialize model\n",
"gp_model = gpax.ExactGP(1, kernel='RBF')\n",
"\n",
"# Run Hamiltonian Monte Carlo to obtain posterior samples for kernel parameters and model noise\n",
"gp_model.fit(rng_key, X, y, num_chains=1)"
"gp_model.fit(key1, X, y, num_chains=1)"
],
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"sample: 100%|██████████| 4000/4000 [00:07<00:00, 547.81it/s, 15 steps of size 2.69e-01. acc. prob=0.78]\n"
"sample: 100%|██████████| 4000/4000 [00:05<00:00, 706.59it/s, 15 steps of size 2.69e-01. acc. prob=0.78]\n"
]
},
{
Expand Down Expand Up @@ -523,7 +508,7 @@
"source": [
"X_test = np.linspace(-1, 1, 100)\n",
"\n",
"y_pred, y_sampled = gp_model.predict(rng_key_predict, X_test, n=200)"
"posterior_mean, f_samples = gp_model.predict(key2, X_test, n=200)"
],
"execution_count": 10,
"outputs": []
Expand All @@ -544,18 +529,18 @@
"base_uri": "https://localhost:8080/",
"height": 449
},
"outputId": "d77186a5-f0f2-465b-a653-cebe8503fda6",
"outputId": "8ee0ee9f-16a8-4628-abe2-235e4ef75081",
"id": "rv-9uBCHIhIo"
},
"source": [
"_, ax = plt.subplots(dpi=100)\n",
"ax.set_xlabel(\"$x$\")\n",
"ax.set_ylabel(\"$y$\")\n",
"ax.scatter(X, y, marker='x', c='k', zorder=1, label=\"Noisy observations\", alpha=0.7)\n",
"for y1 in y_sampled:\n",
"for y1 in f_samples:\n",
" ax.plot(X_test, y1.mean(0), lw=.1, zorder=0, c='r', alpha=.1)\n",
"l, = ax.plot(X_test, y_sampled[0].mean(0), lw=1, c='r', alpha=1, label=\"Sampled predictions\")\n",
"ax.plot(X_test, y_pred, lw=1.5, zorder=1, c='b', label='Posterior mean')\n",
"l, = ax.plot(X_test, f_samples[0].mean(0), lw=1, c='r', alpha=1, label=\"Sampled predictions\")\n",
"ax.plot(X_test, posterior_mean, lw=1.5, zorder=1, c='b', label='Posterior mean')\n",
"ax.legend(loc='upper left')\n",
"l.set_alpha(0)\n",
"ax.set_ylim(-1.2, 1.2);"
Expand Down Expand Up @@ -595,7 +580,7 @@
{
"cell_type": "code",
"metadata": {
"outputId": "2790be52-fe38-4c51-c09b-5501738cf87f",
"outputId": "103d087a-9a93-4be5-c1bb-c31a5f8b5438",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 449
Expand All @@ -607,8 +592,10 @@
"ax.set_xlabel(\"$x$\")\n",
"ax.set_ylabel(\"$y$\")\n",
"ax.scatter(X, y, marker='x', c='k', zorder=2, label=\"Noisy observations\", alpha=0.7)\n",
"ax.plot(X_test, y_pred, lw=1.5, zorder=2, c='b', label='Posterior mean')\n",
"ax.fill_between(X_test, y_pred - y_sampled.std(axis=(0,1)), y_pred + y_sampled.std(axis=(0,1)),\n",
"ax.plot(X_test, posterior_mean, lw=1.5, zorder=2, c='b', label='Posterior mean')\n",
"ax.fill_between(X_test,\n",
" posterior_mean - f_samples.std(axis=(0,1)),\n",
" posterior_mean + f_samples.std(axis=(0,1)),\n",
" color='r', alpha=0.3, label=\"Model uncertainty (2$\\sigma$)\")\n",
"ax.plot(X_test, f(X_test), color='k', alpha=0.7, zorder=1, label='Ground truth')\n",
"ax.legend(loc='upper left')\n",
Expand Down Expand Up @@ -684,7 +671,7 @@
],
"metadata": {
"id": "UnBPC1RoXAZ8",
"outputId": "3e72357a-d26c-41ec-c2f6-e7fcc6b8aa10",
"outputId": "98e3fa51-fbfd-4fa3-9607-fbdbdb70f6b9",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 430
Expand Down Expand Up @@ -717,17 +704,17 @@
"cell_type": "code",
"source": [
"# Get random number generator keys for training and prediction\n",
"rng_key, rng_key_predict = gpax.utils.get_keys()\n",
"key1, key2 = gpax.utils.get_keys()\n",
"\n",
"# Initialize model\n",
"gp_model = gpax.ExactGP(1, kernel='RBF', lengthscale_prior_dist=lengthscale_prior_dist)\n",
"\n",
"# Run Hamiltonian Monte Carlo to obtain posterior samples for kernel parameters and model noise\n",
"gp_model.fit(rng_key, X, y, num_chains=1)"
"gp_model.fit(key1, X, y, num_chains=1)"
],
"metadata": {
"id": "nV9SLaAEnv6v",
"outputId": "f02ef5d1-4bac-4fdd-bf07-9704293e442f",
"outputId": "ec8dae3c-7ae9-458c-898d-265ab62190c1",
"colab": {
"base_uri": "https://localhost:8080/"
}
Expand All @@ -738,7 +725,7 @@
"output_type": "stream",
"name": "stderr",
"text": [
"sample: 100%|██████████| 4000/4000 [00:07<00:00, 550.79it/s, 7 steps of size 4.26e-01. acc. prob=0.94] \n"
"sample: 100%|██████████| 4000/4000 [00:06<00:00, 647.22it/s, 7 steps of size 4.26e-01. acc. prob=0.94]\n"
]
},
{
Expand Down Expand Up @@ -767,7 +754,7 @@
{
"cell_type": "code",
"source": [
"y_pred, y_sampled = gp_model.predict(rng_key_predict, X_test, n=200)"
"posterior_mean, f_samples = gp_model.predict(key2, X_test, n=200)"
],
"metadata": {
"id": "W9woiXs5oFrR"
Expand All @@ -791,10 +778,10 @@
"ax.set_xlabel(\"$x$\")\n",
"ax.set_ylabel(\"$y$\")\n",
"ax.scatter(X, y, marker='x', c='k', zorder=1, label=\"Noisy observations\", alpha=0.7)\n",
"for y1 in y_sampled:\n",
"for y1 in f_samples:\n",
" ax.plot(X_test, y1.mean(0), lw=.1, zorder=0, c='r', alpha=.1)\n",
"l, = ax.plot(X_test, y_sampled[0].mean(0), lw=1, c='r', alpha=1, label=\"Sampled predictions\")\n",
"ax.plot(X_test, y_pred, lw=1.5, zorder=1, c='b', label='Posterior mean')\n",
"l, = ax.plot(X_test, f_samples[0].mean(0), lw=1, c='r', alpha=1, label=\"Sampled predictions\")\n",
"ax.plot(X_test, posterior_mean, lw=1.5, zorder=1, c='b', label='Posterior mean')\n",
"ax.legend(loc='upper left')\n",
"l.set_alpha(0)\n",
"ax.set_ylim(-1.2, 1.2);"
Expand All @@ -805,9 +792,9 @@
"height": 449
},
"id": "y9gHGmwD3lv-",
"outputId": "b02ecbbf-4091-42f3-f398-037f915cff6c"
"outputId": "24b2604a-a4b6-4715-d1b4-0d3476b8ef62"
},
"execution_count": 17,
"execution_count": 18,
"outputs": [
{
"output_type": "display_data",
Expand Down Expand Up @@ -846,22 +833,24 @@
"ax.set_xlabel(\"$x$\")\n",
"ax.set_ylabel(\"$y$\")\n",
"ax.scatter(X, y, marker='x', c='k', zorder=2, label=\"Noisy observations\", alpha=0.7)\n",
"ax.plot(X_test, y_pred, lw=1.5, zorder=2, c='b', label='Posterior mean')\n",
"ax.fill_between(X_test, y_pred - y_sampled.std(axis=(0,1)), y_pred + y_sampled.std(axis=(0,1)),\n",
"ax.plot(X_test, posterior_mean, lw=1.5, zorder=2, c='b', label='Posterior mean')\n",
"ax.fill_between(X_test,\n",
" posterior_mean - f_samples.std(axis=(0,1)),\n",
" posterior_mean + f_samples.std(axis=(0,1)),\n",
" color='r', alpha=0.3, label=\"Model uncertainty (2$\\sigma$)\")\n",
"ax.plot(X_test, f(X_test), color='k', alpha=0.7, zorder=1, label='Ground truth')\n",
"ax.legend(loc='upper left')\n",
"ax.set_ylim(-1.2, 1.2);"
],
"metadata": {
"id": "fm0e70PIoJvE",
"outputId": "85b7ca44-26f0-4e44-d4b9-47f82ca8b5d1",
"outputId": "8584ac26-b827-454d-aa78-75800863bd69",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 449
}
},
"execution_count": 18,
"execution_count": 19,
"outputs": [
{
"output_type": "display_data",
Expand Down

0 comments on commit 023752e

Please sign in to comment.