Skip to content

Commit

Permalink
[USGS-R#184] referencing data release results
Browse files Browse the repository at this point in the history
  • Loading branch information
jsadler2 committed Sep 7, 2023
1 parent db0a7a6 commit cda71ff
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 159 deletions.
35 changes: 5 additions & 30 deletions 3_visualize/src/python_scripts/catch_area_metab_presence.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,13 @@

# %%
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import xarray as xr
import seaborn as sns

# %%
obs_file = "../../../2a_model/out/well_obs_io.zarr"
from plot_utils import urban_sites, headwater_site, train_sites, obs_file, input_variables

# %%
urban_sites = ['01475530', '01475548']
headwater_site = ['014721259']
train_sites = ['01472104', '014721254', '01473500', '01480617', '01480870', '01481000', '01481500']
all_sites = urban_sites + headwater_site + train_sites

# %%
input_variables = ["SLOPE","TOTDASQKM","CAT_BASIN_SLOPE",
"TOT_BASIN_SLOPE","CAT_ELEV_MEAN","CAT_RDX","CAT_BFI","CAT_EWT",
"CAT_TWI","CAT_PPT7100_ANN","TOT_PPT7100_ANN","CAT_RUN7100",
"CAT_CNPY11_BUFF100","CAT_IMPV11","TOT_IMPV11","CAT_NLCD11_wetland",
"TOT_NLCD11_wetland","CAT_SANDAVE","CAT_PERMAVE","TOT_PERMAVE",
"CAT_RFACT","CAT_WTDEP","TOT_WTDEP","CAT_NPDES_MAJ","CAT_NDAMS2010",
"CAT_NORM_STORAGE2010"]

# %%
ds = xr.open_zarr(obs_file)

# %%
ds = ds.sel(site_id=all_sites)
df = pd.read_csv(obs_file)
df = df.set_index('site_id')

# %%
print("drainage areas [sq km]")
drainage_areas = ds['TOTDASQKM'].mean(dim='date').to_dataframe().sort_values('TOTDASQKM')
drainage_areas = df['TOTDASQKM'].groupby("site_id").mean().sort_values()
print(drainage_areas.round())

print("")
Expand All @@ -41,5 +16,5 @@

print("")
print("number of metab observations per site")
print(ds['GPP'].to_dataframe().groupby('site_id').count())
print(df['GPP'].groupby('site_id').count())

51 changes: 5 additions & 46 deletions 3_visualize/src/python_scripts/differences_in_site_attrs.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,17 @@
# ---
# jupyter:
# jupytext:
# formats: ipynb,py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.14.4
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
# name: python3
# ---

# %%
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import xarray as xr
import seaborn as sns
from plot_utils import urban_sites, headwater_site, train_sites, input_variables, obs_file

# %%
obs_file = "../../../2a_model/out/well_obs_io.zarr"

# %%
urban_sites = ['01475530', '01475548']
headwater_site = ['014721259']
train_sites = ['01472104', '014721254', '01473500', '01480617', '01480870', '01481000', '01481500']
df = pd.read_csv(obs_file, dtype={"site_id": str}, index_col=['site_id'])
print(df)

# %%
input_variables = ["SLOPE","TOTDASQKM","CAT_BASIN_SLOPE",
"TOT_BASIN_SLOPE","CAT_ELEV_MEAN","CAT_RDX","CAT_BFI","CAT_EWT",
"CAT_TWI","CAT_PPT7100_ANN","TOT_PPT7100_ANN","CAT_RUN7100",
"CAT_CNPY11_BUFF100","CAT_IMPV11","TOT_IMPV11","CAT_NLCD11_wetland",
"TOT_NLCD11_wetland","CAT_SANDAVE","CAT_PERMAVE","TOT_PERMAVE",
"CAT_RFACT","CAT_WTDEP","TOT_WTDEP","CAT_NPDES_MAJ","CAT_NDAMS2010",
"CAT_NORM_STORAGE2010"]
df = df[input_variables].groupby('site_id').mean()
print(df)

# %%
ds = xr.open_zarr(obs_file)

# %%
df = ds[input_variables].mean(dim='date').to_dataframe()


# %%
colors = []

for s in df.index:
Expand All @@ -57,12 +24,9 @@
else:
colors.append(sns.color_palette()[0])

# %%
df_long = df.melt(ignore_index=False).reset_index()

# %%

# %%
sns.set(font_scale=1.7)
g = sns.catplot(x='site_id', y='value', kind='bar', palette=colors, col="variable", col_wrap=7, data=df_long, sharey=False)
g.set_xticklabels([])
Expand All @@ -75,25 +39,20 @@
g.savefig('../../out/catch_attr_distr_test_sites.png', dpi=300)


# %%
variables_in_table = ['SLOPE', 'TOT_IMPV11', 'CAT_RDX']

# %%
df_long_train = df_long[df_long['site_id'].isin(train_sites)]

print("Training Sites:")
print(df_long_train.groupby('variable').mean().loc[variables_in_table])

# %%
df_long_urban = df_long[df_long['site_id'].isin(urban_sites)]

print("Urban Sites:")
print(df_long_urban.groupby('variable').mean().loc[variables_in_table])

# %%
df_long_hw = df_long[df_long['site_id'].isin(headwater_site)]

print("Headwater Site:")
print(df_long_hw.groupby('variable').mean().loc[variables_in_table])

# %%
42 changes: 5 additions & 37 deletions 3_visualize/src/python_scripts/plot_pred_performance.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,17 @@
# ---
# jupyter:
# jupytext:
# formats: ipynb,py:light
# text_representation:
# extension: .py
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.14.4
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
# name: python3
# ---

import pandas as pd
import re
import os
import seaborn as sns
import matplotlib.pyplot as plt
import plot_utils
from plot_utils import read_and_filter_df, make_holdout_id_col, filter_out_urban_spatial, replacements, model_labels
from plot_utils import read_and_filter_df, make_holdout_id_col, filter_out_urban_spatial, replacements, model_labels, df_site_filt
import numpy as np
import seaborn.objects as so


# +
outdir = f"../../out"

# -

df_comb_reach = read_and_filter_df("reach", "val")
df_comb_reach = df_comb_reach.replace(replacements)

test_sites_urban = ["01475530", "01475548"]

# -

df_comb_reach = make_holdout_id_col(df_comb_reach)

df_reach_filt = filter_out_urban_spatial(df_comb_reach)


def format_plot(g):
for i, ax in enumerate(g.axes.flatten()):
Expand All @@ -66,13 +38,13 @@ def format_plot(g):


######## stripplot by site (temporal)###########################################
df_comb_reach_temporal = df_comb_reach[df_comb_reach["holdout_id"] == "temporal"]
df_comb_site_temporal = df_site_filt[df_site_filt["holdout_id"] == "temporal"]
plt.rcParams.update({'font.size': 18})
g = sns.catplot(
x="site_id",
y="rmse",
col="variable",
data=df_comb_reach_temporal,
data=df_comb_site_temporal,
hue="model_id",
kind="strip",
dodge=True,
Expand All @@ -92,7 +64,7 @@ def format_plot(g):
x="holdout_id",
y="rmse",
col="variable",
data=df_reach_filt,
data=df_site_filt,
hue="model_id",
kind="bar",
errorbar="sd",
Expand All @@ -115,20 +87,18 @@ def format_plot(g):
g.set_xlabels("Holdout Experiment")
plt.savefig(os.path.join(outdir, "val_results_by_holdout.png"), bbox_inches='tight', dpi=300)

# -

######## prep for lineplot by month ###########################################
month_order = [10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9]

df_comb_month = read_and_filter_df("month_reach", "val")
df_comb_month = read_and_filter_df("site_month", "val")

df_comb_month = make_holdout_id_col(df_comb_month)
df_comb_month = df_comb_month.replace(replacements)

df_comb_month = df_comb_month[df_comb_month['holdout_id'] == 'temporal']


# +
######## Lineplot by month #####################################################
g = sns.relplot(
x="date",
Expand All @@ -148,7 +118,5 @@ def format_plot(g):
g = format_plot(g)

g.set_xlabels("Month")
# plt.tight_layout()
plt.savefig(os.path.join(outdir, "val_results_by_month_line.png"), bbox_inches='tight', dpi=300)
# -

69 changes: 44 additions & 25 deletions 3_visualize/src/python_scripts/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,58 @@
import re
import pandas as pd

# CHANGE THIS IF NECESSARY
metric_data_directory = '../../../../pgdl-do-data-release/out_data'

obs_file = f"{metric_data_directory}/model_inputs_outputs.csv"

input_variables = ["SLOPE","TOTDASQKM","CAT_BASIN_SLOPE",
"TOT_BASIN_SLOPE","CAT_ELEV_MEAN","CAT_RDX","CAT_BFI","CAT_EWT",
"CAT_TWI","CAT_PPT7100_ANN","TOT_PPT7100_ANN","CAT_RUN7100",
"CAT_CNPY11_BUFF100","CAT_IMPV11","TOT_IMPV11","CAT_NLCD11_wetland",
"TOT_NLCD11_wetland","CAT_SANDAVE","CAT_PERMAVE","TOT_PERMAVE",
"CAT_RFACT","CAT_WTDEP","TOT_WTDEP","CAT_NPDES_MAJ","CAT_NDAMS2010",
"CAT_NORM_STORAGE2010"]


sites_xwalk = {
"01480617" : "BC_53",
"01480870" : "BC_40",
"01481000" : "BC_24",
"01481500" : "BC_8",
"01472104" : "SR_72",
"01473500" : "SR_40",
"01475530" : "CC_12",
"01475548" : "CC_4",
"014721259" : "BAP",
"014721254" : "FC"
}


model_labels = [
"v0 - Baseline",
"v1 - Metab Multitask",
# "1a_multitask_gpp_er",
"v2 - Metab Dependent",
]

replacements = {
"0_baseline_LSTM": model_labels[0],
# "1a_multitask_do_gpp_er": "Metab Multitask - GPP, ER",
"1_metab_multitask": model_labels[1],
"2_multitask_dense": model_labels[2],
}

validation_sites = ["01472104", "01473500", "01481500"]
test_sites = ["01475530", "01475548"]
test_sites_urban = ["01475530", "01475548"]
train_sites = ['01472104', '014721254', '01473500', '01480617', '01480870', '01481000', '01481500']
urban_sites = ['01475530', '01475548']
headwater_site = ['014721259']

all_sites = train_sites + urban_sites + headwater_site

model_order = ["0_baseline_LSTM", "1a_multitask_do_gpp_er",
"1_metab_multitask", "2_multitask_dense", "2a_lower_lambda_metab"]

def mark_val_sites(ax):
labels = [item.get_text() for item in ax.get_xticklabels()]
new_labels = []
for l in labels:
if l in validation_sites:
new_labels.append("*" + l)
else:
new_labels.append(l)

ax.set_xticklabels(new_labels)

fig = plt.gcf()
fig.text(0.5, 0, "* validation site")

return ax

def read_and_filter_df(metric_type, partition, var="do"):
f_name = f"../../../2a_model/out/models/combined_{metric_type}_metrics.csv"
f_name = f"{metric_data_directory}/{metric_type}_metrics.csv"
df_comb = pd.read_csv(f_name, dtype={"site_id": str})
df_comb = df_comb[df_comb["partition"] == partition]
if var == "do":
Expand All @@ -55,20 +67,21 @@ def read_and_filter_df(metric_type, partition, var="do"):
def define_group(row):
pattern = "^\d{7,8}$"
holdout = str(row['holdout'])
if re.match(pattern, holdout) and holdout != "14721259":
# if it's a site_id but not the headwater site it's considered "spatially
# similar". Other options are "1_urban", "2_urban", and "temporal"
if re.match(pattern, holdout) and int(holdout) != int(headwater_site[0]):
return "spatial similar"
elif holdout == "temporal":
if row["site_id"] in test_sites_urban:
if row["site_id"] in urban_sites:
return "temporal urban"
else:
return "temporal"
elif holdout == "1_urban":
return "spatial one-urban"
elif holdout == "2_urban" or holdout == "14721259":
elif holdout == "2_urban" or int(holdout) == int(headwater_site[0]):
return "spatial dissimilar"


# +
def filter_out_urban_spatial(df):
"""
Filter out the urban sites from the spatial non-urban holdouts. We don't
Expand Down Expand Up @@ -105,3 +118,9 @@ def make_holdout_id_col(df):
df['holdout_id'] = df.apply(define_group, axis=1)
return df


df_comb_site = read_and_filter_df("site", "val")
df_comb_site = df_comb_site.replace(replacements)
df_comb_site = make_holdout_id_col(df_comb_site)
df_site_filt = filter_out_urban_spatial(df_comb_site).query('model_id != "1a_multitask_do_gpp_er"')
df_site_filt = df_site_filt.replace(sites_xwalk)
13 changes: 4 additions & 9 deletions 3_visualize/src/python_scripts/stds_across_sites.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
from plot_utils import read_and_filter_df, make_holdout_id_col, replacements, filter_out_urban_spatial
from plot_utils import df_site_filt

df_comb_reach = read_and_filter_df("reach", "val")
df_comb_reach = df_comb_reach.replace(replacements)
df_comb_reach = make_holdout_id_col(df_comb_reach)
df_reach_filt = filter_out_urban_spatial(df_comb_reach)

df_reach_filt = df_reach_filt[df_reach_filt['holdout_id'] == 'temporal']
df_site_filt = df_site_filt[df_site_filt['holdout_id'] == 'temporal']

# first get the standard deviations for each site/model/variable
df_reach_std = df_reach_filt.groupby(["model_id", "variable", "site_id"]).std()['rmse']
df_site_std = df_site_filt.groupby(["model_id", "variable", "site_id"]).std()['rmse']

# take the mean across the sites and variables
mean_std = df_reach_std.groupby(["model_id"]).mean()
mean_std = df_site_std.groupby(["model_id"]).mean()

print(mean_std)
Loading

0 comments on commit cda71ff

Please sign in to comment.