Skip to content

Commit

Permalink
Merge pull request USGS-R#182 from jsadler2/179-integrate-functional-…
Browse files Browse the repository at this point in the history
…performance

integrate functional performance
  • Loading branch information
jsadler2 committed Dec 29, 2022
2 parents 7dcf000 + b02cec1 commit 080508c
Show file tree
Hide file tree
Showing 8 changed files with 252 additions and 311 deletions.
14 changes: 7 additions & 7 deletions 2a_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,12 @@ p2a_targets_list <- list(
config_path = stringr::str_remove(p2a_config_baseline_LSTM_yml, "2a_model/src/models/")),
#the 1_ models use the same model and therefore the same Snakefile
#as the 0_baseline_LSTM run
#list(model_id = "1_metab_multitask",
#snakefile_dir = "0_baseline_LSTM",
#config_path = stringr::str_remove(p2a_config_metab_multitask_yml, "2a_model/src/models/")),
#list(model_id = "1a_multitask_do_gpp_er",
#snakefile_dir = "0_baseline_LSTM",
#config_path = stringr::str_remove(p2a_config_1a_metab_multitask_yml, "2a_model/src/models/")),
list(model_id = "1_metab_multitask",
snakefile_dir = "0_baseline_LSTM",
config_path = stringr::str_remove(p2a_config_metab_multitask_yml, "2a_model/src/models/")),
list(model_id = "1a_multitask_do_gpp_er",
snakefile_dir = "0_baseline_LSTM",
config_path = stringr::str_remove(p2a_config_1a_metab_multitask_yml, "2a_model/src/models/")),
list(model_id = "2_multitask_dense",
snakefile_dir = "2_multitask_dense",
config_path = stringr::str_remove(p2a_config_multitask_dense_yml, "2a_model/src/models/"))
Expand Down Expand Up @@ -230,7 +230,7 @@ p2a_targets_list <- list(
system(stringr::str_glue("snakemake -s {snakefile_path} --configfile {config_path} -j --touch --rerun-incomplete"))

# then run the snakemake pipeline to produce the predictions and metric files
system(stringr::str_glue("snakemake -s {snakefile_path} --configfile {config_path} -j --rerun-incomplete "))
system(stringr::str_glue("snakemake -s {snakefile_path} --configfile {config_path} -j --rerun-incomplete --rerun-trigger mtime"))

# print out the metrics file name for the target
c(
Expand Down
182 changes: 182 additions & 0 deletions 2a_model/src/do_it_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# -*- coding: utf-8 -*-
"""
Created on Fri May 27 10:00:43 2022
@author: ggorski
"""
import pandas as pd
import sys
#sys.path.insert(0, 'C:\\Users\\ggorski\\OneDrive - DOI\\USGS_ML\\DO\\drb-do-ml\\scratch\\Functional_Performance\\src')
import it_functions as it_functions
import numpy as np
import math
import xarray as xr

def calc_it_metrics_site(inputs_zarr,
predictions_file,
source,
sink,
site,
log_transform,
model,
replicate,
outfile=None):
'''
Calculate the transfer entropy (TE) and Mutual Information (MI) between
one input (source) and one output (sink) at one site and one replicate
Parameters
----------
inputs_zarr : str
path to io zarr file
predictions_file : str
path to preds.feather file
source : str
source for calculations (e.g., srad, tmmx, tmmn)
sink : str
sink for calculations (e.g., 'do_min', 'do_mean', 'do_max')
site : str
site number
log_transform : boolean
should the source variable be log10 transformed, should only be log10 transformed for discharge
model: str
the model for which you are doing the calcs (e.g., '0_baseline_LSTM', 'observed')
replicate: int
which replicate you are doing the calcs for
outfile: str
filepath to store the output (if desired)
Returns
-------
Information theory metric results (transfer entropy) as a nested dictionary
'''
inputs = xr.open_zarr(inputs_zarr, consolidated=False)
inputs_df = inputs.to_dataframe()

# TODO: it'd be nice to read this in dynamically at some point
inputs_site = inputs_df.loc[site][['CAT_BASIN_SLOPE', 'CAT_CNPY11_BUFF100',
'CAT_ELEV_MEAN', 'CAT_IMPV11', 'CAT_TWI', 'SLOPE', 'day.length', 'depth',
'discharge', 'light_ratio',
'model_confidence', 'pr', 'resolution', 'rmax', 'rmin', 'shortwave',
'site_min_confidence', 'site_name', 'sph', 'srad', 'temp.water', 'tmmn',
'tmmx', 'velocity', 'vs']]
targets_site = inputs_df.loc[site][['do_min','do_mean','do_max']]

if sink == 'do_range':
targets_site['do_range'] = targets_site['do_max']-targets_site['do_min']

tar_dict = {}

model_preds = pd.read_feather(predictions_file)
model_preds = model_preds[model_preds['site_id'] == site].set_index('date')[['do_min','do_mean','do_max']]
model_preds['do_range'] = model_preds['do_max']-model_preds['do_min']
#create targets dictionary
if model != 'observed':
tar_dict[model] = model_preds
tar_dict['observed'] = targets_site



#create dictionary to store calculations in
max_it = {}
#create a nested dictionary for each DO variable to store it calcs
#TE0 = Transfer Entropy at a time lag of 0, MI = mututal information,
#TEmax is the maximum TE, TEmaxT is the time lag of the maximum TE,
#TEmaxcrit is the a True/False if TEmax is significant, everything follows the same
#convention for MI


#join input and target to make sure they are aligned
site_inptar = inputs_site.join(tar_dict[model], rsuffix = '_pred')

#assign x = source, y = sink
x = site_inptar[source]
y = site_inptar[sink]

#for calculating rmse
obs_pred = tar_dict['observed'][[sink]].join(tar_dict[model][[sink]], rsuffix = '_pred')


#load the preprocessing fucntions from it_functions.py
ppf = it_functions.pre_proc_func()

if log_transform:
xl10 = ppf.log10(x)
else:
xl10 = x.copy()
x_rss = ppf.remove_seasonal_signal(xl10)
x_ss = ppf.standardize(x_rss)

y_prepped = {}
y_rss = ppf.remove_seasonal_signal(y)
y_prepped = ppf.standardize(y_rss)


print('Calculating it metrics '+model+' '+site)
n_lags = 9
nbins = 11
it_dict = {}

#create an array of the prepped x and y variables
M = np.stack((x_ss,y_prepped), axis = 1)
#Mswap is for caclulating the TE from Y -> X, we don't really need to do that
#because DO doesn't affect solar radiation, but it is needed for function
Mswap = np.stack((y_prepped, x_ss), axis = 1)
#x_bounds and y_bounds are for removing outliers
x_bounds = it_functions.find_bounds(M[:,0], 0.1, 99.9)
y_bounds = it_functions.find_bounds(M[:,1], 0.1, 99.9)
M_x_bound = np.delete(M, np.where((M[:,0] < x_bounds[0]*1.1) | (M[:,0] > x_bounds[1]*1.1)), axis = 0)
M_xy_bound = np.delete(M_x_bound, np.where((M_x_bound[:,1] < y_bounds[0]*1.1) | (M_x_bound[:,1] > y_bounds[1]*1.1)), axis = 0)

#calc it metrics and store in the dictionary it_dict
it_dict = it_functions.calc_it_metrics(M_xy_bound, Mswap, n_lags, nbins, calc_swap = False, alpha = 0.05, ncores = 7)


print('Storing it metrics '+model+' '+site)
#find the max TE and MI and the time lag at which the max occurs
#and store that in a dictionary as well

TEmax = max(it_dict['TE'])
TEmaxt = int(np.where(it_dict['TE'] == TEmax)[0])

if TEmax > it_dict['TEcrit'][TEmaxt]:
TEmaxcrit = True
else:
TEmaxcrit = False

MImax = max(it_dict['MI'])
MImaxt = int(np.where(it_dict['MI'] == MImax)[0])

if MImax > it_dict['MIcrit'][MImaxt]:
MImaxcrit = True
else:
MImaxcrit = False
#do min
max_it['model'] = model
mse = np.square(np.subtract(obs_pred[sink+'_pred'],obs_pred[sink])).mean()
math.sqrt(mse)
max_it['rmse'] = math.sqrt(mse)

max_it['TEmax'] = TEmax
max_it['TEmaxt'] = TEmaxt
max_it['TEmaxcrit'] = TEmaxcrit

max_it['MImax'] = MImax
for variable in ['TE', 'TEcrit', 'MI', 'MIcrit', 'corr']:
for i in range(9):
max_it[f'{variable}{i}'] = it_dict[variable][i]
max_it['MImaxt'] = MImaxt
max_it['MImaxcrit'] = MImaxcrit
max_it['replicate'] = replicate
max_it['sink'] = sink
max_it['source'] = source
max_it['site'] = site

if outfile:
df = pd.DataFrame(max_it, index=[0])
df.to_csv(outfile, index=False)

return max_it


File renamed without changes.
5 changes: 4 additions & 1 deletion 2a_model/src/models/0_baseline_LSTM/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ rule all:
epochs=config['epochs'],
rep=list(range(config['num_replicates'])),
site_id=['01480870'],
year=[2012])
year=[2012]),
f"{out_dir}/{config['exp_name']}_func_perf.csv",
f"{out_dir}/observed_func_perf.csv"



