Skip to content

Commit

Permalink
Accommodate different reach index names for future flopy versions (#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
mwtoews authored Sep 30, 2023
1 parent ec9d851 commit b774d8a
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 189 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ jobs:
run: |
pytest -v swn --doctest-modules
- name: Run tests with older packages
- name: Run tests with develop flopy and other older packages
run: |
pip install https://github.com/modflowpy/flopy/archive/refs/heads/develop.zip
pip install "shapely<2.0" "pandas<2.0"
pytest -v -n2 --cov --cov-append
Expand Down
67 changes: 35 additions & 32 deletions swn/modflow/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,34 @@ def __init__(self, logger=None):
logger : logging.Logger, optional
Logger to show messages.
"""
from importlib.util import find_spec

if not find_spec("flopy"):
raise ImportError(f"{self.__class__.__name__} requires flopy")
from ..logger import get_logger, logging

class_name = self.__class__.__name__
try:
import flopy
except ImportError:
raise ImportError(f"{class_name} requires flopy")

if logger is None:
self.logger = get_logger(self.__class__.__name__)
self.logger = get_logger(class_name)
elif isinstance(logger, logging.Logger):
self.logger = logger
else:
raise ValueError(f"expected 'logger' to be Logger; found {type(logger)!r}")
self.logger.info("creating new %s object", self.__class__.__name__)
self.logger.info("creating new %s object", class_name)
if class_name == "SwnModflow":
self.domain_label = "ibound"
self.reach_index_name = "reachID"
elif class_name == "SwnMf6":
self.domain_label = "idomain"
for block in flopy.mf6.ModflowGwfsfr.dfn:
if block[0] == "block packagedata" and block[1] != "name packagedata":
self.reach_index_name = block[1][5:]
break
else:
self.logger.error("cannot determine reach index name for GWF-SFR")
else:
self.logger.error("unsupported subclass %r", class_name)

def __iter__(self):
"""Return object datasets with an iterator."""
Expand Down Expand Up @@ -213,7 +228,7 @@ def reaches(self):
Attributes
----------
reachID (SwnModflow) or rno (SwnMf6) : int, index
reachID (SwnModflow), rno or ifno (SwnMf6) : int, index
Reach index number, starting from 1.
geometry : geometry
LineString segments, one per model cell.
Expand Down Expand Up @@ -312,11 +327,9 @@ def model(self, model):
modelgrid = model.modelgrid
modeltime = model.modeltime
if this_class == "SwnModflow":
domain_label = "ibound"
domain = model.bas6.ibound[0].array.copy()
perlen = pd.Series(model.dis.perlen.array)
elif this_class == "SwnMf6":
domain_label = "idomain"
domain = dis.idomain.array[0].copy()
nper = sim.tdis.nper.data
perlen = pd.Series(sim.tdis.perioddata.array.perlen)
Expand Down Expand Up @@ -392,7 +405,7 @@ def model(self, model):
cols, rows = np.meshgrid(np.arange(ncol), np.arange(nrow))
grid_df = pd.DataFrame({"i": rows.flatten(), "j": cols.flatten()})
grid_df.set_index(["i", "j"], inplace=True)
grid_df[domain_label] = domain.flatten()
grid_df[self.domain_label] = domain.flatten()
# Note: modelgrid.get_cell_vertices(i, j) is slow!
xv = modelgrid.xvertices
yv = modelgrid.yvertices
Expand Down Expand Up @@ -439,13 +452,9 @@ def from_swn_flopy(
"""
this_class = cls.__name__
if this_class == "SwnModflow":
domain_label = "ibound"
reach_index_name = "reachID"
reach_length_name = "rchlen"
uses_segments = True
elif this_class == "SwnMf6":
domain_label = "idomain"
reach_index_name = "rno"
reach_length_name = "rlen"
uses_segments = False
else:
Expand All @@ -466,17 +475,15 @@ def from_swn_flopy(
dis = model.dis
grid_cells = obj.grid_cells.copy()
if domain_action == "freeze":
sel = grid_cells[domain_label] != 0
sel = grid_cells[obj.domain_label] != 0
if sel.any():
# Remove any inactive grid cells from analysis
grid_cells = grid_cells.loc[sel]
_ = grid_cells.sindex # create spatial index
num_domain_modified = 0
if this_class == "SwnModflow":
domain_label = "ibound"
domain = model.bas6.ibound[0].array.copy()
elif this_class == "SwnMf6":
domain_label = "idomain"
domain = dis.idomain.array[0].copy()
else:
raise TypeError(f"unsupported subclass {cls!r}")
Expand Down Expand Up @@ -814,27 +821,27 @@ def do_linemerge(ij, df, drop_reach_ids):
if domain_action == "modify" and domain[i, j] == 0:
num_domain_modified += 1
domain[i, j] = 1
obj.grid_cells[domain_label].at[i, j] = 1
obj.grid_cells[obj.domain_label].at[i, j] = 1

if domain_action == "modify":
if num_domain_modified:
obj.logger.debug(
"updating %d cells from %s array for top layer",
num_domain_modified,
domain_label.upper(),
obj.domain_label.upper(),
)
if domain_label == "ibound":
if obj.domain_label == "ibound":
obj.model.bas6.ibound[0] = domain
elif domain_label == "idomain":
elif obj.domain_label == "idomain":
obj.model.dis.idomain.set_data(domain, layer=0)
reaches = reaches.merge(
grid_cells[[domain_label]], left_on=["i", "j"], right_index=True
grid_cells[[obj.domain_label]], left_on=["i", "j"], right_index=True
)
reaches.rename(
columns={domain_label: f"prev_{domain_label}"}, inplace=True
columns={obj.domain_label: f"prev_{obj.domain_label}"}, inplace=True
)
else:
reaches[f"prev_{domain_label}"] = 1
reaches[f"prev_{obj.domain_label}"] = 1

# Mark segments that are not used
segments["in_model"] = True
Expand Down Expand Up @@ -1012,7 +1019,7 @@ def do_linemerge(ij, df, drop_reach_ids):
reaches.loc[sel, reach_length_name] = 1.0

reaches.reset_index(inplace=True, drop=True)
reaches.index.name = reach_index_name
reaches.index.name = obj.reach_index_name
reaches.index += 1 # flopy series starts at one

if not hasattr(reaches.geometry, "geom_type"):
Expand Down Expand Up @@ -1352,7 +1359,7 @@ def get_location_frame_reach_info(
Returns
-------
pandas.DataFrame
- reachID (SwnModflow) or rno (SwnMf6)
- reachID (SwnModflow), rno or ifno (SwnMf6)
- zero-based cell index: k, i, j,
- one-based reach index: iseg, ireach
- dist_to_reach
Expand Down Expand Up @@ -1434,7 +1441,7 @@ def get_location_frame_reach_info(
has_orig_geom = True

reach_index_name = self.reaches.index.name
if reach_index_name in loc_df.columns:
if self.reach_index_name in loc_df.columns:
self.logger.info("resetting %s from location frame", reach_index_name)
del loc_df[reach_index_name]
loc_df[reach_index_name] = -1
Expand Down Expand Up @@ -1526,11 +1533,7 @@ def plot(self, column=None, cmap="viridis_r", colorbar=False, ax=None):

grid_cells = getattr(self, "grid_cells", None)
if grid_cells is not None:
domain_label = {
"SwnModflow": "ibound",
"SwnMf6": "idomain",
}[self.__class__.__name__]
sel = self.grid_cells[domain_label] != 0
sel = self.grid_cells[self.domain_label] != 0
if sel.any():
self.grid_cells.loc[sel].plot(
ax=ax, color="whitesmoke", edgecolor="gainsboro"
Expand Down
Loading

0 comments on commit b774d8a

Please sign in to comment.