Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Comparison notebook between deep learning and regression-based forecasts #8

Merged
merged 14 commits into from
Jun 30, 2023
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
## Recipe for building and executing your workflow with `s2spy` suite.
This repo provides several tutorial notebooks showing how [`s2spy`](https://github.com/AI4S2S/s2spy) and [`lilio`](https://github.com/AI4S2S/lilio) can faciliate your data-driven (sub)seasonal (S2S) forecasts workflow.

<img src="./assets/concept_test_case.png" alt="usecase" width="600"/>

## Basic workflow
Here is an example of a basic data-driven S2S forecasts workflow for regression modelling with [`s2spy`](https://github.com/AI4S2S/s2spy) and [`lilio`](https://github.com/AI4S2S/lilio).

Expand All @@ -25,11 +27,9 @@ This workflow is illustrated below:

Similarly, you can adapt this recipe to your deep learning workflow with a few changes. You can find several examples in the next section.

## install dependencies

## Tutorial notebooks

The tutorial notebooks include a case study in which we attempt to predict surface temperature over US using the SST over Pacific. We use processed ERA5 fields to perform data-driven forecasts. More details about the data can be found in this [README.md](./data/README.md).
The tutorial notebooks include a case study in which we attempt to predict surface temperature over US using the SST over Pacific. We use processed ERA5 fields to perform data-driven forecasts. More details about the data can be found in this [README.md](./data/README.md).

Before playing with these notebooks, please make sure that you have all the dependent packages installed. You can simply install the dependencies by go to this repo and run the following command:
```sh
Expand All @@ -43,3 +43,4 @@ Predict surface temperature over US with SST over Pacific with [`s2spy`](https:/
- [Data-driven S2S forecasts using LSTM network](./workflow/pred_temperature_LSTM.ipynb)
- [Data-driven S2S forecasts using autoencoder network](./workflow/pred_temperature_autoencoder.ipynb)
- [Data-driven S2S forecasts using transformer with multi-head attention](./workflow/pred_temperature_transformer.ipynb)
- [Data-driven S2S forecasts using LSTM network with linear regression as baseline](./workflow/comp_pred_ridge_and_LSTM.ipynb)
Binary file added assets/concept_test_case.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2,196 changes: 2,196 additions & 0 deletions workflow/comp_pred_ridge_and_LSTM.ipynb

Large diffs are not rendered by default.

20 changes: 15 additions & 5 deletions workflow/pred_temperature_LSTM.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,19 @@
"This notebook serves as an example of a basic workflow of data driven forecasting using deep learning with `s2spy` & `lilio` packages. <br>\n",
"We will predict temperature in US at seasonal time scales using ERA5 dataset with LSTM network. <br>\n",
"\n",
"<img src=\"../assets/concept_test_case.png\" alt=\"usecase\" width=\"500\"/>"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"This recipe includes the following steps:\n",
"- Define a calendar (`lilio`)\n",
"- Download/load input data (`era5cli`) (TBA)\n",
"- Download/load input data (`era5cli`) (test data, accessible via `era5cli`)\n",
"- Map the calendar to the data (`lilio`)\n",
"- Train-validate-test split (60%/20%/20%) (`torch`)\n",
"- Train-validate-test split (60%/20%/20%)\n",
"- Preprocessing based on the training set (`s2spy`)\n",
"- Resample data to the calendar (`lilio`)\n",
"- Create LSTM model (`torch`)\n",
Expand All @@ -29,7 +37,7 @@
"source": [
"The workflow is illustrated below:\n",
"\n",
"![Transformer](../assets/dl.PNG)"
"<img src=\"../assets/dl.PNG\" alt=\"Transformer\" width=\"900\"/>"
]
},
{
Expand Down Expand Up @@ -438,9 +446,11 @@
"#### Build LSTM model\n",
"Build a LSTM model with `nn.LSTM` module.\n",
"\n",
"The architecture of the autoencoder used here is shown in the figure below. (source of image: https://colah.github.io/posts/2015-08-Understanding-LSTMs/)\n",
"The architecture of the autoencoder used here is shown in the figure below.\n",
"\n",
"<img src=\"../assets/lstm.png\" alt=\"LSTM\" width=\"500\"/>\n",
"\n",
"![lstm](../assets/lstm.png)"
"(source of image: https://colah.github.io/posts/2015-08-Understanding-LSTMs/)"
]
},
{
Expand Down
16 changes: 12 additions & 4 deletions workflow/pred_temperature_autoencoder.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,19 @@
"This notebook serves as an example of a basic workflow of data driven forecasting using deep learning with `s2spy` & `lilio` packages. <br>\n",
"We will predict temperature in US at seasonal time scales using ERA5 dataset with multi-head attention autoencoder. <br>\n",
"\n",
"<img src=\"../assets/concept_test_case.png\" alt=\"usecase\" width=\"500\"/>"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"This recipe includes the following steps:\n",
"- Define a calendar (`lilio`)\n",
"- Download/load input data (`era5cli`) (TBA)\n",
"- Download/load input data (test data, accessible via `era5cli`)\n",
"- Map the calendar to the data (`lilio`)\n",
"- Train-validate-test split (60%/20%/20%) (`torch`)\n",
"- Train-validate-test split (60%/20%/20%)\n",
"- Preprocessing based on the training set (`s2spy`)\n",
"- Resample data to the calendar (`lilio`)\n",
"- Create autoencoder model (`torch`)\n",
Expand All @@ -29,7 +37,7 @@
"source": [
"The workflow is illustrated below:\n",
"\n",
"![Transformer](../assets/dl.PNG)"
"<img src=\"../assets/dl.PNG\" alt=\"Transformer\" width=\"900\"/>"
]
},
{
Expand Down Expand Up @@ -576,7 +584,7 @@
"\n",
"The architecture of the autoencoder used here is shown in the figure below. This structure is very similar to the famous language model called BERT. For more details about the full transformer network structure, check the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805).\n",
"\n",
"![architecture](../assets/bert.png)"
"<img src=\"../assets/bert.png\" alt=\"BERT\" width=\"500\"/>"
]
},
{
Expand Down
18 changes: 13 additions & 5 deletions workflow/pred_temperature_ridge.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,17 @@
"This notebook serves as an example of a basic workflow of data driven forecasting using machine learning with `s2spy` & `lilio` packages. <br>\n",
"We will predict temperature in US at seasonal time scales using ERA5 dataset with linear regression (Ridge). <br>\n",
"\n",
"<img src=\"../assets/concept_test_case.png\" alt=\"usecase\" width=\"500\"/>"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"This recipe includes the following steps:\n",
"- Define a calendar (`lilio`)\n",
"- Download/load input data (`era5cli`) (TBA)\n",
"- Download/load input data (test data, accessible via `era5cli`)\n",
"- Map the calendar to the data (`lilio`)\n",
"- Train-test split (70%/30%)\n",
"- Preprocessing based on the training set (`s2spy`)\n",
Expand All @@ -26,7 +34,7 @@
"source": [
"The workflow is illustrated below:\n",
"\n",
"![Ridge](../assets/regression.PNG)"
"<img src=\"../assets/regression.PNG\" alt=\"Ridge\" width=\"900\"/>"
]
},
{
Expand Down Expand Up @@ -488,16 +496,16 @@
" clusters_test = rgdr.transform(x_test)\n",
" # train model\n",
" ridge = Ridge(alpha=1.0)\n",
" model = ridge.fit(clusters_train.isel(i_interval=0), y_train.isel(i_interval=1))\n",
" model = ridge.fit(clusters_train.isel(i_interval=0), y_train.sel(i_interval=1))\n",
" # save model\n",
" models.append(model)\n",
" # predict and save results\n",
" prediction = model.predict(clusters_test.isel(i_interval=0))\n",
" predictions.append(prediction)\n",
" # calculate and save rmse\n",
" rmse_train.append(mean_squared_error(y_train.isel(i_interval=1),\n",
" rmse_train.append(mean_squared_error(y_train.sel(i_interval=1),\n",
" model.predict(clusters_train.isel(i_interval=0))))\n",
" rmse_test.append(mean_squared_error(y_test.isel(i_interval=1),\n",
" rmse_test.append(mean_squared_error(y_test.sel(i_interval=1),\n",
" prediction))"
]
},
Expand Down
16 changes: 12 additions & 4 deletions workflow/pred_temperature_transformer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,19 @@
"This notebook serves as an example of a basic workflow of data driven forecasting using deep learning with `s2spy` & `lilio` packages. <br>\n",
"We will predict temperature in US at seasonal time scales using ERA5 dataset with multi-head attention transformer. <br>\n",
"\n",
"<img src=\"../assets/concept_test_case.png\" alt=\"usecase\" width=\"500\"/>"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"This recipe includes the following steps:\n",
"- Define a calendar (`lilio`)\n",
"- Download/load input data (`era5cli`) (TBA)\n",
"- Download/load input data (test data, accessible via `era5cli`)\n",
"- Map the calendar to the data (`lilio`)\n",
"- Train-validate-test split (60%/20%/20%) (`torch`)\n",
"- Train-validate-test split (60%/20%/20%)\n",
"- Preprocessing based on the training set (`s2spy`)\n",
"- Resample data to the calendar (`lilio`)\n",
"- Create transformer model (`torch`)\n",
Expand All @@ -29,7 +37,7 @@
"source": [
"The workflow is illustrated below:\n",
"\n",
"![Transformer](../assets/dl.PNG)"
"<img src=\"../assets/dl.PNG\" alt=\"Transformer\" width=\"900\"/>"
]
},
{
Expand Down Expand Up @@ -577,7 +585,7 @@
"\n",
"The architecture of the transformer is illustrated in the figure below, which is from the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762).\n",
"\n",
"![architecture](../assets/transformer.webp)"
"<img src=\"../assets/transformer.webp\" alt=\"Transformer\" width=\"500\"/>"
]
},
{
Expand Down