Skip to content
This repository has been archived by the owner on May 28, 2024. It is now read-only.

Commit

Permalink
Merge pull request #53 from jsadler2/45-baseline-states
Browse files Browse the repository at this point in the history
analysis of baseline states
  • Loading branch information
jsadler2 committed Mar 4, 2022
2 parents cee3da5 + 026781c commit 7c99e50
Show file tree
Hide file tree
Showing 8 changed files with 424 additions and 3 deletions.
21 changes: 20 additions & 1 deletion 1_fetch.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,26 @@ p1_targets_list <- list(
p1_daily_data,
get_daily_nwis_data(p1_nwis_sites_daily,pcode_select,stat_cd_select,start_date=earliest_date,end_date=dummy_date),
pattern = map(p1_nwis_sites_daily)),


# Download NWIS daily data for other parameters (flow, temperature, SC) (see codes below)
tar_target(
p1_daily_aux_data,
dataRetrieval::readNWISdv(
siteNumbers = p1_nwis_sites_daily$site_no,
parameterCd=c("00060", "00010", "00095"),
statCd=stat_cd_select,
startDate=earliest_date,
endDate=dummy_date) %>%
dataRetrieval::renameNWISColumns() %>%
select(!starts_with("..2..")),
pattern = map(p1_nwis_sites_daily)),

# Save daily aux data to csv
tar_target(
p1_daily_aux_csv,
write_to_csv(p1_daily_aux_data, outfile="1_fetch/out/daily_aux_data.csv"),
format = "file"),

# Save NWIS daily data
tar_target(
Expand Down Expand Up @@ -181,6 +201,5 @@ p1_targets_list <- list(
read_csv(p1_ntw_adj_matrix_csv,show_col_types = FALSE)
)


)

5 changes: 3 additions & 2 deletions 2a_model/src/models/0_baseline_LSTM/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ loss_function = lf.multitask_rmse(config['lambdas'])

