From 2a3074b59bba7df85f6b87946920d8dcdc53cc2e Mon Sep 17 00:00:00 2001 From: "Leaf, Andrew T" Date: Wed, 14 Feb 2024 14:05:01 -0600 Subject: [PATCH] fix(sourcedata.py::TransientArraySourceData): refactor TransientSourceDataMixin to have a general stress_period_mapping property to handle cases with and without a parent model (previously expected a parent model) --- mfsetup/sourcedata.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/mfsetup/sourcedata.py b/mfsetup/sourcedata.py index 76fd9e77..f30cd679 100644 --- a/mfsetup/sourcedata.py +++ b/mfsetup/sourcedata.py @@ -184,6 +184,7 @@ def __init__(self, period_stats, dest_model): self._period_stats = None # attributes + self.dest_model = dest_model self.perioddata = dest_model.perioddata.sort_values(by='per').reset_index(drop=True) @property @@ -192,6 +193,21 @@ def period_stats(self): self._period_stats = self.get_period_stats() return self._period_stats + @property + def stress_period_mapping(self): + # if there is a parent/source model, + # get the mapping between the parent model and + # inset/destination model stress periods {inset_kper: parent_kper} + if self.dest_model.parent is not None: + # for now, just assume one-to-one correspondance + # between source and dest model stress periods + return self.dest_model.parent_stress_periods + # otherwise, just return a dictionary of the same + # key, value pairs for consistency + # with logic of subclass get_data() methods + else: + return dict(zip(self.perioddata['per'], self.perioddata['per'])) + def get_period_stats(self): """Populate each stress period with period_stat information for temporal resampling (tdis.aggregate_dataframe_to_stress_period and @@ -576,7 +592,6 @@ def __init__(self, filenames, variable, period_stats=None, self.variable = variable self.resample_method = resample_method - self.dest_model = dest_model def get_data(self): @@ -599,11 +614,9 @@ def get_data(self): # would follow logic of netcdf files, but trickier because steady-state periods need to be handled #da = transient2d_to_xarray(data, time) - # for now, just assume one-to-one correspondance - # between source and dest model stress periods results = {} - for inset_kper, parent_kper in self.dest_model.parent_stress_periods.items(): - data = source_data[parent_kper].copy() + for dest_kper, source_kper in self.stress_period_mapping.items(): + data = source_data[source_kper].copy() if regrid: # sample the data onto the model grid resampled = self.regrid_from_source_model(data, method=self.resample_method) @@ -612,7 +625,7 @@ def get_data(self): # reshape results to model grid period_mean2d = resampled.reshape(self.dest_model.nrow, self.dest_model.ncol) - results[inset_kper] = period_mean2d + results[dest_kper] = period_mean2d self.data = results return results