Skip to content
This repository has been archived by the owner on May 28, 2024. It is now read-only.

analysis of baseline states #53

Merged
merged 5 commits into from
Mar 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion 1_fetch.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,26 @@ p1_targets_list <- list(
p1_daily_data,
get_daily_nwis_data(p1_nwis_sites_daily,pcode_select,stat_cd_select,start_date=earliest_date,end_date=dummy_date),
pattern = map(p1_nwis_sites_daily)),


# Download NWIS daily data for other parameters (flow, temperature, SC) (see codes below)
tar_target(
p1_daily_aux_data,
dataRetrieval::readNWISdv(
siteNumbers = p1_nwis_sites_daily$site_no,
parameterCd=c("00060", "00010", "00095"),
statCd=stat_cd_select,
startDate=earliest_date,
endDate=dummy_date) %>%
dataRetrieval::renameNWISColumns() %>%
select(!starts_with("..2..")),
pattern = map(p1_nwis_sites_daily)),

# Save daily aux data to csv
tar_target(
p1_daily_aux_csv,
write_to_csv(p1_daily_aux_data, outfile="1_fetch/out/daily_aux_data.csv"),
format = "file"),

# Save NWIS daily data
tar_target(
Expand Down Expand Up @@ -181,6 +201,5 @@ p1_targets_list <- list(
read_csv(p1_ntw_adj_matrix_csv,show_col_types = FALSE)
)


)

5 changes: 3 additions & 2 deletions 2a_model/src/models/0_baseline_LSTM/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ loss_function = lf.multitask_rmse(config['lambdas'])

