diff --git a/3_visualize/src/python_scripts/plot_utils.py b/3_visualize/src/python_scripts/plot_utils.py index 909786d0..c564ad92 100644 --- a/3_visualize/src/python_scripts/plot_utils.py +++ b/3_visualize/src/python_scripts/plot_utils.py @@ -22,10 +22,15 @@ def mark_val_sites(ax): return ax -def read_and_filter_df(metric_type, partition): +def read_and_filter_df(metric_type, partition, var="do"): f_name = f"../../../2a_model/out/models/combined_{metric_type}_metrics.csv" df_comb = pd.read_csv(f_name, dtype={"site_id": str}) - df_comb = df_comb[df_comb['partition'] == partition] - df_comb = df_comb[df_comb['variable'].str.startswith('do')] - df_comb = df_comb[df_comb['rmse'].notna()] - return df_comb \ No newline at end of file + df_comb = df_comb[df_comb["partition"] == partition] + if var == "do": + df_comb = df_comb[df_comb["variable"].str.startswith("do")] + else: + df_comb = df_comb[df_comb["variable"] == var] + df_comb = df_comb[df_comb["rmse"].notna()] + return df_comb + +