Skip to content
This repository was archived by the owner on Jun 2, 2023. It is now read-only.

Commit

Permalink
Merge pull request #208 from janetrbarclay/drop_exclude
Browse files Browse the repository at this point in the history
Drop exclude_file parameters and functions
  • Loading branch information
janetrbarclay authored Nov 10, 2022
2 parents 23cf6a3 + 825c3e0 commit c184cbd
Show file tree
Hide file tree
Showing 7 changed files with 0 additions and 150 deletions.
144 changes: 0 additions & 144 deletions river_dl/preproc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,102 +276,6 @@ def reshape_for_training(data):
return np.reshape(data, [n_batch * n_seg, seq_len, n_feat])


def get_exclude_start_end(exclude_grp):
"""
get the start and end dates for the exclude group
:param exclude_grp: [dict] dictionary representing the exclude group from
the exclude yml file
:return: [tuple of datetime objects] start date, end date
"""
start = exclude_grp.get("start_date")
if start:
start = datetime.datetime.strptime(start, "%Y-%m-%d")

end = exclude_grp.get("end_date")
if end:
end = datetime.datetime.strptime(end, "%Y-%m-%d")
return start, end


def get_exclude_vars(exclude_grp):
"""
get the variables_to_log to exclude for the exclude group
:param exclude_grp: [dict] dictionary representing the exclude group from
the exclude yml file
:return: [list] variables_to_log to exclude
"""
variable = exclude_grp.get("variable")
if not variable or variable == "both":
return ["seg_tave_water", "seg_outflow"]
elif variable == "temp":
return ["seg_tave_water"]
elif variable == "flow":
return ["seg_outflow"]
else:
raise ValueError("exclude variable must be flow, temp, or both")


def get_exclude_seg_ids(exclude_grp, all_segs):
"""
get the segments to exclude
:param exclude_grp: [dict] dictionary representing the exclude group from
the exclude yml file
:param all_segs: [array] all of the segments. this is needed if we are doing
a reverse exclusion
:return: [list like] the segments to exclude
"""
# ex_segs are the sites to exclude
if "seg_id_nats_ex" in exclude_grp.keys():
ex_segs = exclude_grp["seg_id_nats_ex"]
# exclude all *but* the "seg_id_nats_in"
elif "seg_id_nats_in" in exclude_grp.keys():
ex_mask = ~all_segs.isin(exclude_grp["seg_id_nats_in"])
ex_segs = all_segs[ex_mask]
else:
ex_segs = all_segs
return ex_segs


def exclude_segments(y_data, exclude_segs):
"""
exclude segments from being trained on by setting their weights as zero
:param y_data:[xr dataset] y_dataset data. this is used to get the dimensions
:param exclude_segs: [list] list of segments to exclude in the loss
calculation
:return:
"""
weights = initialize_weights(y_data, 1)
for seg_grp in exclude_segs:
# get the start and end dates is present
start, end = get_exclude_start_end(seg_grp)
exclude_vars = get_exclude_vars(seg_grp)
segs_to_exclude = get_exclude_seg_ids(seg_grp, weights.seg_id_nat)

# loop through the data_vars
for v in exclude_vars:
# set those weights to zero
weights[v].load()
weights[v].loc[
dict(date=slice(start, end), seg_id_nat=segs_to_exclude)
] = 0
return weights


def initialize_weights(y_data, initial_val=1):
"""
initialize all weights with a value.
:param y_data:[xr dataset] y_dataset data. this is used to get the dimensions
:param initial_val: [num] a number to initialize the weights with. should
be between 0 and 1 (inclusive)
:return: [xr dataset] dataset weights initialized with a uniform value
"""
weights = y_data.copy(deep=True)
for v in y_data.data_vars:
weights[v].load()
weights[v].loc[:, :] = initial_val
return weights


