-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3 from onnela-lab/rev
Revisions.
- Loading branch information
Showing
23 changed files
with
1,650 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,14 @@ | ||
FROM python:3.10 | ||
# Install R and dependencies. | ||
RUN apt-get update && apt-get install -y \ | ||
r-base \ | ||
r-cran-devtools \ | ||
&& rm -rf /var/lib/apt/lists/* | ||
WORKDIR /workdir | ||
COPY setup.R . | ||
RUN Rscript setup.R | ||
|
||
# Install Python dependencies and compile cmdstan. | ||
COPY requirements.txt . | ||
RUN pip install --no-cache-dir -r requirements.txt | ||
RUN python -m cmdstanpy.install_cmdstan --verbose --version 2.33.0 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
getting_started |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
Install the packages. We explicitly specify the repos here so we don't get asked for the mirror when | ||
rendering the Rmarkdown. | ||
|
||
```{r} | ||
install.packages( | ||
"cmdstanr", | ||
repos = c("https://mc-stan.org/r-packages/", "http://cran.us.r-project.org") | ||
) | ||
install.packages("gptoolsStan", repos=c("http://cran.us.r-project.org")) | ||
``` | ||
|
||
Compile and run the model. | ||
|
||
```{r} | ||
library(cmdstanr) | ||
library(gptoolsStan) | ||
model <- cmdstan_model( | ||
stan_file="getting_started.stan", | ||
include_paths=gptools_include_path(), | ||
) | ||
fit <- model$sample( | ||
data=list(n=100, sigma=1, length_scale=0.1, period=1), | ||
chains=1, | ||
iter_warmup=500, | ||
iter_sampling=50 | ||
) | ||
f <- fit$draws("f") | ||
dim(f) | ||
``` | ||
|
||
Expected output: `[1] 50 1 100` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "d2b9dac2", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
">>> import cmdstanpy\n", | ||
">>> from gptools.stan import get_include\n", | ||
">>>\n", | ||
">>> model = cmdstanpy.CmdStanModel(\n", | ||
"... stan_file=\"getting_started.stan\",\n", | ||
"... stanc_options={\"include-paths\": get_include()},\n", | ||
"... )\n", | ||
">>> fit = model.sample(\n", | ||
"... data = {\"n\": 100, \"sigma\": 1, \"length_scale\": 0.1, \"period\": 1},\n", | ||
"... chains=1,\n", | ||
"... iter_warmup=500,\n", | ||
"... iter_sampling=50,\n", | ||
"... )\n", | ||
">>> fit.f.shape" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "34ffcd17", | ||
"metadata": {}, | ||
"source": [ | ||
"Expected output: `(50, 100)`" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
--- | ||
jupytext: | ||
text_representation: | ||
extension: .md | ||
format_name: myst | ||
format_version: 0.13 | ||
jupytext_version: 1.15.1 | ||
kernelspec: | ||
display_name: Python 3 (ipykernel) | ||
language: python | ||
name: python3 | ||
--- | ||
|
||
```{code-cell} ipython3 | ||
>>> import cmdstanpy | ||
>>> from gptools.stan import get_include | ||
>>> | ||
>>> model = cmdstanpy.CmdStanModel( | ||
... stan_file="getting_started.stan", | ||
... stanc_options={"include-paths": get_include()}, | ||
... ) | ||
>>> fit = model.sample( | ||
... data = {"n": 100, "sigma": 1, "length_scale": 0.1, "period": 1}, | ||
... chains=1, | ||
... iter_warmup=500, | ||
... iter_sampling=50, | ||
... ) | ||
>>> fit.f.shape | ||
``` | ||
|
||
Expected output: `(50, 100)` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
functions { | ||
#include gptools/util.stan | ||
#include gptools/fft.stan | ||
} | ||
|
||
data { | ||
int n; | ||
real<lower=0> sigma, length_scale, period; | ||
} | ||
|
||
transformed data { | ||
vector [n %/% 2 + 1] cov_rfft = | ||
gp_periodic_exp_quad_cov_rfft(n, sigma, length_scale, period) + 1e-9; | ||
} | ||
|
||
parameters { | ||
vector [n] f; | ||
} | ||
|
||
model { | ||
f ~ gp_rfft(zeros_vector(n), cov_rfft); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "031e9913", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from gptools.util.kernels import ExpQuadKernel, MaternKernel\n", | ||
"from gptools.util.fft.fft1 import transform_irfft, evaluate_rfft_scale\n", | ||
"import matplotlib as mpl\n", | ||
"from matplotlib import pyplot as plt\n", | ||
"import numpy as np\n", | ||
"from pathlib import Path\n", | ||
"\n", | ||
"mpl.style.use(\"../jss.mplstyle\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "250db79b", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"np.random.seed(9) # Seed picked for good legend positioning. Works for any though.\n", | ||
"fig, axes = plt.subplots(2, 2)\n", | ||
"length_scale = 0.2\n", | ||
"kernels = {\n", | ||
" \"squared exp.\": lambda period: ExpQuadKernel(1, length_scale, period),\n", | ||
" \"Matern ³⁄₂\": lambda period: MaternKernel(1.5, 1, length_scale, period),\n", | ||
"}\n", | ||
"\n", | ||
"x = np.linspace(0, 1, 101, endpoint=False)\n", | ||
"z = np.random.normal(0, 1, x.size)\n", | ||
"\n", | ||
"for ax, (key, kernel) in zip(axes[1], kernels.items()):\n", | ||
" value = kernel(None).evaluate(0, x[:, None])\n", | ||
" line, = axes[0, 0].plot(x, value, ls=\"--\")\n", | ||
" rfft = kernel(1).evaluate_rfft([x.size])\n", | ||
" value = np.fft.irfft(rfft, x.size)\n", | ||
" axes[0, 1].plot(rfft, label=key)\n", | ||
" axes[0, 0].plot(x, value, color=line.get_color())\n", | ||
"\n", | ||
" for maxf, ls in [(x.size // 2 + 1, \"-\"), (5, \"--\"), (3, \":\")]:\n", | ||
" rfft_scale = evaluate_rfft_scale(cov_rfft=rfft, size=x.size)\n", | ||
" rfft_scale[maxf:] = 0\n", | ||
" f = transform_irfft(z, np.zeros_like(z), rfft_scale=rfft_scale)\n", | ||
" ax.plot(x, f, ls=ls, color=line.get_color(), label=fr\"$\\xi_\\max={maxf}$\")\n", | ||
"\n", | ||
" ax.set_xlabel(\"position $x$\")\n", | ||
" ax.set_ylabel(f\"{key} GP $f$\")\n", | ||
"\n", | ||
"ax = axes[0, 0]\n", | ||
"ax.set_ylabel(\"kernel $k(0,x)$\")\n", | ||
"ax.set_xlabel(\"position $x$\")\n", | ||
"ax.legend([\n", | ||
" mpl.lines.Line2D([], [], ls=\"--\", color=\"gray\"),\n", | ||
" mpl.lines.Line2D([], [], ls=\"-\", color=\"gray\"),\n", | ||
"], [\"standard\", \"periodic\"], fontsize=\"small\")\n", | ||
"ax.text(0.05, 0.05, \"(a)\", transform=ax.transAxes)\n", | ||
"ax.yaxis.set_ticks([0, 0.5, 1])\n", | ||
"\n", | ||
"ax = axes[0, 1]\n", | ||
"ax.set_yscale(\"log\")\n", | ||
"ax.set_ylim(1e-5, x.size)\n", | ||
"ax.set_xlabel(r\"frequency $\\xi$\")\n", | ||
"ax.set_ylabel(r\"Fourier kernel $\\tilde k=\\phi\\left(k\\right)$\")\n", | ||
"ax.legend(fontsize=\"small\", loc=\"center right\")\n", | ||
"ax.text(0.95, 0.95, \"(b)\", transform=ax.transAxes, ha=\"right\", va=\"top\")\n", | ||
"\n", | ||
"ax = axes[1, 0]\n", | ||
"ax.legend(fontsize=\"small\", loc=\"lower center\")\n", | ||
"ax.text(0.95, 0.95, \"(c)\", transform=ax.transAxes, ha=\"right\", va=\"top\")\n", | ||
"\n", | ||
"ax = axes[1, 1]\n", | ||
"ax.legend(fontsize=\"small\", loc=\"lower center\")\n", | ||
"ax.sharey(axes[1, 0])\n", | ||
"ax.text(0.95, 0.95, \"(d)\", transform=ax.transAxes, ha=\"right\", va=\"top\")\n", | ||
"\n", | ||
"for ax in [axes[0, 0], *axes[1]]:\n", | ||
" ax.xaxis.set_ticks([0, 0.5, 1])\n", | ||
"\n", | ||
"fig.tight_layout()\n", | ||
"fig.savefig(\"kernels.pdf\", bbox_inches=\"tight\")\n", | ||
"fig.savefig(\"kernels.png\", bbox_inches=\"tight\")" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "eb2c936b", | ||
"metadata": {}, | ||
"source": [ | ||
"# Linear regression example from Section 2 of the manuscript" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "42ecaa12", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"\n", | ||
"np.random.seed(0)\n", | ||
"n = 100\n", | ||
"p = 3\n", | ||
"X = np.random.normal(0, 1, (n, p))\n", | ||
"theta = np.random.normal(0, 1, p)\n", | ||
"sigma = np.random.gamma(2, 2)\n", | ||
"y = np.random.normal(X @ theta, sigma)\n", | ||
"\n", | ||
"print(f\"coefficients: {theta}\")\n", | ||
"print(f\"observation noise scale: {sigma}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "07069f61", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import cmdstanpy\n", | ||
"\n", | ||
"model = cmdstanpy.CmdStanModel(stan_file=\"linear.stan\")\n", | ||
"fit = model.sample(data={\"n\": n, \"p\": p, \"X\": X, \"y\": y}, seed=0)\n", | ||
"\n", | ||
"print(fit.summary())" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Oops, something went wrong.