Skip to content

Commit

Permalink
[USGS-R#184] pred perf plot uses combined files
Browse files Browse the repository at this point in the history
  • Loading branch information
jsadler2 committed Jul 31, 2023
1 parent a411fa7 commit bbf0c28
Showing 1 changed file with 18 additions and 61 deletions.
79 changes: 18 additions & 61 deletions 3_visualize/src/python_scripts/plot_pred_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,99 +19,56 @@
validation_sites = ["01472104", "01473500", "01481500"]
test_sites = ["01475530", "01475548"]

models = ["0_baseline_LSTM", "2_multitask_dense"]


def read_and_combine_dfs(model_ids, metric_type, partition, parent_dir):
f_name = "{p}{model}/exp_{metric_type}_metrics.csv"
dfs = []
for m in model_ids:
df = pd.read_csv(f_name.format(p=parent_dir, model=m, metric_type=metric_type), dtype={'site_id':str})
df['model'] = m
dfs.append(df)

df_comb = pd.concat(dfs)
def read_and_filter_df(metric_type, partition):
f_name = f"../../../2a_model/out/models/combined_{metric_type}_metrics.csv"
df_comb = pd.read_csv(f_name, dtype={"site_id": str})
df_comb = df_comb[df_comb['partition'] == partition]
df_comb = df_comb[df_comb['variable'].str.startswith('do')]
df_comb = df_comb[df_comb['rmse'].notna()]
return df_comb


df_comb_reach_new = read_and_combine_dfs(models, 'reach', 'val', "../../")
df_comb_reach_new['type'] = 'new inputs'

df_comb_reach_old = read_and_combine_dfs(models, "reach", 'val', "archive_221215/")
df_comb_reach_old['type'] = 'old inputs'

df_comb_reach = pd.concat([df_comb_reach_new, df_comb_reach_old])

# +
g = sns.catplot(x='site_id', y='rmse', row='model', col='variable', data=df_comb_reach, hue='type', kind='bar', legend=False, ci='sd')
g.set_xticklabels(rotation=90)
for ax in g.axes.flatten():
ax.grid()
ax.set_axisbelow(True)

plt.legend(bbox_to_anchor=(1.05, .55))
plt.tight_layout()
plt.savefig("figs/val_results_new_inputs.png")
df_comb_reach = read_and_filter_df("reach", "val")

# +
g = sns.catplot(x='site_id', y='rmse', col='variable', data=df_comb_reach_new, hue='model', kind='bar', legend=False, ci='sd')
g = sns.catplot(x='site_id', y='rmse', col='variable', data=df_comb_reach, hue='model_id', kind='bar', legend=False, ci='sd')
g.set_xticklabels(rotation=90)
for ax in g.axes.flatten():
ax.grid()
ax.set_axisbelow(True)

plt.legend(bbox_to_anchor=(1.05, .55))
plt.tight_layout()
plt.savefig("figs/val_results_new_inputs.png")
plt.savefig("../../out/pred_perf/val_results_by_site.png")
plt.clf()
# -

g=sns.catplot(x='site_id', y='rmse', hue='model', col='variable', col_wrap=3, data=df_comb_reach, dodge=True, legend=False)
g=sns.catplot(x='site_id', y='rmse', hue='model_id', col='variable', col_wrap=3, data=df_comb_reach, dodge=True, legend=False)
g.set_xticklabels(rotation=90)
for ax in g.axes.flatten():
ax.grid()
plt.legend(bbox_to_anchor=(1.05, .55))
plt.tight_layout()
plt.savefig('val_results_strip.png')

# +
g=sns.catplot(x='variable', y='rmse', hue='model', col='site_id', col_wrap=3, data=df_comb_reach, dodge=True, legend=False)
for site_id, ax in g.axes_dict.items():
ax.grid()
if site_id in validation_sites:
ax.text(1, 3.2, "**Validation Site**", ha='center')


plt.legend(bbox_to_anchor=(1.5, 1.15))
plt.legend(bbox_to_anchor=(1.05, .55))
plt.tight_layout()
plt.savefig('val_results_strip.png')
# -

df_comb = read_and_combine_dfs(models, 'overall', 'val', "archive_221215/")

df_comb_new = read_and_combine_dfs(models, 'overall', 'val', "./")
df_comb_new['type'] = 'new inputs'
plt.savefig('../../out/pred_perf/val_results_by_site_strip.png')
plt.clf()

df_comb_old = read_and_combine_dfs(models, "overall", 'val', "archive_221215/")
df_comb_old['type'] = 'old inputs'

df_comb = pd.concat([df_comb_new, df_comb_old])
# -

g = sns.catplot(x='variable', y='rmse', data=df_comb, hue='type', col='model', kind="bar")
for ax in g.axes.flatten():
ax.bar_label(ax.containers[0], label_type="center")
ax.bar_label(ax.containers[1], label_type="center")
plt.savefig("figs/val_results_overall_new_inputs.png")
df_comb = read_and_filter_df('overall', 'val')

# +
ax = sns.barplot(x='variable', y='rmse', data=df_comb_new, hue='model')
ax.bar_label(ax.containers[0], label_type="center")
ax.bar_label(ax.containers[1], label_type="center")
fig, ax = plt.subplots(figsize=(6,4))
ax = sns.barplot(x='variable', y='rmse', data=df_comb, hue='model_id', ax=ax)
ax.bar_label(ax.containers[0], label_type="center", fmt='%.2f')
ax.bar_label(ax.containers[1], label_type="center", fmt='%.2f')

# plt.tight_layout()
plt.savefig("figs/val_results_overall.png")
plt.savefig("../../out/pred_perf/val_results_overall.png")
# -


0 comments on commit bbf0c28

Please sign in to comment.