def reduce_training_data_random(
data_file,
train_start_date="1980-10-01",
Expand Down Expand Up @@ -600,7 +504,6 @@ def prep_y_data(
time_idx_name="date",
seq_len=365,
log_vars=None,
exclude_file=None,
normalize_y=True,
y_type="obs",
y_std=None,
Expand Down Expand Up @@ -637,7 +540,6 @@ def prep_y_data(
sites will be witheld from training and validation
:param seq_len: [int] length of sequences (e.g., 365)
:param log_vars: [list-like] which variables_to_log (if any) to take log of
:param exclude_file: [str] path to exclude file
:param normalize_y: [bool] whether or not to normalize the y_dataset values
:param y_type: [str] "obs" if observations or "pre" if pretraining
:param y_std: [array-like] standard deviations of y_dataset variables_to_log
Expand Down Expand Up @@ -683,12 +585,6 @@ def prep_y_data(
if log_vars:
y_trn = log_variables(y_trn, log_vars)

# filter pretrain/finetune y_dataset
if exclude_file:
exclude_segs = read_exclude_segs_file(exclude_file)
y_wgts = exclude_segments(y_trn, exclude_segs=exclude_segs)
else:
y_wgts = initialize_weights(y_trn)
# scale y_dataset training data and get the mean and std
# scale the validation partition to benchmark epoch performance
if normalize_y:
Expand All @@ -713,9 +609,6 @@ def prep_y_data(
"y_obs_trn": convert_batch_reshape(
y_trn, spatial_idx_name, time_idx_name, offset=trn_offset, seq_len=seq_len
),
"y_obs_wgts": convert_batch_reshape(
y_wgts, spatial_idx_name, time_idx_name, offset=trn_offset, seq_len=seq_len
),
"y_obs_val": convert_batch_reshape(
y_val, spatial_idx_name, time_idx_name, offset=tst_val_offset, seq_len=seq_len
),
Expand Down Expand Up @@ -768,7 +661,6 @@ def prep_all_data(
dist_type="updown",
catch_prop_file=None,
catch_prop_vars=None,
exclude_file=None,
log_y_vars=False,
out_file=None,
segs=None,
Expand Down Expand Up @@ -823,7 +715,6 @@ def prep_all_data(
left unfilled, the catchment properties will not be included as predictors
:param catch_prop_vars: [list of str] list of catchment properties to use. If
left unfilled and a catchment property file is supplied all variables will be used.
:param exclude_file: [str] path to exclude file
:param log_y_vars: [bool] whether or not to take the log of discharge in
training
:param segs: [list-like] which segments to prepare the data for
Expand Down Expand Up @@ -1005,7 +896,6 @@ def prep_all_data(
time_idx_name=time_idx_name,
seq_len=seq_len,
log_vars=log_y_vars,
exclude_file=exclude_file,
normalize_y=normalize_y,
y_type="obs",
trn_offset = trn_offset,
Expand All @@ -1028,7 +918,6 @@ def prep_all_data(
time_idx_name=time_idx_name,
seq_len=seq_len,
log_vars=log_y_vars,
exclude_file=exclude_file,
normalize_y=normalize_y,
y_type="pre",
y_std=y_obs_data["y_std"],
Expand All @@ -1053,7 +942,6 @@ def prep_all_data(
time_idx_name=time_idx_name,
seq_len=seq_len,
log_vars=log_y_vars,
exclude_file=exclude_file,
normalize_y=normalize_y,
y_type="pre",
trn_offset = trn_offset,
Expand Down Expand Up @@ -1118,35 +1006,3 @@ def prep_adj_matrix(infile, dist_type, dist_idx_name, segs=None, out_file=None):
np.savez_compressed(out_file, dist_matrix=A_hat)
return A_hat


def read_exclude_segs_file(exclude_file):
"""
read the exclude segs file. should be a yml file with start_date and list of
segments to exclude
--
example exclude file:
group_after_2017:
start_date: "2017-10-01"
variable: "temp"
seg_id_nats_ex:
- 1556
- 1569
group_2018_water_year:
start_date: "2017-10-01"
end_date: "2018-10-01"
seg_id_nats_ex:
- 1653
group_all_time:
seg_id_nats_in:
- 1806
- 2030
--
:param exclude_file: [str] exclude segs file
:return: [list] list of dictionaries of segments to exclude. dict keys must
have 'seg_id_nats' and may also have 'start_date' and 'end_date'
"""
with open(exclude_file, "r") as s:
d = yaml.safe_load(s)
return [val for key, val in d.items()]
1 change: 0 additions & 1 deletion workflow_examples/Snakefile_basic.smk
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ rule prep_io_data:
spatial_idx_name='segs_test',
time_idx_name='times_test',
catch_prop_file=None,
exclude_file=None,
train_start_date=config['train_start_date'],
train_end_date=config['train_end_date'],
val_start_date=config['val_start_date'],
Expand Down
1 change: 0 additions & 1 deletion workflow_examples/Snakefile_gwn.smk
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ rule prep_io_data:
y_vars_pretrain=config['y_vars_pretrain'],
y_vars_finetune=config['y_vars_finetune'],
catch_prop_file=None,
exclude_file=None,
train_start_date=config['train_start_date'],
train_end_date=config['train_end_date'],
val_start_date=config['val_start_date'],
Expand Down
1 change: 0 additions & 1 deletion workflow_examples/Snakefile_pretrain_LSTM.smk
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ rule prep_io_data:
spatial_idx_name='segs_test',
time_idx_name='times_test',
catch_prop_file=None,
exclude_file=None,
train_start_date=config['train_start_date'],
train_end_date=config['train_end_date'],
val_start_date=config['val_start_date'],
Expand Down
1 change: 0 additions & 1 deletion workflow_examples/Snakefile_rgcn.smk
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ rule prep_io_data:
y_vars_pretrain=config['y_vars_pretrain'],
y_vars_finetune=config['y_vars_finetune'],
catch_prop_file=None,
exclude_file=None,
train_start_date=config['train_start_date'],
train_end_date=config['train_end_date'],
val_start_date=config['val_start_date'],
Expand Down
1 change: 0 additions & 1 deletion workflow_examples/Snakefile_rgcn_hypertune.smk
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ rule prep_io_data:
y_vars_pretrain=config['y_vars_pretrain'],
y_vars_finetune=config['y_vars_finetune'],
catch_prop_file=None,
exclude_file=None,
train_start_date=config['train_start_date'],
train_end_date=config['train_end_date'],
val_start_date=config['val_start_date'],
Expand Down
1 change: 0 additions & 1 deletion workflow_examples/Snakefile_rgcn_pytorch.smk
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ rule prep_io_data:
y_vars_pretrain=config['y_vars_pretrain'],
y_vars_finetune=config['y_vars_finetune'],
catch_prop_file=None,
exclude_file=None,
train_start_date=config['train_start_date'],
train_end_date=config['train_end_date'],
val_start_date=config['val_start_date'],
Expand Down

0 comments on commit c184cbd

Please sign in to comment.