-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
66 changed files
with
10,329 additions
and
1 deletion.
There are no files selected for viewing
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
209 changes: 209 additions & 0 deletions
209
docs/docs/doctrees/nbsphinx/tutorials/README_example.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,209 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Quick start: Forecasting with synthetic data\n", | ||
"\n", | ||
"In this notebook, we train Treeffuser on synthethic data and then visualize both the original and model-generated samples to explore how well Treeffuser captures the underlying distribution of the data." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Getting started" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"We first install `treeffuser` and import the relevant libraries." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!pip install treeffuser\n", | ||
"\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import numpy as np\n", | ||
"\n", | ||
"from treeffuser import Treeffuser" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"We simulate a non-linear, bimodal response of $y$ given $x$, where the two modes follow two different response functions: one is a sine function and the other is a cosine function over $x$." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"seed = 0 # fixing the random seed for reproducibility\n", | ||
"n = 5000 # number of data points\n", | ||
"\n", | ||
"rng = np.random.default_rng(seed=seed)\n", | ||
"x = rng.uniform(0, 2 * np.pi, size=n) # x values in the range [0, 2π)\n", | ||
"z = rng.integers(0, 2, size=n) # response function assignments\n", | ||
"\n", | ||
"y = z * np.sin(x - np.pi / 2) + (1 - z) * np.cos(x)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"We also introduce heteroscedastic, fat-tailed noise from a Laplace distribution, meaning the variability of $y$ increases with $x$ and may result in large outliers." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"y += rng.laplace(scale=x / 30, size=n)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Fitting Treffuser and producing samples" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Fitting Treeffuser and generating samples is very simple, as Treeffuser adheres to the `sklearn.base.BaseEstimator` class. Fitting amounts to initializing the model and calling the `fit` method, just like any `scikit-learn` estimator. Samples are then generated using the `sample` method." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model = Treeffuser(sde_initialize_from_data=True, seed=seed)\n", | ||
"model.fit(x, y)\n", | ||
"\n", | ||
"y_samples = model.sample(x, n_samples=1, seed=seed, verbose=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Plotting the samples" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"We create a scatter plot to visualize both the original data and the samples produced by Treeffuser. The samples closely reflect the underlying response distributions that generated the data." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"plt.scatter(x, y, s=1, label=\"observed data\")\n", | ||
"plt.scatter(x, y_samples[0, :], s=1, alpha=0.7, label=\"Treeffuser samples\")\n", | ||
"\n", | ||
"plt.xlabel(\"$x$\")\n", | ||
"plt.ylabel(\"$y$\")\n", | ||
"\n", | ||
"legend = plt.legend(loc=\"upper center\", scatterpoints=1, bbox_to_anchor=(0.5, -0.125), ncol=2)\n", | ||
"for legend_handle in legend.legend_handles:\n", | ||
" legend_handle.set_sizes([32]) # change marker size for legend\n", | ||
"\n", | ||
"plt.tight_layout()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"The samples generated by Treeffuser can be used to compute any downstream estimates of interest." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"x = np.array(np.pi).reshape((1, 1))\n", | ||
"y_samples = model.sample(x, n_samples=100, verbose=True) # y_samples.shape[0] is 100\n", | ||
"\n", | ||
"# Estimate downstream quantities of interest\n", | ||
"y_mean = y_samples.mean(axis=0) # conditional mean for each x\n", | ||
"y_std = y_samples.std(axis=0) # conditional std for each x\n", | ||
"\n", | ||
"print(f\"Mean of the samples: {y_mean}\")\n", | ||
"print(f\"Standard deviation of the samples: {y_std} \")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"For convenience, we also provide a class Samples that can estimate standard quantities." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from treeffuser.samples import Samples\n", | ||
"\n", | ||
"y_samples = Samples(y_samples)\n", | ||
"y_mean = y_samples.sample_mean() # same as before\n", | ||
"y_std = y_samples.sample_std() # same as before\n", | ||
"y_quantiles = y_samples.sample_quantile(q=[0.05, 0.95]) # conditional quantiles for each x\n", | ||
"\n", | ||
"print(f\"Mean of the samples: {y_mean}\")\n", | ||
"print(f\"Standard deviation of the samples: {y_std} \")\n", | ||
"print(f\"5th and 95th quantiles of the samples: {y_quantiles.reshape(-1)}\")" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": ".venv", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Sphinx build info version 1 | ||
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. | ||
config: e4ac402ca58269d1b18eda45a326b05c | ||
tags: 645f666f9bcd5a90fca523b33c5a78b7 |
Empty file.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.