diff --git a/3_visualize/src/python_scripts/DO_variability.py b/3_visualize/src/python_scripts/archive/DO_variability.py similarity index 100% rename from 3_visualize/src/python_scripts/DO_variability.py rename to 3_visualize/src/python_scripts/archive/DO_variability.py diff --git a/3_visualize/src/python_scripts/catch_area_metab_presence.py b/3_visualize/src/python_scripts/catch_area_metab_presence.py index 23d10582..5a9dd8f8 100644 --- a/3_visualize/src/python_scripts/catch_area_metab_presence.py +++ b/3_visualize/src/python_scripts/catch_area_metab_presence.py @@ -1,38 +1,13 @@ - -# %% import pandas as pd -import matplotlib.pyplot as plt -from matplotlib.patches import Patch -import xarray as xr -import seaborn as sns - -# %% -obs_file = "../../../2a_model/out/well_obs_io.zarr" +from plot_utils import urban_sites, headwater_site, train_sites, obs_file, input_variables -# %% -urban_sites = ['01475530', '01475548'] -headwater_site = ['014721259'] -train_sites = ['01472104', '014721254', '01473500', '01480617', '01480870', '01481000', '01481500'] all_sites = urban_sites + headwater_site + train_sites -# %% -input_variables = ["SLOPE","TOTDASQKM","CAT_BASIN_SLOPE", - "TOT_BASIN_SLOPE","CAT_ELEV_MEAN","CAT_RDX","CAT_BFI","CAT_EWT", - "CAT_TWI","CAT_PPT7100_ANN","TOT_PPT7100_ANN","CAT_RUN7100", - "CAT_CNPY11_BUFF100","CAT_IMPV11","TOT_IMPV11","CAT_NLCD11_wetland", - "TOT_NLCD11_wetland","CAT_SANDAVE","CAT_PERMAVE","TOT_PERMAVE", - "CAT_RFACT","CAT_WTDEP","TOT_WTDEP","CAT_NPDES_MAJ","CAT_NDAMS2010", - "CAT_NORM_STORAGE2010"] - -# %% -ds = xr.open_zarr(obs_file) - -# %% -ds = ds.sel(site_id=all_sites) +df = pd.read_csv(obs_file) +df = df.set_index('site_id') -# %% print("drainage areas [sq km]") -drainage_areas = ds['TOTDASQKM'].mean(dim='date').to_dataframe().sort_values('TOTDASQKM') +drainage_areas = df['TOTDASQKM'].groupby("site_id").mean().sort_values() print(drainage_areas.round()) print("") @@ -41,5 +16,5 @@ print("") print("number of metab observations per site") -print(ds['GPP'].to_dataframe().groupby('site_id').count()) +print(df['GPP'].groupby('site_id').count()) diff --git a/3_visualize/src/python_scripts/differences_in_site_attrs.py b/3_visualize/src/python_scripts/differences_in_site_attrs.py index 49272125..ab7e03d4 100644 --- a/3_visualize/src/python_scripts/differences_in_site_attrs.py +++ b/3_visualize/src/python_scripts/differences_in_site_attrs.py @@ -1,50 +1,17 @@ -# --- -# jupyter: -# jupytext: -# formats: ipynb,py:percent -# text_representation: -# extension: .py -# format_name: percent -# format_version: '1.3' -# jupytext_version: 1.14.4 -# kernelspec: -# display_name: Python 3 (ipykernel) -# language: python -# name: python3 -# --- - -# %% import pandas as pd import matplotlib.pyplot as plt from matplotlib.patches import Patch import xarray as xr import seaborn as sns +from plot_utils import urban_sites, headwater_site, train_sites, input_variables, obs_file -# %% -obs_file = "../../../2a_model/out/well_obs_io.zarr" -# %% -urban_sites = ['01475530', '01475548'] -headwater_site = ['014721259'] -train_sites = ['01472104', '014721254', '01473500', '01480617', '01480870', '01481000', '01481500'] +df = pd.read_csv(obs_file, dtype={"site_id": str}, index_col=['site_id']) +print(df) -# %% -input_variables = ["SLOPE","TOTDASQKM","CAT_BASIN_SLOPE", - "TOT_BASIN_SLOPE","CAT_ELEV_MEAN","CAT_RDX","CAT_BFI","CAT_EWT", - "CAT_TWI","CAT_PPT7100_ANN","TOT_PPT7100_ANN","CAT_RUN7100", - "CAT_CNPY11_BUFF100","CAT_IMPV11","TOT_IMPV11","CAT_NLCD11_wetland", - "TOT_NLCD11_wetland","CAT_SANDAVE","CAT_PERMAVE","TOT_PERMAVE", - "CAT_RFACT","CAT_WTDEP","TOT_WTDEP","CAT_NPDES_MAJ","CAT_NDAMS2010", - "CAT_NORM_STORAGE2010"] +df = df[input_variables].groupby('site_id').mean() +print(df) -# %% -ds = xr.open_zarr(obs_file) - -# %% -df = ds[input_variables].mean(dim='date').to_dataframe() - - -# %% colors = [] for s in df.index: @@ -57,12 +24,9 @@ else: colors.append(sns.color_palette()[0]) -# %% df_long = df.melt(ignore_index=False).reset_index() -# %% -# %% sns.set(font_scale=1.7) g = sns.catplot(x='site_id', y='value', kind='bar', palette=colors, col="variable", col_wrap=7, data=df_long, sharey=False) g.set_xticklabels([]) @@ -75,25 +39,20 @@ g.savefig('../../out/catch_attr_distr_test_sites.png', dpi=300) -# %% variables_in_table = ['SLOPE', 'TOT_IMPV11', 'CAT_RDX'] -# %% df_long_train = df_long[df_long['site_id'].isin(train_sites)] print("Training Sites:") print(df_long_train.groupby('variable').mean().loc[variables_in_table]) -# %% df_long_urban = df_long[df_long['site_id'].isin(urban_sites)] print("Urban Sites:") print(df_long_urban.groupby('variable').mean().loc[variables_in_table]) -# %% df_long_hw = df_long[df_long['site_id'].isin(headwater_site)] print("Headwater Site:") print(df_long_hw.groupby('variable').mean().loc[variables_in_table]) -# %% diff --git a/3_visualize/src/python_scripts/plot_pred_performance.py b/3_visualize/src/python_scripts/plot_pred_performance.py index 199663ff..d8271229 100644 --- a/3_visualize/src/python_scripts/plot_pred_performance.py +++ b/3_visualize/src/python_scripts/plot_pred_performance.py @@ -1,17 +1,3 @@ -# --- -# jupyter: -# jupytext: -# formats: ipynb,py:light -# text_representation: -# extension: .py -# format_name: light -# format_version: '1.5' -# jupytext_version: 1.14.4 -# kernelspec: -# display_name: Python 3 (ipykernel) -# language: python -# name: python3 -# --- import pandas as pd import re @@ -19,27 +5,13 @@ import seaborn as sns import matplotlib.pyplot as plt import plot_utils -from plot_utils import read_and_filter_df, make_holdout_id_col, filter_out_urban_spatial, replacements, model_labels +from plot_utils import read_and_filter_df, make_holdout_id_col, filter_out_urban_spatial, replacements, model_labels, df_site_filt import numpy as np import seaborn.objects as so -# + outdir = f"../../out" -# - - -df_comb_reach = read_and_filter_df("reach", "val") -df_comb_reach = df_comb_reach.replace(replacements) - -test_sites_urban = ["01475530", "01475548"] - -# - - -df_comb_reach = make_holdout_id_col(df_comb_reach) - -df_reach_filt = filter_out_urban_spatial(df_comb_reach) - def format_plot(g): for i, ax in enumerate(g.axes.flatten()): @@ -66,13 +38,13 @@ def format_plot(g): ######## stripplot by site (temporal)########################################### -df_comb_reach_temporal = df_comb_reach[df_comb_reach["holdout_id"] == "temporal"] +df_comb_site_temporal = df_site_filt[df_site_filt["holdout_id"] == "temporal"] plt.rcParams.update({'font.size': 18}) g = sns.catplot( x="site_id", y="rmse", col="variable", - data=df_comb_reach_temporal, + data=df_comb_site_temporal, hue="model_id", kind="strip", dodge=True, @@ -92,7 +64,7 @@ def format_plot(g): x="holdout_id", y="rmse", col="variable", - data=df_reach_filt, + data=df_site_filt, hue="model_id", kind="bar", errorbar="sd", @@ -115,12 +87,11 @@ def format_plot(g): g.set_xlabels("Holdout Experiment") plt.savefig(os.path.join(outdir, "val_results_by_holdout.png"), bbox_inches='tight', dpi=300) -# - ######## prep for lineplot by month ########################################### month_order = [10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9] -df_comb_month = read_and_filter_df("month_reach", "val") +df_comb_month = read_and_filter_df("site_month", "val") df_comb_month = make_holdout_id_col(df_comb_month) df_comb_month = df_comb_month.replace(replacements) @@ -128,7 +99,6 @@ def format_plot(g): df_comb_month = df_comb_month[df_comb_month['holdout_id'] == 'temporal'] -# + ######## Lineplot by month ##################################################### g = sns.relplot( x="date", @@ -148,7 +118,5 @@ def format_plot(g): g = format_plot(g) g.set_xlabels("Month") -# plt.tight_layout() plt.savefig(os.path.join(outdir, "val_results_by_month_line.png"), bbox_inches='tight', dpi=300) -# - diff --git a/3_visualize/src/python_scripts/plot_utils.py b/3_visualize/src/python_scripts/plot_utils.py index 32aefbca..e0a345e8 100644 --- a/3_visualize/src/python_scripts/plot_utils.py +++ b/3_visualize/src/python_scripts/plot_utils.py @@ -2,46 +2,58 @@ import re import pandas as pd +# CHANGE THIS IF NECESSARY +metric_data_directory = '../../../../pgdl-do-data-release/out_data' + +obs_file = f"{metric_data_directory}/model_inputs_outputs.csv" + +input_variables = ["SLOPE","TOTDASQKM","CAT_BASIN_SLOPE", + "TOT_BASIN_SLOPE","CAT_ELEV_MEAN","CAT_RDX","CAT_BFI","CAT_EWT", + "CAT_TWI","CAT_PPT7100_ANN","TOT_PPT7100_ANN","CAT_RUN7100", + "CAT_CNPY11_BUFF100","CAT_IMPV11","TOT_IMPV11","CAT_NLCD11_wetland", + "TOT_NLCD11_wetland","CAT_SANDAVE","CAT_PERMAVE","TOT_PERMAVE", + "CAT_RFACT","CAT_WTDEP","TOT_WTDEP","CAT_NPDES_MAJ","CAT_NDAMS2010", + "CAT_NORM_STORAGE2010"] + + +sites_xwalk = { + "01480617" : "BC_53", + "01480870" : "BC_40", + "01481000" : "BC_24", + "01481500" : "BC_8", + "01472104" : "SR_72", + "01473500" : "SR_40", + "01475530" : "CC_12", + "01475548" : "CC_4", + "014721259" : "BAP", + "014721254" : "FC" + } + model_labels = [ "v0 - Baseline", "v1 - Metab Multitask", - # "1a_multitask_gpp_er", "v2 - Metab Dependent", ] replacements = { "0_baseline_LSTM": model_labels[0], - # "1a_multitask_do_gpp_er": "Metab Multitask - GPP, ER", "1_metab_multitask": model_labels[1], "2_multitask_dense": model_labels[2], } -validation_sites = ["01472104", "01473500", "01481500"] -test_sites = ["01475530", "01475548"] -test_sites_urban = ["01475530", "01475548"] +train_sites = ['01472104', '014721254', '01473500', '01480617', '01480870', '01481000', '01481500'] +urban_sites = ['01475530', '01475548'] +headwater_site = ['014721259'] + +all_sites = train_sites + urban_sites + headwater_site model_order = ["0_baseline_LSTM", "1a_multitask_do_gpp_er", "1_metab_multitask", "2_multitask_dense", "2a_lower_lambda_metab"] -def mark_val_sites(ax): - labels = [item.get_text() for item in ax.get_xticklabels()] - new_labels = [] - for l in labels: - if l in validation_sites: - new_labels.append("*" + l) - else: - new_labels.append(l) - - ax.set_xticklabels(new_labels) - - fig = plt.gcf() - fig.text(0.5, 0, "* validation site") - - return ax def read_and_filter_df(metric_type, partition, var="do"): - f_name = f"../../../2a_model/out/models/combined_{metric_type}_metrics.csv" + f_name = f"{metric_data_directory}/{metric_type}_metrics.csv" df_comb = pd.read_csv(f_name, dtype={"site_id": str}) df_comb = df_comb[df_comb["partition"] == partition] if var == "do": @@ -55,20 +67,21 @@ def read_and_filter_df(metric_type, partition, var="do"): def define_group(row): pattern = "^\d{7,8}$" holdout = str(row['holdout']) - if re.match(pattern, holdout) and holdout != "14721259": + # if it's a site_id but not the headwater site it's considered "spatially + # similar". Other options are "1_urban", "2_urban", and "temporal" + if re.match(pattern, holdout) and int(holdout) != int(headwater_site[0]): return "spatial similar" elif holdout == "temporal": - if row["site_id"] in test_sites_urban: + if row["site_id"] in urban_sites: return "temporal urban" else: return "temporal" elif holdout == "1_urban": return "spatial one-urban" - elif holdout == "2_urban" or holdout == "14721259": + elif holdout == "2_urban" or int(holdout) == int(headwater_site[0]): return "spatial dissimilar" -# + def filter_out_urban_spatial(df): """ Filter out the urban sites from the spatial non-urban holdouts. We don't @@ -105,3 +118,9 @@ def make_holdout_id_col(df): df['holdout_id'] = df.apply(define_group, axis=1) return df + +df_comb_site = read_and_filter_df("site", "val") +df_comb_site = df_comb_site.replace(replacements) +df_comb_site = make_holdout_id_col(df_comb_site) +df_site_filt = filter_out_urban_spatial(df_comb_site).query('model_id != "1a_multitask_do_gpp_er"') +df_site_filt = df_site_filt.replace(sites_xwalk) diff --git a/3_visualize/src/python_scripts/stds_across_sites.py b/3_visualize/src/python_scripts/stds_across_sites.py index 11ac7533..c1d3b6c5 100644 --- a/3_visualize/src/python_scripts/stds_across_sites.py +++ b/3_visualize/src/python_scripts/stds_across_sites.py @@ -1,16 +1,11 @@ -from plot_utils import read_and_filter_df, make_holdout_id_col, replacements, filter_out_urban_spatial +from plot_utils import df_site_filt -df_comb_reach = read_and_filter_df("reach", "val") -df_comb_reach = df_comb_reach.replace(replacements) -df_comb_reach = make_holdout_id_col(df_comb_reach) -df_reach_filt = filter_out_urban_spatial(df_comb_reach) - -df_reach_filt = df_reach_filt[df_reach_filt['holdout_id'] == 'temporal'] +df_site_filt = df_site_filt[df_site_filt['holdout_id'] == 'temporal'] # first get the standard deviations for each site/model/variable -df_reach_std = df_reach_filt.groupby(["model_id", "variable", "site_id"]).std()['rmse'] +df_site_std = df_site_filt.groupby(["model_id", "variable", "site_id"]).std()['rmse'] # take the mean across the sites and variables -mean_std = df_reach_std.groupby(["model_id"]).mean() +mean_std = df_site_std.groupby(["model_id"]).mean() print(mean_std) diff --git a/3_visualize/src/python_scripts/summary_stats.py b/3_visualize/src/python_scripts/summary_stats.py index a54c1ec7..3b84f055 100644 --- a/3_visualize/src/python_scripts/summary_stats.py +++ b/3_visualize/src/python_scripts/summary_stats.py @@ -1,18 +1,22 @@ -from plot_utils import read_and_filter_df, make_holdout_id_col, replacements, filter_out_urban_spatial +from plot_utils import df_site_filt -df_comb_reach = read_and_filter_df("reach", "val") -df_comb_reach = df_comb_reach.replace(replacements) -df_comb_reach = make_holdout_id_col(df_comb_reach) -df_reach_filt = filter_out_urban_spatial(df_comb_reach).query('model_id != "1a_multitask_do_gpp_er"') - - -summary = df_reach_filt.groupby(["model_id", "variable", "holdout_id"]).describe() -print(summary) +summary = df_site_filt.groupby(["model_id", "variable", "holdout_id"]).describe() +print('#'*40 + ' Values for Section 3.1 ' + 40*'#') print(summary['rmse'][['mean']].query('holdout_id == "temporal"')) -print(summary['rmse'][['mean']].query('holdout_id == "spatial similar"')) +print('#' * 100) +print('#' * 100) -summary_by_site = df_reach_filt.groupby(["model_id", "variable", "holdout_id", "site_id"]).describe() -print(summary_by_site) + +print("summary stats by site") +summary_by_site = df_site_filt.groupby(["model_id", "variable", "holdout_id", "site_id"]).describe() +print('#'*40 + ' Values for Section 3.1.1 ' + 40*'#') print(summary_by_site['rmse'][['mean']].query('holdout_id == "temporal"')) +print('#' * 100) +print('#' * 100) + +print('#'*40 + ' Values for Section 3.2.1 ' + 40*'#') +print(summary['rmse'][['mean']].query('holdout_id == "spatial similar"')) +print('#' * 100) +print('#' * 100)