rule all:
input:
expand("{outdir}/{metric_type}_metrics.csv",
expand("{outdir}/rep_{rep}/{metric_type}_metrics.csv",
outdir=out_dir,
metric_type=['overall', 'reach'],
rep=list(range(config['num_replicates'])),
)


Expand Down Expand Up @@ -124,7 +125,7 @@ def filter_predictions(all_preds_file, partition, out_file):
df_preds_val_sites = df_preds[df_preds.site_id.isin(config['validation_sites'])]


if partition == "train":
if partition == "trn":
df_preds_filt = df_preds_trn_sites[(df_preds_trn_sites.date >= config['train_start_date'][0]) &
(df_preds_trn_sites.date < config['train_end_date'][0])]
elif partition == "val":
Expand Down
107 changes: 107 additions & 0 deletions 2a_model/src/models/0_baseline_LSTM/analyze_states.smk
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
code_dir = '../river-dl'
import sys
sys.path.insert(0, code_dir)
# if using river_dl installed with pip this is not needed

from model import LSTMModelStates
from river_dl.postproc_utils import prepped_array_to_df
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd


out_dir = "../../../out/models/0_baseline_LSTM/analyze_states"
in_dir = "../../../out/models/0_baseline_LSTM"


def get_site_ids():
df = pd.read_csv(f"{in_dir}/rep_0/reach_metrics.csv", dtype={"site_id": str})
return df.site_id.unique()


rule all:
input:
expand("{outdir}/rep_{rep}/states_{trained_or_random}_{site_id}.png",
outdir=out_dir,
rep=list(range(6)),
trained_or_random = ["trained", "random"],
site_id = get_site_ids()),
expand("{outdir}/rep_{rep}/output_weights.jpg",
outdir=out_dir,
rep=list(range(6))),


model = LSTMModelStates(
config['hidden_size'],
recurrent_dropout=config['recurrent_dropout'],
dropout=config['dropout'],
num_tasks=len(config['y_vars'])
)


rule write_states:
input:
f"{in_dir}/rep_{{rep}}/prepped.npz",
f"{in_dir}/rep_{{rep}}/train_weights/",
output:
"{outdir}/rep_{rep}/states_{trained_or_random}.csv"
run:
data = np.load(input[0], allow_pickle=True)
if wildcards.trained_or_random == "trained":
model.load_weights(input[1] + "/")
states = model(data['x_val'])
states_df = prepped_array_to_df(states, data["times_val"], data["ids_val"],
col_names=[f"h{i}" for i in range(10)],
spatial_idx_name="site_id")
states_df["site_id"] = states_df["site_id"].astype(str)
states_df.to_csv(output[0], index=False)


rule plot_states:
input:
"{outdir}/states_{trained_or_random}.csv"
output:
"{outdir}/states_{trained_or_random}_{site_id}.png"
run:
df = pd.read_csv(input[0], parse_dates=["date"], infer_datetime_format=True, dtype={"site_id": str})
df_site = df.query(f"site_id == '{wildcards.site_id}'")
del df_site["site_id"]
df_site = df_site.set_index("date")
axs = df_site.plot(subplots=True, figsize=(8,10))
for ax in axs.flatten():
ax.legend(loc = "upper left")
plt.suptitle(wildcards.site_id)
plt.tight_layout()
plt.savefig(output[0])


rule plot_output_weights:
input:
f"{in_dir}/rep_{{rep}}/prepped.npz",
f"{in_dir}/rep_{{rep}}/train_weights/",
output:
"{outdir}/rep_{rep}/output_weights.jpg"
run:
data = np.load(input[0], allow_pickle=True)
m = LSTMModelStates(
config['hidden_size'],
recurrent_dropout=config['recurrent_dropout'],
dropout=config['dropout'],
num_tasks=len(config['y_vars'])
)
m.load_weights(input[1] + "/")
m(data['x_val'])
w = m.weights
ax = plt.imshow(w[3].numpy())
fig = plt.gcf()
cbar = fig.colorbar(ax)
cbar.set_label('weight value')
ax = plt.gca()
ax.set_yticks(list(range(10)))
ax.set_yticklabels(f"h{i}" for i in range(10))
ax.set_ylabel('hidden state')
ax.set_xticks(list(range(3)))
ax.set_xticklabels(["DO_max", "DO_mean", "DO_min"], rotation=90)
ax.set_xlabel('output variable')
plt.tight_layout()
plt.savefig(output[0], bbox_inches='tight')
1 change: 1 addition & 0 deletions 2a_model/src/models/0_baseline_LSTM/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ x_vars: ['seg_ccov', 'seg_rain', 'seg_slope', 'seg_tave_air', 'hru_slope', 'hru_


seed: False #random seed for training False==No seed, otherwise specify the seed
num_replicates: 6

y_vars: ['do_min', 'do_mean', 'do_max']

Expand Down
31 changes: 31 additions & 0 deletions 2a_model/src/models/0_baseline_LSTM/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,34 @@ def call(self, inputs):
prediction = self.dense(h)
return prediction


class LSTMModelStates(tf.keras.Model):
"""
LSTM model but returning states (h) instead of the predictions (y)
"""
def __init__(
self, hidden_size, num_tasks, recurrent_dropout=0, dropout=0,
):
"""
:param hidden_size: [int] the number of hidden units
:param num_tasks: [int] number of outputs to predict
:param recurrent_dropout: [float] value between 0 and 1 for the
probability of a recurrent element to be zero
:param dropout: [float] value between 0 and 1 for the probability of an
input element to be zero
"""
super().__init__()
self.rnn_layer = layers.LSTM(
hidden_size,
return_sequences=True,
recurrent_dropout=recurrent_dropout,
dropout=dropout,
)
self.dense = layers.Dense(num_tasks)

@tf.function
def call(self, inputs):
h = self.rnn_layer(inputs)
prediction = self.dense(h)
return h

57 changes: 57 additions & 0 deletions 3_visualize/src/plot_output_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# ---
# jupyter:
# jupytext:
# formats: ipynb,py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.13.7
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
# name: python3
# ---

# %%
import sys
import numpy as np
import matplotlib.pyplot as plt

# %%
sys.path.insert(0, "../../2a_model/src/models/0_baseline_LSTM/")

# %%
from model import LSTMModel

# %%
m = LSTMModel(10, 3)

# %%
m.load_weights("../../2a_model/out/models/0_baseline_LSTM/train_weights/")

# %%
data = np.load("../../2a_model/out/models/0_baseline_LSTM/prepped.npz", allow_pickle=True)

# %%
y = m(data['x_val'])

# %%
w = m.weights

# %%
ax = plt.imshow(w[3].numpy())
fig = plt.gcf()
cbar = fig.colorbar(ax)
cbar.set_label('weight value')
ax = plt.gca()
ax.set_yticks(list(range(10)))
ax.set_yticklabels(f"h{i}" for i in range(10))
ax.set_ylabel('hidden state')
ax.set_xticks(list(range(3)))
ax.set_xticklabels(["DO_max", "DO_mean", "DO_min"], rotation=90)
ax.set_xlabel('output variable')
plt.tight_layout()
plt.savefig('../out/hidden_states/out_weights.jpg', bbox_inches='tight')

# %%
127 changes: 127 additions & 0 deletions 3_visualize/src/plot_states_with_other_vars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# ---
# jupyter:
# jupytext:
# formats: ipynb,py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.13.7
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
# name: python3
# ---

# %%
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt

# %% [markdown]
# ## load states and aux data

# %%
df_states = pd.read_csv("../../2a_model/out/models/0_baseline_LSTM/analyze_states/rep_0/states_trained.csv",
dtype={"site_id": str}, parse_dates=["date"], infer_datetime_format=True)

# %%
df_aux = pd.read_csv("../../1_fetch/out/daily_aux_data.csv",
dtype={"site_no": str}, parse_dates=["Date"], infer_datetime_format=True)
df_aux = df_aux.rename(columns={"site_no": "site_id", "Date":"date"})

# %%
site_id = "01480870"

# %%
df_aux_site = df_aux.query(f"site_id == '{site_id}'").set_index('date')
df_states_site = df_states.query(f"site_id == '{site_id}'").set_index('date')

# %% [markdown]
# ## load input data

# %%
ds = xr.open_zarr("../../2a_model/out/well_observed_train_val_inputs.zarr/", consolidated=False)

# %%
df_air_temp = ds.seg_tave_air.sel(site_id=site_id).to_dataframe()

# %%
del df_air_temp['site_id']
del df_aux_site['site_id']
del df_states_site['site_id']

# %%
df_comb = df_states_site.join(df_aux_site).join(df_air_temp)

# %% [markdown]
# ___

# %% [markdown]
# # Comparison with Flow

# %%
axs = df_comb.loc[:, df_comb.columns.str.startswith('h')].plot(subplots=True, figsize=(16,20))
axs = axs.ravel()
for ax in axs:
ax.legend(loc="upper left")
ax_twin = ax.twinx()
df_comb.Flow.plot(ax=ax_twin, color="black", alpha=0.6)
ax_twin.set_ylabel('flow [cfs]')
plt.tight_layout()
plt.savefig("../out/states_with_flow.jpg")

# %%
axs = df_comb.loc[:, df_comb.columns.str.startswith('h0')].plot(subplots=True, figsize=(20,5))
axs = axs.ravel()
for ax in axs:
ax.legend(loc="upper left")
ax_twin = ax.twinx()
df_comb.Flow.plot(ax=ax_twin, color="darkgray")
ax_twin.set_ylabel('flow [cfs]')


# %%
def plot_one_state_w_flow(df_comb, state, color):
axs = df_comb.loc["2018", df_comb.columns.str.startswith(state)].plot(subplots=True, figsize=(20,5),
color=color, fontsize=20)
axs = axs.ravel()
for ax in axs:
ax.legend(loc="upper left", fontsize=20)
ax_twin = ax.twinx()
df_comb.loc["2018", "Flow"].plot(ax=ax_twin, color="black", alpha=0.6, fontsize=20)
ax_twin.set_ylabel('flow [cfs]', fontsize=20)
ax.set_xlabel('date', fontsize=20)
plt.tight_layout()
plt.savefig(f"../out/{state}_2018_w_flow.jpg")


# %%
plot_one_state_w_flow(df_comb, "h0", color="#1f77b4")

# %%
df_comb.plot.scatter('h0', 'Flow', alpha=0.5)
plt.tight_layout()
plt.savefig("../out/flow_h0_scatter.jpg")

# %%
plot_one_state_w_flow(df_comb, "h1", "#ff7f0e")

# %% [markdown]
# # Comparison with Temperature

# %%
axs = df_comb.loc[:, df_comb.columns.str.startswith('h')].plot(subplots=True, figsize=(16,20))
axs = axs.ravel()
for ax in axs:
ax.legend(loc="upper left")
ax_twin = ax.twinx()
df_comb.seg_tave_air.plot(ax=ax_twin, color="darkgray")
ax_twin.set_ylabel('avg air temp [degC]')
plt.tight_layout()
plt.savefig("../out/states_w_air_temp.jpg")

# %%
df_comb.tail()

# %%
Loading

0 comments on commit 7c99e50

Please sign in to comment.