module base_workflow:
Expand Down
3 changes: 2 additions & 1 deletion 2a_model/src/models/2_multitask_dense/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ rule all:
epochs=config['epochs'],
rep=list(range(config['num_replicates'])),
site_id=['01480870'],
year=[2012])
year=[2012]),
f"{out_dir}/{config['exp_name']}_func_perf.csv"


module base_workflow:
Expand Down
62 changes: 57 additions & 5 deletions 2a_model/src/models/Snakefile_base.smk
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import os
import xarray as xr
import tensorflow as tf
import numpy as np
import pandas as pd
import sys

code_dir = "../river-dl"
sys.path.append(code_dir)
river_dl_dir = "../river-dl"
sys.path.append(river_dl_dir)

src_dir = "../.."
sys.path.append(src_dir)

from river_dl.preproc_utils import asRunConfig
from river_dl.preproc_utils import prep_all_data
Expand All @@ -14,6 +18,7 @@ from river_dl.postproc_utils import plot_obs, plot_ts, prepped_array_to_df
from river_dl.predict import predict_from_arbitrary_data
from river_dl.train import train_model
from river_dl import loss_functions as lf
from do_it_functions import calc_it_metrics_site

out_dir = os.path.join(config['out_dir'], config['exp_name'])
loss_function = lf.multitask_rmse(config['lambdas'])
Expand All @@ -29,7 +34,7 @@ rule as_run_config:

