Skip to content

Commit

Permalink
[USGS-R#184] centralizing util fxns in plot_utils,
Browse files Browse the repository at this point in the history
revising plotting pred performance
  • Loading branch information
jsadler2 committed Jul 31, 2023
1 parent d819ed8 commit 5670265
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 168 deletions.
51 changes: 23 additions & 28 deletions 3_visualize/src/python_scripts/plot_func_performance.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# ---
# jupyter:
# jupytext:
# formats: ipynb,py:light
# text_representation:
# extension: .py
# format_name: light
Expand All @@ -12,12 +13,22 @@
# name: python3
# ---

import os
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import plot_utils

# +
run_id = 215

outdir = f"../../out/func_perf/{run_id}"

if not os.path.exists(outdir):
os.makedirs(outdir)


# -

def get_diff_df(df, it_metric):
"""
Expand All @@ -26,7 +37,8 @@ def get_diff_df(df, it_metric):
it_metric : str
which IT metric you want the difference for (e.g., 'TE1')
"""
df_piv = df.pivot(columns='model', index=['sink', 'replicate', 'site'], values=[it_metric])
print(df.head())
df_piv = df.pivot(columns='model', index=['holdout', 'sink', 'replicate', 'site'], values=[it_metric])

df_piv.columns = df_piv.columns.get_level_values(1)

Expand All @@ -48,38 +60,21 @@ def get_diff_df(df, it_metric):

df_diff_long = df_diff.reset_index().melt(id_vars=df_diff.index.names)

# df_diff_long = df_diff_long.replace({"0_baseline_LSTM": "baseline", "2_multitask_dense": "multitask dense"})
df_diff_long = df_diff_long.rename(columns={'site': 'site_id'})

d = plot_utils.make_holdout_id_col(df_diff_long)

df_diff_long['holdout_id'].unique()

######## Barplot by site ######################################################
diff_temporal = df_diff_long[df_diff_long['holdout_id'] == 'temporal']
plt.rcParams.update({'font.size': 14})
g = sns.catplot(x='site', y='value', hue='model', col='sink', kind='bar',
data=df_diff_long, legend=False,
col_order=['do_min', 'do_mean', 'do_max'],
hue_order=plot_utils.model_order
g = sns.catplot(x='site_id', y='value', hue='model', col='sink', kind='bar',
data=diff_temporal, legend=False,
col_order=['do_min', 'do_max'],
hue_order=plot_utils.model_order,
col_wrap=1, aspect=3
)
g.set_xticklabels(rotation=90)
g.set_ylabels('Deviation from \noptimal functional performance')
g.set_titles('{col_name}')
for site_id, ax in g.axes_dict.items():
plot_utils.mark_val_sites(ax)

plt.legend(loc="lower left", bbox_to_anchor=(1.05, 0), title='Model')
plt.tight_layout()
plt.savefig("../../out/func_perf/func_performance_site_tmmx.png")
plt.clf()

######## Barplot overall ######################################################
fig, ax = plt.subplots(figsize=(6,4))

ax = sns.barplot(x='sink', y='value', hue='model', data=df_diff_long,
order=['do_min', 'do_mean', 'do_max'], ax=ax,
hue_order=plot_utils.model_order)

ax.set_ylabel('Deviation from \noptimal functional performance')
ax.set_xlabel('')
plt.legend(loc="lower right", title='Model')
plt.tight_layout()
plt.savefig("../../out/func_perf/func_performance_overall_tmmx.png")


227 changes: 87 additions & 140 deletions 3_visualize/src/python_scripts/plot_pred_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
# ---

import pandas as pd
import re
import os
import seaborn as sns
import matplotlib.pyplot as plt
import plot_utils
from plot_utils import read_and_filter_df
from plot_utils import read_and_filter_df, make_holdout_id_col, filter_out_urban_spatial, replacements, model_labels
import numpy as np
import seaborn.objects as so

Expand All @@ -32,202 +33,148 @@
os.makedirs(outdir)
# -


models = [
"0_baseline_LSTM",
"1_metab_multitask",
"1a_multitask_gpp_er",
"2_multitask_dense",
]

df_comb_reach = read_and_filter_df("reach", "val")
df_comb_reach = df_comb_reach.replace(
{
"0_baseline_LSTM": "Baseline",
"1a_multitask_do_gpp_er": "Metab Multitask - GPP, ER",
"1_metab_multitask": "Metab Multitask",
"2_multitask_dense": "Metab Dependent",
}
)

df_comb_reach.holdout.unique()

df_comb_reach = df_comb_reach.replace(replacements)

test_sites_urban = ["01475530", "01475548"]

# -

def define_group(row):
if row["holdout"] != "temporal" and row["holdout"] != "1_urban":
return "spatial non-urban"
elif row["holdout"] == "temporal":
if row['site_id'] in test_sites_urban:
return "temporal urban"
else:
return "temporal non-urban"
elif row["holdout"] == '1_urban':
return "spatial one-urban"


df_comb_reach["holdout_id"] = df_comb_reach.apply(define_group, axis=1)
df_comb_reach = make_holdout_id_col(df_comb_reach)

df_comb_reach.holdout_id.unique()
df_reach_filt = filter_out_urban_spatial(df_comb_reach)


def plot_by_site_or_holdout(data, x, kind, outfile):
def plot_by_site_or_holdout(data, x, kind, outfile,
col_order=['do_min', 'do_max'],
order=None):
plt.rcParams.update({'font.size': 14})
g = sns.catplot(
x=x,
y="rmse",
col="variable",
data=data,
hue="model_id",
kind=kind,
legend=False,
errorbar="sd",
col_order=["do_min", "do_max"],
dodge=True
col_order=col_order,
dodge=True,
hue_order=model_labels,
order=order,
)
g.set_xticklabels(rotation=90)
g.set_xticklabels(rotation=45)
g.set_titles('{col_name}')
# g.set_ylabels(

for i, ax in enumerate(g.axes.flatten()):
ax.grid()
ax.set_axisbelow(True)

g.axes.flatten()[0].set_ylabel("RMSE (mg O2/l)")
ax.legend(bbox_to_anchor=(1.05, 0.55))
plt.tight_layout()
plt.show()
plt.savefig(os.path.join(outdir, outfile), dpi=300)
plt.clf()
sns.move_legend(g, loc='lower left', bbox_to_anchor=(0.8, 0.1))
# plt.tight_layout()
plt.savefig(os.path.join(outdir, outfile), bbox_inches='tight', dpi=300)
return g
# plt.show()
# plt.clf()


# +
# g.set_xlabels?
# -

######## Barplot by site (temporal)############################################
df_comb_reach_temporal = df_comb_reach[df_comb_reach['holdout_id'] == 'temporal non-urban']
plot_by_site_or_holdout(df_comb_reach_temporal, "site_id", "bar", "val_results_by_site.jpg")
df_comb_reach_temporal = df_comb_reach[df_comb_reach["holdout_id"] == "temporal"]
plot_by_site_or_holdout(
df_comb_reach_temporal, "site_id", "bar", "val_results_by_site.jpg"
)

df_temp_spatial = df_comb_reach[
df_comb_reach["site_id"].isin(df_comb_reach_temporal.site_id.unique())
]

sns.catplot(
x="holdout_id",
y="rmse",
row="site_id",
col="variable",
hue="model_id",
data=df_temp_spatial,
kind="bar",
col_order=["do_min", "do_max"],
)
plt.savefig(os.path.join(outdir, "temporal_vs_spatial_holdouts.jpg"), dpi=300)

######## stripplot by site (temporal)############################################
df_comb_reach_temporal = df_comb_reach[df_comb_reach['holdout_id'] == 'temporal non-urban']
plot_by_site_or_holdout(df_comb_reach_temporal, "site_id", "strip", "val_results_by_site_strip.jpg")
######## stripplot by site (temporal)###########################################
df_comb_reach_temporal = df_comb_reach[df_comb_reach["holdout_id"] == "temporal"]
plot_by_site_or_holdout(
df_comb_reach_temporal, "site_id", "strip", "val_results_by_site_strip.jpg"
)


######## stripplot by site (temporal)############################################
df_comb_reach_spatial = df_comb_reach[df_comb_reach['holdout_id'] == 'spatial non-urban']
plot_by_site_or_holdout(df_comb_reach_spatial, "site_id", "strip", "val_results_by_site_strip.jpg")
######## stripplot by site (spatial)############################################
df_comb_reach_spatial = df_comb_reach[df_comb_reach["holdout_id"] == "spatial similar"]
plot_by_site_or_holdout(
df_comb_reach_spatial, "site_id", "strip", "val_results_by_site_strip_spatial.jpg"
)


######## Barplot by holdout ######################################################
plot_by_site_or_holdout(df_comb_reach, "holdout_id", "bar", "val_results_by_holdout.jpg")
g = plot_by_site_or_holdout(
df_reach_filt,
"holdout_id",
"bar",
"val_results_by_holdout.jpg",
order=["temporal", 'spatial similar', 'spatial dissimilar'],
col_order=['do_min', 'do_max']
)
for ax in g.axes.flatten():
ax.set_ylabel('')


# +
######## Stripplot by holdout ####################################################
plot_by_site_or_holdout(df_comb_reach, "holdout_id", "strip", "val_results_by_holdout_strip.jpg")


######## Barplot overall ######################################################
df_comb = read_and_filter_df("overall", "val")

# +
fig, ax = plt.subplots(figsize=(17, 6))
ax = sns.barplot(
x="variable", y="rmse", data=df_comb, hue="model_id", ax=ax
) # , hue_order=['0_baseline_LSTM', '1a_multitask_do_gpp_er', '1_metab_multitask', '2_multitask_dense'])
for c in ax.containers:
ax.bar_label(c, label_type="center", fmt="%.2f")

ax.set_xlabel("")
plt.legend(loc="lower left", bbox_to_anchor=(1.05, 0), title="Model")
plt.tight_layout()
plt.savefig(os.path.join(outdir, "val_results_overall.png"), dpi=300)


######## Barplot calculating site metrics then averaging #########################################
fig, ax = plt.subplots(figsize=(15, 4))
ax = sns.barplot(
x="variable",
y="rmse",
hue="model_id",
data=df_comb_reach,
order=["do_min", "do_max"],
plot_by_site_or_holdout(
df_reach_filt, "holdout_id", "strip", "val_results_by_holdout_strip.jpg"
)

for c in ax.containers:
ax.bar_label(c, label_type="center", fmt="%.2f")

ax.set_ylabel("RMSE (mg O2/l)")

plt.legend(loc="lower left", bbox_to_anchor=(1.05, 0), title="Model")
plt.tight_layout()
plt.savefig(os.path.join(outdir, "val_results_overall_avg_across_sites.jpg"), dpi=300)

# -

# +
######## Barplot calculating site metrics then median-ing #########################################
fig, ax = plt.subplots(figsize=(17, 3))
ax = sns.barplot(
x="variable",
y="rmse",
hue="model_id",
data=df_comb_reach,
hue_order=None,
estimator=np.median,
)
month_order = [10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9]

for c in ax.containers:
ax.bar_label(c, label_type="center", fmt="%.2f")
df_comb_month = read_and_filter_df("month_reach", "val")

plt.legend(loc="lower left", bbox_to_anchor=(1.05, 0), title="Model")
plt.savefig(os.path.join(outdir, "val_results_overall_median_across_sites.jpg"), dpi=300)
# -
df_comb_month = make_holdout_id_col(df_comb_month)
df_comb_month = df_comb_month.replace(replacements)

df_comb_month = read_and_filter_df("month", "val")
df_comb_month = df_comb_month[df_comb_month['holdout_id'] == 'temporal']

month_order = [9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8]

# +
######## Barplot by month #########################################
g = sns.catplot(
######## Lineplot by month #########################################
g = sns.relplot(
x="date",
y="rmse",
col="variable",
data=df_comb_month,
hue="model_id",
kind="strip",
legend=False,
ci="sd",
dodge=True,
# hue_order=models,
kind="line",
legend=True,
# ci="sd",
# dodge=True,
# order=month_order,
hue_order=model_labels,
col_order=['do_min', 'do_max']
)

for ax in g.axes.flatten():
ax.grid()
ax.set_axisbelow(True)
ax.set_xticks(list(range(1,13)))

sns.move_legend(g, loc='lower left', bbox_to_anchor=(0.83, 0.1))

g.set_xlabels("month")
plt.legend(bbox_to_anchor=(1.05, 0.55))
plt.tight_layout()
plt.savefig(os.path.join(outdir, "val_results_by_month_strip.jpg"), dpi=300)
plt.clf()
# plt.tight_layout()
plt.savefig(os.path.join(outdir, "val_results_by_month_line.jpg"), dpi=300)
# -

df_comb_month.sort_values("rmse").iloc[-1]

df_comb_month[df_comb_month["date"] == 12].sort_values("rmse").iloc[-1]

df_2 = df_comb_month[
(df_comb_month["rep_id"] == 2) & (df_comb_month["model_id"].str.startswith("2"))
]
df_2_pg = df_comb_month[df_comb_month["model_id"].str.startswith("2")]

sns.barplot(x="date", y="rmse", hue="variable", data=df_2)

g = sns.catplot(
x="date", y="rmse", hue="rep_id", col="variable", data=df_2_pg, col_wrap=1, aspect=2
)
for ax in g.axes.flatten():
ax.grid()
ax.set_axisbelow(True)
g.set_xlabels("month")
plt.tight_layout()
plt.savefig(os.path.join(outdir, "2_monthly_performance.jpg"), dpi=300)
Loading

0 comments on commit 5670265

Please sign in to comment.