Skip to content

Commit

Permalink
[USGS-R#184] updates to python plot scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
jsadler2 committed Mar 8, 2023
1 parent 03fc3dd commit 5d38dad
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 43 deletions.
30 changes: 25 additions & 5 deletions 3_visualize/src/python_scripts/plot_func_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import plot_utils



def get_diff_df(df, it_metric):
Expand All @@ -37,28 +39,46 @@ def get_diff_df(df, it_metric):
return df_diff


df = pd.read_csv("../../../2a_model/out/models/combined_FP_metrics.csv")
df = pd.read_csv("../../../2a_model/out/models/combined_FP_metrics.csv",
dtype={"site": str})

fp_metric = 'TE1'
df_diff = get_diff_df(df,fp_metric)

df_diff = get_diff_df(df, 'TE1')

df_diff_long = df_diff.reset_index().melt(id_vars=df_diff.index.names)

df_diff_long = df_diff_long.replace({"0_baseline_LSTM": "baseline", "2_multitask_dense": "multitask dense"})
# df_diff_long = df_diff_long.replace({"0_baseline_LSTM": "baseline", "2_multitask_dense": "multitask dense"})


######## Barplot by site ######################################################
plt.rcParams.update({'font.size': 14})
g = sns.catplot(x='site', y='value', hue='model', col='sink', kind='bar', data=df_diff_long, legend=False, col_order=['do_min', 'do_mean', 'do_max'])
g = sns.catplot(x='site', y='value', hue='model', col='sink', kind='bar',
data=df_diff_long, legend=False,
col_order=['do_min', 'do_mean', 'do_max'],
hue_order=plot_utils.model_order
)
g.set_xticklabels(rotation=90)
g.set_ylabels('Deviation from \noptimal functional performance')
g.set_titles('{col_name}')
for site_id, ax in g.axes_dict.items():
plot_utils.mark_val_sites(ax)

plt.legend(loc="lower left", bbox_to_anchor=(1.05, 0), title='Model')
plt.tight_layout()
plt.savefig("../../out/func_perf/func_performance_site_tmmx.png")
plt.clf()

######## Barplot overall ######################################################
fig, ax = plt.subplots(figsize=(6,4))
ax = sns.barplot(x='sink', y='value', hue='model', data=df_diff_long, order=['do_min', 'do_mean', 'do_max'], ax=ax)

ax = sns.barplot(x='sink', y='value', hue='model', data=df_diff_long,
order=['do_min', 'do_mean', 'do_max'], ax=ax,
hue_order=plot_utils.model_order)

ax.set_ylabel('Deviation from \noptimal functional performance')
ax.set_xlabel('')
plt.legend(loc="lower right", title='Model')
plt.tight_layout()
plt.savefig("../../out/func_perf/func_performance_overall_tmmx.png")

Expand Down
43 changes: 23 additions & 20 deletions 3_visualize/src/python_scripts/plot_output_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.13.7
# jupytext_version: 1.14.4
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
Expand All @@ -17,41 +17,44 @@
import sys
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# %%
sys.path.insert(0, "../../2a_model/src/models/0_baseline_LSTM/")
sys.path.insert(0, "../../../2a_model/src/models/2_multitask_dense/")

# %%
from model import LSTMModel
from model import LSTMModel2Dense

# %%
m = LSTMModel(10, 3)
m = LSTMModel2Dense(10, 3)

# %%
m.load_weights("../../2a_model/out/models/0_baseline_LSTM/train_weights/")
rep_id = 2

# %%
data = np.load("../../2a_model/out/models/0_baseline_LSTM/prepped.npz", allow_pickle=True)
m.load_weights(f"../../../2a_model/out/models/2_multitask_dense/nstates_10/nep_100/rep_{rep_id}/train_weights/")

# %%
y = m(data['x_val'])
o = m(np.random.randn(4, 5, 30))

# %%
w = m.weights

# %%
ax = plt.imshow(w[3].numpy())
fig = plt.gcf()
cbar = fig.colorbar(ax)
cbar.set_label('weight value')
ax = plt.gca()
ax.set_yticks(list(range(10)))
ax.set_yticklabels(f"h{i}" for i in range(10))
ax.set_ylabel('hidden state')
ax.set_xticks(list(range(3)))
ax.set_xticklabels(["DO_max", "DO_mean", "DO_min"], rotation=90)
ax.set_xlabel('output variable')
plt.tight_layout()
plt.savefig('../out/hidden_states/out_weights.jpg', bbox_inches='tight')
w[5]

# %%
w[6]

# %%
ax = sns.heatmap(w[5].numpy(), annot=True)
ax.set_xticklabels(['DO_min', 'DO_mean', 'DO_max'])
ax.set_yticklabels(['GPP', 'ER', 'K600', 'depth', 'temp.water'])

# %%
np.expand_dims(w[6].numpy(), 1)

# %%
ax = sns.heatmap(np.expand_dims(w[6].numpy(), 1), annot=True)

# %%
37 changes: 20 additions & 17 deletions 3_visualize/src/python_scripts/plot_pred_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import plot_utils

validation_sites = ["01472104", "01473500", "01481500"]
test_sites = ["01475530", "01475548"]


def read_and_filter_df(metric_type, partition):
Expand All @@ -31,44 +30,48 @@ def read_and_filter_df(metric_type, partition):

df_comb_reach = read_and_filter_df("reach", "val")

# +
g = sns.catplot(x='site_id', y='rmse', col='variable', data=df_comb_reach, hue='model_id', kind='bar', legend=False, ci='sd')
######## Barplot by site ######################################################
g = sns.catplot(x='site_id', y='rmse', col='variable', data=df_comb_reach,
hue='model_id', kind='bar', legend=False, ci='sd',
hue_order=plot_utils.model_order)
g.set_xticklabels(rotation=90)
for ax in g.axes.flatten():
ax.grid()
ax.set_axisbelow(True)
plot_utils.mark_val_sites(ax)

plt.legend(bbox_to_anchor=(1.05, .55))
plt.tight_layout()
plt.savefig("../../out/pred_perf/val_results_by_site.png")
plt.clf()
# -

g=sns.catplot(x='site_id', y='rmse', hue='model_id', col='variable', col_wrap=3, data=df_comb_reach, dodge=True, legend=False)
######## Stripplot by site ####################################################
g=sns.catplot(x='site_id', y='rmse', hue='model_id', col='variable',
col_wrap=3, data=df_comb_reach, dodge=True, legend=False,
hue_order=plot_utils.model_order)
g.set_xticklabels(rotation=90)
for ax in g.axes.flatten():
ax.grid()

g.set_titles('{col_name}')
for site_id, ax in g.axes_dict.items():
ax.grid()
if site_id in validation_sites:
ax.text(1, 3.2, "**Validation Site**", ha='center')
plot_utils.mark_val_sites(ax)

plt.legend(bbox_to_anchor=(1.05, .55))
plt.tight_layout()
plt.savefig('../../out/pred_perf/val_results_by_site_strip.png')
plt.clf()

# -

######## Barplot overall ######################################################
df_comb = read_and_filter_df('overall', 'val')

# +
fig, ax = plt.subplots(figsize=(6,4))
ax = sns.barplot(x='variable', y='rmse', data=df_comb, hue='model_id', ax=ax)
ax.bar_label(ax.containers[0], label_type="center", fmt='%.2f')
ax.bar_label(ax.containers[1], label_type="center", fmt='%.2f')
ax = sns.barplot(x='variable', y='rmse', data=df_comb, hue='model_id', ax=ax,
hue_order=plot_utils.model_order)
for c in ax.containers:
ax.bar_label(c, label_type="center", fmt='%.2f')

ax.set_xlabel('')

plt.savefig("../../out/pred_perf/val_results_overall.png")
# -


143 changes: 143 additions & 0 deletions 3_visualize/src/python_scripts/plot_pred_performance_manuscript.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# ---
# jupyter:
# jupytext:
# formats: ipynb,py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.14.4
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
# name: python3
# ---

# %%
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import plot_utils
import numpy as np
import seaborn.objects as so

# %%
df_comb_reach = plot_utils.read_and_filter_df("reach", "val")

# %%
df_comb_reach.model_id.unique()

# %%
models = ['0_baseline_LSTM', '1_metab_multitask', '2_multitask_dense']

# %%
df_comb_month = plot_utils.read_and_filter_df('month', 'val')

# %%
######## Barplot by site ######################################################
df_reach = df_comb_reach[df_comb_reach.model_id != '1a_multitask_do_gpp_er']
reach_groups = df_reach.groupby(['site_id', 'variable', 'model_id'])
reach_means = reach_groups.mean()
reach_stds = reach_groups.std()

######## Overall barplot ######################################################
model_groups = df_reach.groupby(['variable', 'model_id'])

model_means = model_groups.mean()
model_stds = model_groups.std()

######## Barplot by month ######################################################
month_groups = df_comb_month.groupby(['date', 'model_id'])

month_means = month_groups.mean()
month_stds = month_groups.std()


# %%
def plot_bars_overall(means, stds, ax):
variables = ['do_min', 'do_mean', 'do_max']
x = np.arange(len(variables)) # the label locations
width = 0.25 # the width of the bars
multiplier = 0

var_means = means.reset_index().pivot(index='variable', columns='model_id', values='rmse').loc[variables]
var_stds = stds.reset_index().pivot(index='variable', columns='model_id', values='rmse').loc[variables]

for col in var_means.columns:
data = var_means[col]
offset = width * multiplier
rects = ax.bar(x + offset, data, width, label=col, yerr=var_stds[col])
# ax.bar_label(rects, padding=3, fmt='%.2f', label_type='center')
multiplier += 1

ax.set_xticks(x + width, variables)
ax.grid()
ax.set_axisbelow(True)


def plot_bars_reach(means, stds, variable, ax):
sites = means.reset_index()['site_id'].unique()
x = np.arange(len(sites)) # the label locations
width = 0.25 # the width of the bars
multiplier = 0

var_means = means.query(f"variable == '{variable}'").reset_index().pivot(index='site_id', columns='model_id', values='rmse')
var_stds = stds.query(f"variable == '{variable}'").reset_index().pivot(index='site_id', columns='model_id', values='rmse')

for col in var_means.columns:
data = var_means[col]
offset = width * multiplier
rects = ax.bar(x + offset, data, width, label=col, yerr=var_stds[col])
multiplier += 1

ax.set_xticks(x + width, sites)
ax.grid()
ax.set_axisbelow(True)

def plot_bars_months(means, stds, variable, ax):
sites = means.reset_index()['date'].unique()
x = np.arange(len(sites)) # the label locations
width = 0.25 # the width of the bars
multiplier = 0

var_means = means.query(f"variable == '{variable}'").reset_index().pivot(index='date', columns='model_id', values='rmse')
var_stds = stds.query(f"variable == '{variable}'").reset_index().pivot(index='date', columns='model_id', values='rmse')

for col in var_means.columns:
data = var_means[col]
offset = width * multiplier
rects = ax.bar(x + offset, data, width, label=col, yerr=var_stds[col])
multiplier += 1

ax.set_xticks(x + width, sites)
ax.grid()
ax.set_axisbelow(True)


# %%
fig = plt.figure(constrained_layout=True, figsize=(12, 8))
subfigs = fig.subfigures(2, 1, wspace=0.07)
subfigsTop = subfigs[0].subfigures(1, 2, wspace=0.07, width_ratios=[1, 2])
axsTopRight = subfigsTop[1].subplots(1, 3, sharey=True)
axsTopLeft = subfigsTop[0].subplots()
subfigsTop[0].suptitle('A')
subfigsTop[1].suptitle('B')

plot_bars_overall(model_means, model_stds, axsTopLeft)
# axsTopLeft.bar([0, 1, 2, 3], [0, 1, 2, 3])

variables = ['do_min', 'do_mean', 'do_max']


for i, ax in enumerate(axsTopRight):
plot_bars_reach(reach_means, reach_stds, variables[i], ax)
ax.set_title(variables[i])

axsBottom = subfigs[1].subplots(1, 3)
for i, ax in enumerate(axsBottom):
plot_bars_months(month_means, month_stds, variables[i], ax)
ax.set_title(variables[i])



# %%
11 changes: 10 additions & 1 deletion 3_visualize/src/python_scripts/plot_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import matplotlib.pyplot as plt
import pandas as pd
validation_sites = ["01472104", "01473500", "01481500"]
test_sites = ["01475530", "01475548"]

model_order = ["0_baseline_LSTM", "1a_multitask_do_gpp_er",
"1_metab_multitask", "2_multitask_dense"]
"1_metab_multitask", "2_multitask_dense", "2a_lower_lambda_metab"]

def mark_val_sites(ax):
labels = [item.get_text() for item in ax.get_xticklabels()]
Expand All @@ -20,3 +21,11 @@ def mark_val_sites(ax):
fig.text(0.5, 0, "* validation site")

return ax

def read_and_filter_df(metric_type, partition):
f_name = f"../../../2a_model/out/models/combined_{metric_type}_metrics.csv"
df_comb = pd.read_csv(f_name, dtype={"site_id": str})
df_comb = df_comb[df_comb['partition'] == partition]
df_comb = df_comb[df_comb['variable'].str.startswith('do')]
df_comb = df_comb[df_comb['rmse'].notna()]
return df_comb

0 comments on commit 5d38dad

Please sign in to comment.