rule prep_io_data:
input:
f"../../../out/well_obs_io.zarr",
"../../../out/well_obs_io.zarr",
output:
"{outdir}/prepped.npz"
run:
Expand Down Expand Up @@ -100,7 +105,7 @@ rule make_predictions:
input:
"{outdir}/prepped.npz",
"{outdir}/nstates_{nstates}/nep_{epochs}/rep_{rep}/train_weights/",
f"../../../out/well_obs_io.zarr",
"../../../out/well_obs_io.zarr",
output:
"{outdir}/nstates_{nstates}/nep_{epochs}/rep_{rep}/preds.feather",
run:
Expand Down Expand Up @@ -174,7 +179,7 @@ def get_grp_arg(wildcards):

rule combine_metrics:
input:
f"../../../out/well_obs_io.zarr",
"../../../out/well_obs_io.zarr",
"{outdir}/nstates_{nstates}/nep_{epochs}/rep_{rep}/trn_preds.feather",
"{outdir}/nstates_{nstates}/nep_{epochs}/rep_{rep}/val_preds.feather",
"{outdir}/nstates_{nstates}/nep_{epochs}/rep_{rep}/val_times_preds.feather"
Expand Down Expand Up @@ -226,3 +231,50 @@ rule plot_prepped_data:
partition=wildcards.partition)


rule calc_functional_performance_one:
input:
"../../../out/well_obs_io.zarr",
"{outdir}/nstates_{nstates}/nep_{epochs}/rep_{rep}/preds.feather"
output:
"{outdir}/nstates_{nstates}/nep_{epochs}/rep_{rep}/func_perf/{site}-{src}-{snk}-{model}.csv"
run:
calc_it_metrics_site(input[0],
input[1],
wildcards.src,
wildcards.snk,
wildcards.site,
log_transform=False,
model=wildcards.model,
replicate=wildcards.rep,
outfile=output[0])


def get_func_perf_sites():
input_file = "../../../out/well_obs_io.zarr"
inputs = xr.open_zarr(input_file, consolidated=False)
inputs_df = inputs.to_dataframe()

sites = inputs_df.index.unique('site_id')
sites = sites.drop(['014721254', '014721259'])
return sites


rule gather_func_performances:
input:
expand("{outdir}/nstates_{nstates}/nep_{epochs}/rep_{rep}/func_perf/{site}-{src}-{snk}-{{model}}.csv",
outdir=out_dir,
nstates=config['hidden_size'],
epochs=config['epochs'],
rep=list(range(config['num_replicates'])),
site=get_func_perf_sites(),
src=['tmmx'],
snk=['do_min', 'do_mean', 'do_max'])
output:
"{outdir}/{model}_func_perf.csv"
run:
df_list = []
for in_file in input:
df = pd.read_csv(in_file, dtype={"site": str})
df_list.append(df)
df_comb = pd.concat(df_list)
df_comb.to_csv(output[0], index=False)

This file was deleted.

Loading

0 comments on commit 080508c

Please sign in to comment.