Skip to content

Commit

Permalink
Update example
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Feb 16, 2024
1 parent 05dda75 commit df2b2da
Showing 1 changed file with 32 additions and 32 deletions.
64 changes: 32 additions & 32 deletions examples/GP_sGP.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {
"id": "136rQl-Z67Xf"
},
Expand Down Expand Up @@ -70,7 +70,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {
"id": "VQ1rLUzqha2i"
},
Expand All @@ -90,7 +90,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {
"id": "XCoyWlKt67Xk"
},
Expand All @@ -110,7 +110,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {
"id": "KtGDc11Ehh7r"
},
Expand All @@ -133,7 +133,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {
"id": "V5isV5Ho67Xl"
},
Expand All @@ -144,7 +144,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {
"id": "gUyKDZjM67Xl"
},
Expand Down Expand Up @@ -179,7 +179,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {
"id": "LAvbGDom67Xl"
},
Expand All @@ -205,7 +205,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
Expand Down Expand Up @@ -248,7 +248,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
Expand Down Expand Up @@ -279,7 +279,7 @@
],
"source": [
"# Get random number generator keys (see JAX documentation for why it is neccessary)\n",
"rng_key, rng_keposterior_meanict = gpax.utils.get_keys()\n",
"rng_key, rng_key_predict = gpax.utils.get_keys()\n",
"\n",
"# Initialize model\n",
"gp_model = gpax.ExactGP(1, kernel='Matern')\n",
Expand All @@ -288,7 +288,7 @@
"gp_model.fit(rng_key, X, y, num_chains=1)\n",
"\n",
"# Get GP prediction\n",
"posterior_mean, f_samples = gp_model.predict(rng_keposterior_meanict, X_test, n=200)"
"posterior_mean, f_samples = gp_model.predict(rng_key_predict, X_test, n=200)"
]
},
{
Expand All @@ -302,7 +302,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {
"id": "lnxdYcLL67Xm",
"outputId": "584a3a74-32e5-4f13-d3e5-0d2bcdf9cfe7",
Expand Down Expand Up @@ -348,7 +348,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"metadata": {
"id": "OjxPG_gY3U2c"
},
Expand All @@ -371,7 +371,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": null,
"metadata": {
"id": "zdrtXqGPKzUe"
},
Expand Down Expand Up @@ -401,7 +401,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": null,
"metadata": {
"id": "lqXxUSGeqGhm"
},
Expand Down Expand Up @@ -436,7 +436,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
Expand Down Expand Up @@ -524,7 +524,7 @@
" gp_model.fit(rng_key, X, y, num_warmup=NUM_WARMUP, num_samples=NUM_SAMPLES)\n",
"\n",
" # Get GP prediction\n",
" posterior_mean, f_samples = gp_model.predict(rng_keposterior_meanict, X_test, n=200)\n",
" posterior_mean, f_samples = gp_model.predict(rng_key_predict, X_test, n=200)\n",
"\n",
" # Plot results\n",
" _, ax = plt.subplots(dpi=100)\n",
Expand Down Expand Up @@ -572,7 +572,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": null,
"metadata": {
"id": "qRocZMUIVsp4"
},
Expand All @@ -584,7 +584,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
Expand Down Expand Up @@ -745,7 +745,7 @@
}
],
"source": [
"rng_key, rng_keposterior_meanict = gpax.utils.get_keys(1)\n",
"rng_key, rng_key_predict = gpax.utils.get_keys(1)\n",
"\n",
"for i in range(6):\n",
" print(\"\\nExploration step {}\".format(i+1))\n",
Expand All @@ -754,7 +754,7 @@
" gp_model.fit(rng_key, X, y, print_summary=1, num_warmup=NUM_WARMUP, num_samples=NUM_SAMPLES)\n",
" # Compute acquisition function (here it is simply the uncertinty in prediciton)\n",
" # and get the coordinate of the next point to measure\n",
" obj = gpax.acquisition.UE(rng_keposterior_meanict, gp_model, X_test)\n",
" obj = gpax.acquisition.UE(rng_key_predict, gp_model, X_test)\n",
" next_point_idx = obj.argmax()\n",
" # Append the 'suggested' point\n",
" X = np.append(X, X_test[next_point_idx])\n",
Expand All @@ -773,7 +773,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
Expand Down Expand Up @@ -806,17 +806,17 @@
}
],
"source": [
"rng_key, rng_keposterior_meanict = gpax.utils.get_keys(1)\n",
"rng_key, rng_key_predict = gpax.utils.get_keys(1)\n",
"# Update GP posterior\n",
"gp_model = gpax.ExactGP(1, kernel='Matern', mean_fn=piecewise1, mean_fn_prior=piecewise1_priors)\n",
"gp_model.fit(rng_key, X, y)\n",
"# Get GP prediction\n",
"posterior_mean, f_samples = gp_model.predict(rng_keposterior_meanict, X_test, n=200)"
"posterior_mean, f_samples = gp_model.predict(rng_key_predict, X_test, n=200)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
Expand Down Expand Up @@ -871,7 +871,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
Expand Down Expand Up @@ -1106,15 +1106,15 @@
"source": [
"X, y = Xo, yo # start from the same set of observations\n",
"\n",
"rng_key, rng_keposterior_meanict = gpax.utils.get_keys(1)\n",
"rng_key, rng_key_predict = gpax.utils.get_keys(1)\n",
"\n",
"for i in range(9):\n",
" print(\"\\nExploration step {}\".format(i+1))\n",
" # Obtain/update GP posterior\n",
" gp_model = gpax.ExactGP(1, kernel='Matern', mean_fn=piecewise2, mean_fn_prior=piecewise2_priors)\n",
" gp_model.fit(rng_key, X, y, print_summary=1, num_warmup=NUM_WARMUP, num_samples=NUM_SAMPLES)\n",
" # Compute acquisition function and get coordinate of the next point\n",
" obj = gpax.acquisition.UE(rng_keposterior_meanict, gp_model, X_test)\n",
" obj = gpax.acquisition.UE(rng_key_predict, gp_model, X_test)\n",
" next_point_idx = obj.argmax()\n",
" # Append the 'suggested' point\n",
" X = np.append(X, X_test[next_point_idx])\n",
Expand All @@ -1133,7 +1133,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
Expand Down Expand Up @@ -1166,17 +1166,17 @@
}
],
"source": [
"rng_key, rng_keposterior_meanict = gpax.utils.get_keys(1)\n",
"rng_key, rng_key_predict = gpax.utils.get_keys(1)\n",
"# Update GP posterior\n",
"gp_model = gpax.ExactGP(1, kernel='Matern', mean_fn=piecewise2, mean_fn_prior=piecewise2_priors)\n",
"gp_model.fit(rng_key, X, y, num_warmup=NUM_WARMUP, num_samples=NUM_SAMPLES)\n",
"# Get GP prediction\n",
"posterior_mean, f_samples = gp_model.predict(rng_keposterior_meanict, X_test, n=200)"
"posterior_mean, f_samples = gp_model.predict(rng_key_predict, X_test, n=200)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
Expand Down

0 comments on commit df2b2da

Please sign in to comment.