Skip to content

Commit

Permalink
[USGS-R#184] initial commit FP bar plots
Browse files Browse the repository at this point in the history
  • Loading branch information
jsadler2 committed Sep 1, 2023
1 parent e3c73bc commit a6713ea
Showing 1 changed file with 65 additions and 0 deletions.
65 changes: 65 additions & 0 deletions 3_visualize/src/python_scripts/plot_func_performance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# ---
# jupyter:
# jupytext:
# text_representation:
# extension: .py
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.14.4
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
# name: python3
# ---

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt


def get_diff_df(df, it_metric):
"""
Parameters
---
it_metric : str
which IT metric you want the difference for (e.g., 'TE1')
"""
df_piv = df.pivot(columns='model', index=['sink', 'replicate', 'site'], values=[it_metric])

df_piv.columns = df_piv.columns.get_level_values(1)

df_diff_vals = df_piv.values - df_piv[['observed']].values

df_diff = pd.DataFrame(df_diff_vals, index=df_piv.index, columns=df_piv.columns)

del df_diff['observed']

return df_diff


df = pd.read_csv("../../../2a_model/out/models/combined_FP_metrics.csv")

df_diff = get_diff_df(df, 'TE1')

df_diff_long = df_diff.reset_index().melt(id_vars=df_diff.index.names)

df_diff_long = df_diff_long.replace({"0_baseline_LSTM": "baseline", "2_multitask_dense": "multitask dense"})

plt.rcParams.update({'font.size': 14})
g = sns.catplot(x='site', y='value', hue='model', col='sink', kind='bar', data=df_diff_long, legend=False, col_order=['do_min', 'do_mean', 'do_max'])
g.set_xticklabels(rotation=90)
g.set_ylabels('Deviation from \noptimal functional performance')
g.set_titles('{col_name}')
plt.legend(loc="lower left", bbox_to_anchor=(1.05, 0), title='Model')
plt.tight_layout()
plt.savefig("../../out/func_perf/func_performance_site_tmmx.png")
plt.clf()

fig, ax = plt.subplots(figsize=(6,4))
ax = sns.barplot(x='sink', y='value', hue='model', data=df_diff_long, order=['do_min', 'do_mean', 'do_max'], ax=ax)
ax.set_ylabel('Deviation from \noptimal functional performance')
ax.set_xlabel('')
plt.tight_layout()
plt.savefig("../../out/func_perf/func_performance_overall_tmmx.png")


0 comments on commit a6713ea

Please sign in to comment.