rule all:
input:
expand("{outdir}/{metric_type}_metrics.csv",
expand("{outdir}/rep_{rep}/{metric_type}_metrics.csv",
outdir=out_dir,
metric_type=['overall', 'reach'],
rep=list(range(config['num_replicates'])),
)


Expand Down Expand Up @@ -124,7 +125,7 @@ def filter_predictions(all_preds_file, partition, out_file):
df_preds_val_sites = df_preds[df_preds.site_id.isin(config['validation_sites'])]


if partition == "train":
if partition == "trn":
df_preds_filt = df_preds_trn_sites[(df_preds_trn_sites.date >= config['train_start_date'][0]) &
(df_preds_trn_sites.date < config['train_end_date'][0])]
elif partition == "val":
Expand Down
107 changes: 107 additions & 0 deletions 2a_model/src/models/0_baseline_LSTM/analyze_states.smk
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
code_dir = '../river-dl'
import sys
sys.path.insert(0, code_dir)
# if using river_dl installed with pip this is not needed

from model import LSTMModelStates
from river_dl.postproc_utils import prepped_array_to_df
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd


out_dir = "../../../out/models/0_baseline_LSTM/analyze_states"
in_dir = "../../../out/models/0_baseline_LSTM"


def get_site_ids():
df = pd.read_csv(f"{in_dir}/rep_0/reach_metrics.csv", dtype={"site_id": str})
return df.site_id.unique()


rule all:
input:
expand("{outdir}/rep_{rep}/states_{trained_or_random}_{site_id}.png",
outdir=out_dir,
rep=list(range(6)),
trained_or_random = ["trained", "random"],
site_id = get_site_ids()),
expand("{outdir}/rep_{rep}/output_weights.jpg",
outdir=out_dir,
rep=list(range(6))),


model = LSTMModelStates(
config['hidden_size'],
recurrent_dropout=config['recurrent_dropout'],
dropout=config['dropout'],
num_tasks=len(config['y_vars'])
)


rule write_states:
input:
f"{in_dir}/rep_{{rep}}/prepped.npz",
f"{in_dir}/rep_{{rep}}/train_weights/",
output:
"{outdir}/rep_{rep}/states_{trained_or_random}.csv"
run:
data = np.load(input[0], allow_pickle=True)
if wildcards.trained_or_random == "trained":
model.load_weights(input[1] + "/")
states = model(data['x_val'])
states_df = prepped_array_to_df(states, data["times_val"], data["ids_val"],
col_names=[f"h{i}" for i in range(10)],
spatial_idx_name="site_id")
states_df["site_id"] = states_df["site_id"].astype(str)
states_df.to_csv(output[0], index=False)


rule plot_states:
input:
"{outdir}/states_{trained_or_random}.csv"
output:
"{outdir}/states_{trained_or_random}_{site_id}.png"
run:
df = pd.read_csv(input[0], parse_dates=["date"], infer_datetime_format=True, dtype={"site_id": str})
df_site = df.query(f"site_id == '{wildcards.site_id}'")
del df_site["site_id"]
df_site = df_site.set_index("date")
axs = df_site.plot(subplots=True, figsize=(8,10))
for ax in axs.flatten():
ax.legend(loc = "upper left")
plt.suptitle(wildcards.site_id)
plt.tight_layout()
plt.savefig(output[0])


rule plot_output_weights:
input:
f"{in_dir}/rep_{{rep}}/prepped.npz",
f"{in_dir}/rep_{{rep}}/train_weights/",
output:
"{outdir}/rep_{rep}/output_weights.jpg"
run:
data = np.load(input[0], allow_pickle=True)
m = LSTMModelStates(
config['hidden_size'],
recurrent_dropout=config['recurrent_dropout'],
dropout=config['dropout'],
num_tasks=len(config['y_vars'])
)
m.load_weights(input[1] + "/")
m(data['x_val'])
w = m.weights
ax = plt.imshow(w[3].numpy())
fig = plt.gcf()
cbar = fig.colorbar(ax)
cbar.set_label('weight value')
ax = plt.gca()
ax.set_yticks(list(range(10)))
ax.set_yticklabels(f"h{i}" for i in range(10))
ax.set_ylabel('hidden state')
ax.set_xticks(list(range(3)))
ax.set_xticklabels(["DO_max", "DO_mean", "DO_min"], rotation=90)
ax.set_xlabel('output variable')
plt.tight_layout()
plt.savefig(output[0], bbox_inches='tight')
1 change: 1 addition & 0 deletions 2a_model/src/models/0_baseline_LSTM/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ x_vars: ['seg_ccov', 'seg_rain', 'seg_slope', 'seg_tave_air', 'hru_slope', 'hru_


seed: False #random seed for training False==No seed, otherwise specify the seed
num_replicates: 6

y_vars: ['do_min', 'do_mean', 'do_max']

Expand Down
31 changes: 31 additions & 0 deletions 2a_model/src/models/0_baseline_LSTM/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,34 @@ def call(self, inputs):
prediction = self.dense(h)
return prediction


class LSTMModelStates(tf.keras.Model):
"""
LSTM model but returning states (h) instead of the predictions (y)
"""
def __init__(
self, hidden_size, num_tasks, recurrent_dropout=0, dropout=0,
):
"""
:param hidden_size: [int] the number of hidden units
:param num_tasks: [int] number of outputs to predict
:param recurrent_dropout: [float] value between 0 and 1 for the
probability of a recurrent element to be zero
:param dropout: [float] value between 0 and 1 for the probability of an
input element to be zero
"""
super().__init__()
self.rnn_layer = layers.LSTM(
hidden_size,
return_sequences=True,
recurrent_dropout=recurrent_dropout,
dropout=dropout,
)
self.dense = layers.Dense(num_tasks)

@tf.function
def call(self, inputs):
h = self.rnn_layer(inputs)
prediction = self.dense(h)
return h

57 changes: 57 additions & 0 deletions 3_visualize/src/plot_output_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# ---
# jupyter:
# jupytext:
# formats: ipynb,py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.13.7
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
# name: python3
# ---

# %%
import sys
import numpy as np
import matplotlib.pyplot as plt

# %%
sys.path.insert(0, "../../2a_model/src/models/0_baseline_LSTM/")

# %%
from model import LSTMModel

# %%
m = LSTMModel(10, 3)

# %%
m.load_weights("../../2a_model/out/models/0_baseline_LSTM/train_weights/")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little confused where these weights come from

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are weights produced by the snakemake workflow:

directory("{outdir}/train_weights/"),


# %%
data = np.load("../../2a_model/out/models/0_baseline_LSTM/prepped.npz", allow_pickle=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also confused here as to what you are loading, are these input data to use for plotting

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes these are input data. They are used to just run the model once so I can get the weight values. ... now that I think of it, I'm not 100% sure this step is necessary.


# %%
y = m(data['x_val'])

# %%
w = m.weights

# %%
ax = plt.imshow(w[3].numpy())
fig = plt.gcf()
cbar = fig.colorbar(ax)
cbar.set_label('weight value')
ax = plt.gca()
ax.set_yticks(list(range(10)))
ax.set_yticklabels(f"h{i}" for i in range(10))
ax.set_ylabel('hidden state')
ax.set_xticks(list(range(3)))
ax.set_xticklabels(["DO_max", "DO_mean", "DO_min"], rotation=90)
ax.set_xlabel('output variable')
plt.tight_layout()
plt.savefig('../out/hidden_states/out_weights.jpg', bbox_inches='tight')

# %%
127 changes: 127 additions & 0 deletions 3_visualize/src/plot_states_with_other_vars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# ---
# jupyter:
# jupytext:
# formats: ipynb,py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.13.7
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
# name: python3
# ---

# %%
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt

# %% [markdown]
# ## load states and aux data

# %%
df_states = pd.read_csv("../../2a_model/out/models/0_baseline_LSTM/analyze_states/rep_0/states_trained.csv",
dtype={"site_id": str}, parse_dates=["date"], infer_datetime_format=True)

# %%
df_aux = pd.read_csv("../../1_fetch/out/daily_aux_data.csv",
dtype={"site_no": str}, parse_dates=["Date"], infer_datetime_format=True)
df_aux = df_aux.rename(columns={"site_no": "site_id", "Date":"date"})

# %%
site_id = "01480870"

# %%
df_aux_site = df_aux.query(f"site_id == '{site_id}'").set_index('date')
df_states_site = df_states.query(f"site_id == '{site_id}'").set_index('date')

# %% [markdown]
# ## load input data

# %%
ds = xr.open_zarr("../../2a_model/out/well_observed_train_val_inputs.zarr/", consolidated=False)

# %%
df_air_temp = ds.seg_tave_air.sel(site_id=site_id).to_dataframe()

# %%
del df_air_temp['site_id']
del df_aux_site['site_id']
del df_states_site['site_id']

# %%
df_comb = df_states_site.join(df_aux_site).join(df_air_temp)

# %% [markdown]
# ___

# %% [markdown]
# # Comparison with Flow

# %%
axs = df_comb.loc[:, df_comb.columns.str.startswith('h')].plot(subplots=True, figsize=(16,20))
axs = axs.ravel()
for ax in axs:
ax.legend(loc="upper left")
ax_twin = ax.twinx()
df_comb.Flow.plot(ax=ax_twin, color="black", alpha=0.6)
ax_twin.set_ylabel('flow [cfs]')
plt.tight_layout()
plt.savefig("../out/states_with_flow.jpg")

# %%
axs = df_comb.loc[:, df_comb.columns.str.startswith('h0')].plot(subplots=True, figsize=(20,5))
axs = axs.ravel()
for ax in axs:
ax.legend(loc="upper left")
ax_twin = ax.twinx()
df_comb.Flow.plot(ax=ax_twin, color="darkgray")
ax_twin.set_ylabel('flow [cfs]')


# %%
def plot_one_state_w_flow(df_comb, state, color):
axs = df_comb.loc["2018", df_comb.columns.str.startswith(state)].plot(subplots=True, figsize=(20,5),
color=color, fontsize=20)
axs = axs.ravel()
for ax in axs:
ax.legend(loc="upper left", fontsize=20)
ax_twin = ax.twinx()
df_comb.loc["2018", "Flow"].plot(ax=ax_twin, color="black", alpha=0.6, fontsize=20)
ax_twin.set_ylabel('flow [cfs]', fontsize=20)
ax.set_xlabel('date', fontsize=20)
plt.tight_layout()
plt.savefig(f"../out/{state}_2018_w_flow.jpg")


# %%
plot_one_state_w_flow(df_comb, "h0", color="#1f77b4")

# %%
df_comb.plot.scatter('h0', 'Flow', alpha=0.5)
plt.tight_layout()
plt.savefig("../out/flow_h0_scatter.jpg")

# %%
plot_one_state_w_flow(df_comb, "h1", "#ff7f0e")

# %% [markdown]
# # Comparison with Temperature

# %%
axs = df_comb.loc[:, df_comb.columns.str.startswith('h')].plot(subplots=True, figsize=(16,20))
axs = axs.ravel()
for ax in axs:
ax.legend(loc="upper left")
ax_twin = ax.twinx()
df_comb.seg_tave_air.plot(ax=ax_twin, color="darkgray")
ax_twin.set_ylabel('avg air temp [degC]')
plt.tight_layout()
plt.savefig("../out/states_w_air_temp.jpg")

# %%
df_comb.tail()

# %%
Loading