From 25856e72d7a28ab10cd2346fb8a604fcfc10327d Mon Sep 17 00:00:00 2001 From: jsadler2 Date: Mon, 14 Feb 2022 15:54:17 -0500 Subject: [PATCH 1/5] [#45] code for looking at model states --- .../models/0_baseline_LSTM/analyze_states.smk | 68 +++++++++++++++++++ 2a_model/src/models/0_baseline_LSTM/model.py | 31 +++++++++ 2 files changed, 99 insertions(+) create mode 100644 2a_model/src/models/0_baseline_LSTM/analyze_states.smk diff --git a/2a_model/src/models/0_baseline_LSTM/analyze_states.smk b/2a_model/src/models/0_baseline_LSTM/analyze_states.smk new file mode 100644 index 00000000..5e31605d --- /dev/null +++ b/2a_model/src/models/0_baseline_LSTM/analyze_states.smk @@ -0,0 +1,68 @@ +from model import LSTMModelStates +from river_dl.postproc_utils import prepped_array_to_df +import numpy as np +import matplotlib.pyplot as plt +import pandas as pd + +code_dir = '../river-dl' +# if using river_dl installed with pip this is not needed +import sys +sys.path.insert(0, code_dir) + +out_dir = "../../../out/models/0_baseline_LSTM/analyze_states" +in_dir = "../../../out/models/0_baseline_LSTM" + + +def get_site_ids(): + df = pd.read_csv(f"{in_dir}/reach_metrics.csv", dtype={"site_id": str}) + return df.site_id.unique() + + +rule all: + input: + expand("{outdir}/states_{trained_or_random}_{site_id}.png", + outdir=out_dir, + trained_or_random = ["trained", "random"], + site_id = get_site_ids()) + + +model = LSTMModelStates( + config['hidden_size'], + recurrent_dropout=config['recurrent_dropout'], + dropout=config['dropout'], + num_tasks=len(config['y_vars']) +) + + +rule write_states: + input: + f"{in_dir}/prepped.npz", + f"{in_dir}/train_weights/", + output: + "{outdir}/states_{trained_or_random}.csv" + run: + data = np.load(input[0], allow_pickle=True) + if wildcards.trained_or_random == "trained": + model.load_weights(input[1] + "/") + states = model(data['x_val']) + states_df = prepped_array_to_df(states, data["times_val"], data["ids_val"], + col_names=[f"h{i}" for i in range(10)], + spatial_idx_name="site_id") + states_df["site_id"] = states_df["site_id"].astype(str) + states_df.to_csv(output[0], index=False) + + +rule plot_states: + input: + "{outdir}/states_{trained_or_random}.csv" + output: + "{outdir}/states_{trained_or_random}_{site_id}.png" + run: + df = pd.read_csv(input[0], parse_dates=["date"], infer_datetime_format=True, dtype={"site_id": str}) + df_site = df.query(f"site_id == '{wildcards.site_id}'") + del df_site["site_id"] + df_site = df_site.set_index("date") + df_site.plot(subplots=True, figsize=(8,10)) + plt.tight_layout() + plt.savefig(output[0]) + diff --git a/2a_model/src/models/0_baseline_LSTM/model.py b/2a_model/src/models/0_baseline_LSTM/model.py index 741f1df0..53813f82 100644 --- a/2a_model/src/models/0_baseline_LSTM/model.py +++ b/2a_model/src/models/0_baseline_LSTM/model.py @@ -29,3 +29,34 @@ def call(self, inputs): prediction = self.dense(h) return prediction + +class LSTMModelStates(tf.keras.Model): + """ + LSTM model but returning states (h) instead of the predictions (y) + """ + def __init__( + self, hidden_size, num_tasks, recurrent_dropout=0, dropout=0, + ): + """ + :param hidden_size: [int] the number of hidden units + :param num_tasks: [int] number of outputs to predict + :param recurrent_dropout: [float] value between 0 and 1 for the + probability of a recurrent element to be zero + :param dropout: [float] value between 0 and 1 for the probability of an + input element to be zero + """ + super().__init__() + self.rnn_layer = layers.LSTM( + hidden_size, + return_sequences=True, + recurrent_dropout=recurrent_dropout, + dropout=dropout, + ) + self.dense = layers.Dense(num_tasks) + + @tf.function + def call(self, inputs): + h = self.rnn_layer(inputs) + prediction = self.dense(h) + return h + From 4bebd778e877864bfe63da776caa592054a76e24 Mon Sep 17 00:00:00 2001 From: jsadler2 Date: Wed, 23 Feb 2022 10:40:24 -0600 Subject: [PATCH 2/5] [#45] py code for plotting hidden states --- 3_visualize/src/plot_output_weights.py | 57 ++++++++ .../src/plot_states_with_other_vars.py | 127 ++++++++++++++++++ 2 files changed, 184 insertions(+) create mode 100644 3_visualize/src/plot_output_weights.py create mode 100644 3_visualize/src/plot_states_with_other_vars.py diff --git a/3_visualize/src/plot_output_weights.py b/3_visualize/src/plot_output_weights.py new file mode 100644 index 00000000..4def0223 --- /dev/null +++ b/3_visualize/src/plot_output_weights.py @@ -0,0 +1,57 @@ +# --- +# jupyter: +# jupytext: +# formats: ipynb,py:percent +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.13.7 +# kernelspec: +# display_name: Python 3 (ipykernel) +# language: python +# name: python3 +# --- + +# %% +import sys +import numpy as np +import matplotlib.pyplot as plt + +# %% +sys.path.insert(0, "../../2a_model/src/models/0_baseline_LSTM/") + +# %% +from model import LSTMModel + +# %% +m = LSTMModel(10, 3) + +# %% +m.load_weights("../../2a_model/out/models/0_baseline_LSTM/train_weights/") + +# %% +data = np.load("../../2a_model/out/models/0_baseline_LSTM/prepped.npz", allow_pickle=True) + +# %% +y = m(data['x_val']) + +# %% +w = m.weights + +# %% +ax = plt.imshow(w[3].numpy()) +fig = plt.gcf() +cbar = fig.colorbar(ax) +cbar.set_label('weight value') +ax = plt.gca() +ax.set_yticks(list(range(10))) +ax.set_yticklabels(f"h{i}" for i in range(10)) +ax.set_ylabel('hidden state') +ax.set_xticks(list(range(3))) +ax.set_xticklabels(["DO_max", "DO_mean", "DO_min"], rotation=90) +ax.set_xlabel('output variable') +plt.tight_layout() +plt.savefig('../out/hidden_states/out_weights.jpg', bbox_inches='tight') + +# %% diff --git a/3_visualize/src/plot_states_with_other_vars.py b/3_visualize/src/plot_states_with_other_vars.py new file mode 100644 index 00000000..2412ead8 --- /dev/null +++ b/3_visualize/src/plot_states_with_other_vars.py @@ -0,0 +1,127 @@ +# --- +# jupyter: +# jupytext: +# formats: ipynb,py:percent +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.13.7 +# kernelspec: +# display_name: Python 3 (ipykernel) +# language: python +# name: python3 +# --- + +# %% +import pandas as pd +import xarray as xr +import matplotlib.pyplot as plt + +# %% [markdown] +# ## load states and aux data + +# %% +df_states = pd.read_csv("../../2a_model/out/models/0_baseline_LSTM/analyze_states/rep_0/states_trained.csv", + dtype={"site_id": str}, parse_dates=["date"], infer_datetime_format=True) + +# %% +df_aux = pd.read_csv("../../1_fetch/out/daily_aux_data.csv", + dtype={"site_no": str}, parse_dates=["Date"], infer_datetime_format=True) +df_aux = df_aux.rename(columns={"site_no": "site_id", "Date":"date"}) + +# %% +site_id = "01480870" + +# %% +df_aux_site = df_aux.query(f"site_id == '{site_id}'").set_index('date') +df_states_site = df_states.query(f"site_id == '{site_id}'").set_index('date') + +# %% [markdown] +# ## load input data + +# %% +ds = xr.open_zarr("../../2a_model/out/well_observed_train_val_inputs.zarr/", consolidated=False) + +# %% +df_air_temp = ds.seg_tave_air.sel(site_id=site_id).to_dataframe() + +# %% +del df_air_temp['site_id'] +del df_aux_site['site_id'] +del df_states_site['site_id'] + +# %% +df_comb = df_states_site.join(df_aux_site).join(df_air_temp) + +# %% [markdown] +# ___ + +# %% [markdown] +# # Comparison with Flow + +# %% +axs = df_comb.loc[:, df_comb.columns.str.startswith('h')].plot(subplots=True, figsize=(16,20)) +axs = axs.ravel() +for ax in axs: + ax.legend(loc="upper left") + ax_twin = ax.twinx() + df_comb.Flow.plot(ax=ax_twin, color="black", alpha=0.6) + ax_twin.set_ylabel('flow [cfs]') + plt.tight_layout() + plt.savefig("../out/states_with_flow.jpg") + +# %% +axs = df_comb.loc[:, df_comb.columns.str.startswith('h0')].plot(subplots=True, figsize=(20,5)) +axs = axs.ravel() +for ax in axs: + ax.legend(loc="upper left") + ax_twin = ax.twinx() + df_comb.Flow.plot(ax=ax_twin, color="darkgray") + ax_twin.set_ylabel('flow [cfs]') + + +# %% +def plot_one_state_w_flow(df_comb, state, color): + axs = df_comb.loc["2018", df_comb.columns.str.startswith(state)].plot(subplots=True, figsize=(20,5), + color=color, fontsize=20) + axs = axs.ravel() + for ax in axs: + ax.legend(loc="upper left", fontsize=20) + ax_twin = ax.twinx() + df_comb.loc["2018", "Flow"].plot(ax=ax_twin, color="black", alpha=0.6, fontsize=20) + ax_twin.set_ylabel('flow [cfs]', fontsize=20) + ax.set_xlabel('date', fontsize=20) + plt.tight_layout() + plt.savefig(f"../out/{state}_2018_w_flow.jpg") + + +# %% +plot_one_state_w_flow(df_comb, "h0", color="#1f77b4") + +# %% +df_comb.plot.scatter('h0', 'Flow', alpha=0.5) +plt.tight_layout() +plt.savefig("../out/flow_h0_scatter.jpg") + +# %% +plot_one_state_w_flow(df_comb, "h1", "#ff7f0e") + +# %% [markdown] +# # Comparison with Temperature + +# %% +axs = df_comb.loc[:, df_comb.columns.str.startswith('h')].plot(subplots=True, figsize=(16,20)) +axs = axs.ravel() +for ax in axs: + ax.legend(loc="upper left") + ax_twin = ax.twinx() + df_comb.seg_tave_air.plot(ax=ax_twin, color="darkgray") + ax_twin.set_ylabel('avg air temp [degC]') + plt.tight_layout() + plt.savefig("../out/states_w_air_temp.jpg") + +# %% +df_comb.tail() + +# %% From c56f35ea3c84a9dd2a1115952454079888d96918 Mon Sep 17 00:00:00 2001 From: jsadler2 Date: Wed, 23 Feb 2022 10:47:36 -0600 Subject: [PATCH 3/5] [#45] add replicates to baseline model --- 2a_model/src/models/0_baseline_LSTM/Snakefile | 5 +- .../models/0_baseline_LSTM/analyze_states.smk | 61 +++++++++++++++---- .../src/models/0_baseline_LSTM/config.yml | 1 + 3 files changed, 54 insertions(+), 13 deletions(-) diff --git a/2a_model/src/models/0_baseline_LSTM/Snakefile b/2a_model/src/models/0_baseline_LSTM/Snakefile index 0ec50458..8b48256b 100644 --- a/2a_model/src/models/0_baseline_LSTM/Snakefile +++ b/2a_model/src/models/0_baseline_LSTM/Snakefile @@ -17,9 +17,10 @@ loss_function = lf.multitask_rmse(config['lambdas']) rule all: input: - expand("{outdir}/{metric_type}_metrics.csv", + expand("{outdir}/rep_{rep}/{metric_type}_metrics.csv", outdir=out_dir, metric_type=['overall', 'reach'], + rep=list(range(config['num_replicates'])), ) @@ -124,7 +125,7 @@ def filter_predictions(all_preds_file, partition, out_file): df_preds_val_sites = df_preds[df_preds.site_id.isin(config['validation_sites'])] - if partition == "train": + if partition == "trn": df_preds_filt = df_preds_trn_sites[(df_preds_trn_sites.date >= config['train_start_date'][0]) & (df_preds_trn_sites.date < config['train_end_date'][0])] elif partition == "val": diff --git a/2a_model/src/models/0_baseline_LSTM/analyze_states.smk b/2a_model/src/models/0_baseline_LSTM/analyze_states.smk index 5e31605d..56756b6e 100644 --- a/2a_model/src/models/0_baseline_LSTM/analyze_states.smk +++ b/2a_model/src/models/0_baseline_LSTM/analyze_states.smk @@ -1,29 +1,34 @@ +code_dir = '../river-dl' +import sys +sys.path.insert(0, code_dir) +# if using river_dl installed with pip this is not needed + from model import LSTMModelStates from river_dl.postproc_utils import prepped_array_to_df import numpy as np import matplotlib.pyplot as plt import pandas as pd -code_dir = '../river-dl' -# if using river_dl installed with pip this is not needed -import sys -sys.path.insert(0, code_dir) out_dir = "../../../out/models/0_baseline_LSTM/analyze_states" in_dir = "../../../out/models/0_baseline_LSTM" def get_site_ids(): - df = pd.read_csv(f"{in_dir}/reach_metrics.csv", dtype={"site_id": str}) + df = pd.read_csv(f"{in_dir}/rep_0/reach_metrics.csv", dtype={"site_id": str}) return df.site_id.unique() rule all: input: - expand("{outdir}/states_{trained_or_random}_{site_id}.png", + expand("{outdir}/rep_{rep}/states_{trained_or_random}_{site_id}.png", outdir=out_dir, + rep=list(range(6)), trained_or_random = ["trained", "random"], - site_id = get_site_ids()) + site_id = get_site_ids()), + expand("{outdir}/rep_{rep}/output_weights.jpg", + outdir=out_dir, + rep=list(range(6))), model = LSTMModelStates( @@ -36,10 +41,10 @@ model = LSTMModelStates( rule write_states: input: - f"{in_dir}/prepped.npz", - f"{in_dir}/train_weights/", + f"{in_dir}/rep_{{rep}}/prepped.npz", + f"{in_dir}/rep_{{rep}}/train_weights/", output: - "{outdir}/states_{trained_or_random}.csv" + "{outdir}/rep_{rep}/states_{trained_or_random}.csv" run: data = np.load(input[0], allow_pickle=True) if wildcards.trained_or_random == "trained": @@ -62,7 +67,41 @@ rule plot_states: df_site = df.query(f"site_id == '{wildcards.site_id}'") del df_site["site_id"] df_site = df_site.set_index("date") - df_site.plot(subplots=True, figsize=(8,10)) + axs = df_site.plot(subplots=True, figsize=(8,10)) + for ax in axs.flatten(): + ax.legend(loc = "upper left") + plt.suptitle(wildcards.site_id) plt.tight_layout() plt.savefig(output[0]) + +rule plot_output_weights: + input: + f"{in_dir}/rep_{{rep}}/prepped.npz", + f"{in_dir}/rep_{{rep}}/train_weights/", + output: + "{outdir}/rep_{rep}/output_weights.jpg" + run: + data = np.load(input[0], allow_pickle=True) + m = LSTMModelStates( + config['hidden_size'], + recurrent_dropout=config['recurrent_dropout'], + dropout=config['dropout'], + num_tasks=len(config['y_vars']) + ) + m.load_weights(input[1] + "/") + m(data['x_val']) + w = m.weights + ax = plt.imshow(w[3].numpy()) + fig = plt.gcf() + cbar = fig.colorbar(ax) + cbar.set_label('weight value') + ax = plt.gca() + ax.set_yticks(list(range(10))) + ax.set_yticklabels(f"h{i}" for i in range(10)) + ax.set_ylabel('hidden state') + ax.set_xticks(list(range(3))) + ax.set_xticklabels(["DO_max", "DO_mean", "DO_min"], rotation=90) + ax.set_xlabel('output variable') + plt.tight_layout() + plt.savefig(output[0], bbox_inches='tight') diff --git a/2a_model/src/models/0_baseline_LSTM/config.yml b/2a_model/src/models/0_baseline_LSTM/config.yml index 411c1ae8..f950940a 100644 --- a/2a_model/src/models/0_baseline_LSTM/config.yml +++ b/2a_model/src/models/0_baseline_LSTM/config.yml @@ -6,6 +6,7 @@ x_vars: ['seg_ccov', 'seg_rain', 'seg_slope', 'seg_tave_air', 'hru_slope', 'hru_ seed: False #random seed for training False==No seed, otherwise specify the seed +num_replicates: 6 y_vars: ['do_min', 'do_mean', 'do_max'] From 08febe227f98899186c2a3d177a6c13da01c00d9 Mon Sep 17 00:00:00 2001 From: jsadler2 Date: Wed, 23 Feb 2022 15:11:57 -0600 Subject: [PATCH 4/5] [#45], [#48] pull aux (flow, water temp) data --- 1_fetch.R | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/1_fetch.R b/1_fetch.R index 7754eecd..527729d0 100644 --- a/1_fetch.R +++ b/1_fetch.R @@ -49,6 +49,26 @@ p1_targets_list <- list( p1_daily_data, get_daily_nwis_data(p1_nwis_sites_daily,pcode_select,stat_cd_select,start_date=earliest_date,end_date=dummy_date), pattern = map(p1_nwis_sites_daily)), + + + # Download NWIS daily data for other parameters (flow, temperature, SC) (see codes below) + tar_target( + p1_daily_aux_data, + dataRetrieval::readNWISdv( + siteNumbers = p1_nwis_sites_daily$site_no, + parameterCd=c("00060", "00010", "00095"), + statCd=stat_cd_select, + startDate=earliest_date, + endDate=dummy_date) %>% + dataRetrieval::renameNWISColumns() %>% + select(!starts_with("..2..")), + pattern = map(p1_nwis_sites_daily)), + + # Save daily aux data to csv + tar_target( + p1_daily_aux_csv, + write_to_csv(p1_daily_aux_data, outfile="1_fetch/out/daily_aux_data.csv"), + format = "file"), # Save NWIS daily data tar_target( @@ -181,6 +201,5 @@ p1_targets_list <- list( read_csv(p1_ntw_adj_matrix_csv,show_col_types = FALSE) ) - ) From 026781c379979033728d81c8b6756a998a880a8c Mon Sep 17 00:00:00 2001 From: jsadler2 Date: Wed, 23 Feb 2022 15:12:52 -0600 Subject: [PATCH 5/5] [#48] plot water temp vs do preds/obs --- 3_visualize/src/temp_vs_do_preds.py | 78 +++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 3_visualize/src/temp_vs_do_preds.py diff --git a/3_visualize/src/temp_vs_do_preds.py b/3_visualize/src/temp_vs_do_preds.py new file mode 100644 index 00000000..f197144f --- /dev/null +++ b/3_visualize/src/temp_vs_do_preds.py @@ -0,0 +1,78 @@ +# --- +# jupyter: +# jupytext: +# formats: ipynb,py:percent +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.13.7 +# kernelspec: +# display_name: Python 3 (ipykernel) +# language: python +# name: python3 +# --- + +# %% +import pandas as pd +import seaborn as sns +import xarray as xr +import matplotlib.pyplot as plt + +# %% +df = pd.read_csv("../../2_process/out/daily_water_temp.csv", dtype={"site_no": str}, parse_dates=["Date"], infer_datetime_format=True) +df = df.rename(columns={"site_no":"site_id", "Date":"date"}) + +# %% +df_aux = pd.read_csv("../../1_fetch/out/daily_aux_data.csv", + dtype={"site_no": str}, parse_dates=["Date"], infer_datetime_format=True) + +# %% +df_preds = pd.read_feather("../../2a_model/out/models/0_baseline_LSTM/rep_0/val_times_preds.feather") + +# %% +df_aux = df_aux.rename(columns={"site_no": "site_id", "Date": "date"}).set_index(["site_id", "date"]) + +# %% +df.set_index(['site_id', 'date'], inplace=True) +df_preds.set_index(['site_id', 'date'], inplace=True) +df_comb = df_preds.join(df_aux) + +# %% +temp_do_pair_counts = (df_comb['Wtemp'] + df_comb['do_mean']).groupby('site_id').count() + +# %% +temp_do_pair_counts + +# %% +sites_w_temp_do_pairs = temp_do_pair_counts[temp_do_pair_counts > 0].index + +# %% +df_comb = df_comb.reset_index() +df_comb = df_comb[df_comb['site_id'].isin(sites_w_temp_do_pairs)] + +# %% +ds_obs_do = xr.open_zarr("../../2a_model/out/well_observed_train_val_do.zarr/", consolidated=False) + +# %% +df_obs_do = ds_obs_do.do_mean.to_dataframe() + +# %% +df_comb = df_comb.set_index(['site_id', 'date']).join(df_obs_do, lsuffix="_pred", rsuffix="_obs").reset_index() +df_comb = df_comb.rename(columns = {"do_mean_pred": "pred", "do_mean_obs": "obs"}) + +# %% +df_comb.columns + +# %% +df_mlt = df_comb.melt(id_vars=['site_id', 'date', 'Wtemp'], value_vars=['pred', 'obs'], var_name="Pred or obs") + +# %% +sns.set(font_scale=1.5, style="whitegrid") +fg = sns.relplot(x="Wtemp", y="value", col="site_id", data=df_mlt, hue="Pred or obs", col_wrap=3, kind='scatter', alpha=0.5) +fg.set_xlabels("Observed daily mean \n water temperature (deg C)") +fg.set_ylabels("Predicted or observed daily \n mean DO concentration (mg/l)") +# plt.tight_layout() +plt.savefig("../out/do_preds_vs_temp.png", dpi=300) + +# %%