Skip to content

Commit

Permalink
[USGS-R#45] add replicates to baseline model
Browse files Browse the repository at this point in the history
  • Loading branch information
jsadler2 committed Feb 23, 2022
1 parent 322d8cf commit 99d9adb
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 15 deletions.
7 changes: 3 additions & 4 deletions 2a_model/src/models/0_baseline_LSTM/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@ import os
import tensorflow as tf
import numpy as np
import pandas as pd
import inspect

import river_dl
from river_dl.preproc_utils import asRunConfig
from river_dl.preproc_utils import prep_all_data
from river_dl.evaluate import combined_metrics
Expand All @@ -24,9 +22,10 @@ loss_function = lf.multitask_rmse(config['lambdas'])

rule all:
input:
expand("{outdir}/{metric_type}_metrics.csv",
expand("{outdir}/rep_{rep}/{metric_type}_metrics.csv",
outdir=out_dir,
metric_type=['overall', 'reach'],
rep=list(range(config['num_replicates'])),
)


Expand Down Expand Up @@ -131,7 +130,7 @@ def filter_predictions(all_preds_file, partition, out_file):
df_preds_val_sites = df_preds[df_preds.site_id.isin(config['validation_sites'])]


if partition == "train":
if partition == "trn":
df_preds_filt = df_preds_trn_sites[(df_preds_trn_sites.date >= config['train_start_date'][0]) &
(df_preds_trn_sites.date < config['train_end_date'][0])]
elif partition == "val":
Expand Down
61 changes: 50 additions & 11 deletions 2a_model/src/models/0_baseline_LSTM/analyze_states.smk
Original file line number Diff line number Diff line change
@@ -1,29 +1,34 @@
code_dir = '../river-dl'
import sys
sys.path.insert(0, code_dir)
# if using river_dl installed with pip this is not needed

from model import LSTMModelStates
from river_dl.postproc_utils import prepped_array_to_df
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

code_dir = '../river-dl'
# if using river_dl installed with pip this is not needed
import sys
sys.path.insert(0, code_dir)

out_dir = "../../../out/models/0_baseline_LSTM/analyze_states"
in_dir = "../../../out/models/0_baseline_LSTM"


def get_site_ids():
df = pd.read_csv(f"{in_dir}/reach_metrics.csv", dtype={"site_id": str})
df = pd.read_csv(f"{in_dir}/rep_0/reach_metrics.csv", dtype={"site_id": str})
return df.site_id.unique()


rule all:
input:
expand("{outdir}/states_{trained_or_random}_{site_id}.png",
expand("{outdir}/rep_{rep}/states_{trained_or_random}_{site_id}.png",
outdir=out_dir,
rep=list(range(6)),
trained_or_random = ["trained", "random"],
site_id = get_site_ids())
site_id = get_site_ids()),
expand("{outdir}/rep_{rep}/output_weights.jpg",
outdir=out_dir,
rep=list(range(6))),


model = LSTMModelStates(
Expand All @@ -36,10 +41,10 @@ model = LSTMModelStates(

rule write_states:
input:
f"{in_dir}/prepped.npz",
f"{in_dir}/train_weights/",
f"{in_dir}/rep_{{rep}}/prepped.npz",
f"{in_dir}/rep_{{rep}}/train_weights/",
output:
"{outdir}/states_{trained_or_random}.csv"
"{outdir}/rep_{rep}/states_{trained_or_random}.csv"
run:
data = np.load(input[0], allow_pickle=True)
if wildcards.trained_or_random == "trained":
Expand All @@ -62,7 +67,41 @@ rule plot_states:
df_site = df.query(f"site_id == '{wildcards.site_id}'")
del df_site["site_id"]
df_site = df_site.set_index("date")
df_site.plot(subplots=True, figsize=(8,10))
axs = df_site.plot(subplots=True, figsize=(8,10))
for ax in axs.flatten():
ax.legend(loc = "upper left")
plt.suptitle(wildcards.site_id)
plt.tight_layout()
plt.savefig(output[0])


rule plot_output_weights:
input:
f"{in_dir}/rep_{{rep}}/prepped.npz",
f"{in_dir}/rep_{{rep}}/train_weights/",
output:
"{outdir}/rep_{rep}/output_weights.jpg"
run:
data = np.load(input[0], allow_pickle=True)
m = LSTMModelStates(
config['hidden_size'],
recurrent_dropout=config['recurrent_dropout'],
dropout=config['dropout'],
num_tasks=len(config['y_vars'])
)
m.load_weights(input[1] + "/")
m(data['x_val'])
w = m.weights
ax = plt.imshow(w[3].numpy())
fig = plt.gcf()
cbar = fig.colorbar(ax)
cbar.set_label('weight value')
ax = plt.gca()
ax.set_yticks(list(range(10)))
ax.set_yticklabels(f"h{i}" for i in range(10))
ax.set_ylabel('hidden state')
ax.set_xticks(list(range(3)))
ax.set_xticklabels(["DO_max", "DO_mean", "DO_min"], rotation=90)
ax.set_xlabel('output variable')
plt.tight_layout()
plt.savefig(output[0], bbox_inches='tight')
1 change: 1 addition & 0 deletions 2a_model/src/models/0_baseline_LSTM/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ x_vars: ['seg_ccov', 'seg_rain', 'seg_slope', 'seg_tave_air', 'hru_slope', 'hru_


seed: False #random seed for training False==No seed, otherwise specify the seed
num_replicates: 6

y_vars: ['do_min', 'do_mean', 'do_max']

Expand Down

0 comments on commit 99d9adb

Please sign in to comment.