Skip to content

Commit

Permalink
Fixes for geopandas-1.0.0 (#98)
Browse files Browse the repository at this point in the history
  • Loading branch information
mwtoews authored Jun 26, 2024
1 parent eff9f20 commit 2cb24c6
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 23 deletions.
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ jobs:
pytest -v -n2 --cov --cov-append
- name: Run doctest
if: matrix.python-version == '3.12'
run: |
pytest -v swn --doctest-modules
Expand Down
28 changes: 28 additions & 0 deletions swn/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import contextlib
import warnings

import geopandas
import numpy as np
import shapely
from packaging.version import Version
Expand Down Expand Up @@ -60,3 +61,30 @@ def ignore_shapely_warnings_for_object_array():
@contextlib.contextmanager
def ignore_shapely_warnings_for_object_array():
yield


GEOPANDAS_GE_100 = Version(geopandas.__version__) >= Version("1.0.0")


def sjoin_idx_names(left_df, right_df):
"""Returns left and right index names from geopandas.sjoin methods.
Handles breaking change from geopandas 1.0.0.
"""
left_idx_name = left_df.index.name or "index"
if GEOPANDAS_GE_100:
right_idx_name = right_df.index.name or "index"
# add _left/_right if needed
if left_df.index.name and (
left_idx_name == right_idx_name or left_idx_name in right_df.columns
):
left_idx_name += "_left"
if (
right_df.index.name is None
or right_idx_name in left_df.columns
or right_idx_name == left_df.index.name
):
right_idx_name += "_right"
else:
right_idx_name = "index_right"
return left_idx_name, right_idx_name
35 changes: 24 additions & 11 deletions swn/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
import shapely
from shapely.geometry import LineString, Point

from .compat import SHAPELY_GE_20, ignore_shapely_warnings_for_object_array
from .compat import (
SHAPELY_GE_20,
ignore_shapely_warnings_for_object_array,
sjoin_idx_names,
)
from .spatial import bias_substring
from .util import abbr_str

Expand Down Expand Up @@ -234,13 +238,13 @@ def from_lines(cls, lines, polygons=None):
end_pts = obj.segments.interpolate(1.0, normalized=True)
start_df = start_pts.to_frame("start").set_geometry("start")
end_df = end_pts.to_frame("end").set_geometry("end")
segidxname = obj.segments.index.name or "index"
# This is the main component of the algorithm
end_idx_name, start_idx_name = sjoin_idx_names(end_df, start_df)
jxn = pd.DataFrame(
geopandas.sjoin(end_df, start_df, "inner", "intersects")
.drop(columns="end")
.reset_index()
.rename(columns={segidxname: "end", "index_right": "start"})
.rename(columns={end_idx_name: "end", start_idx_name: "start"})
)
# Group end points to start points, list should only have 1 item
to_segnum_l = jxn.groupby("end")["start"].agg(list)
Expand Down Expand Up @@ -281,10 +285,11 @@ def from_lines(cls, lines, polygons=None):
# Find outlets that join to a single coodinate
multi_outlets = set()
out_pts = end_pts.loc[outlets].to_frame("out").set_geometry("out")
left_idx_name, right_idx_name = sjoin_idx_names(out_pts, out_pts)
jout = pd.DataFrame(
geopandas.sjoin(out_pts, out_pts, "inner")
.reset_index()
.rename(columns={segidxname: "out1", "index_right": "out2"})
.rename(columns={left_idx_name: "out1", right_idx_name: "out2"})
).query("out1 != out2")
if jout.size > 0:
# Just evaluate 2D tuple to find segnums with same location
Expand All @@ -306,11 +311,12 @@ def from_lines(cls, lines, polygons=None):
obj.warnings.append(m[0] % m[1:])
multi_outlets |= v
# Find outlets that join to middle of other segments
left_idx_name, right_idx_name = sjoin_idx_names(out_pts, obj.segments)
joutseg = pd.DataFrame(
geopandas.sjoin(out_pts, obj.segments[["geometry"]], "inner")
.drop(columns="out")
.reset_index()
.rename(columns={segidxname: "out", "index_right": "segnum"})
.rename(columns={left_idx_name: "out", right_idx_name: "segnum"})
)
for r in joutseg.query("out != segnum").itertuples():
if r.out in multi_outlets:
Expand All @@ -320,10 +326,11 @@ def from_lines(cls, lines, polygons=None):
obj.errors.append(m[0] % m[1:])
# Find headwater that join to a single coodinate
hw_pts = start_pts.loc[headwater].to_frame("hw").set_geometry("hw")
hw_idx_name, start_idx_name = sjoin_idx_names(hw_pts, start_df)
jhw = pd.DataFrame(
geopandas.sjoin(hw_pts, start_df, "inner")
.reset_index()
.rename(columns={segidxname: "hw1", "index_right": "start"})
.rename(columns={hw_idx_name: "hw1", start_idx_name: "start"})
).query("hw1 != start")
obj.jhw = jhw
if jhw.size > 0:
Expand Down Expand Up @@ -1112,9 +1119,8 @@ def locate_geoms(
catchments_df = self.catchments.to_frame("geometry")
if catchments_df.crs is None and self.segments.crs is not None:
catchments_df.crs = self.segments.crs
match_s = geopandas.sjoin(res[sel], catchments_df, "inner")[
"index_right"
]
_, ridxn = sjoin_idx_names(res, catchments_df)
match_s = geopandas.sjoin(res[sel], catchments_df, "inner")[ridxn]
match_s.name = "segnum"
match_s.index.name = "gidx"
match = match_s.reset_index()
Expand Down Expand Up @@ -1200,9 +1206,10 @@ def find_downstream_in_min_stream_order(segnum):
)
try:
# faster method, not widely available
_, right_idx_name = sjoin_idx_names(res, self.segments)
match_s = geopandas.sjoin_nearest(
res[sel], segments_gs.to_frame(), "inner"
)["index_right"]
)[right_idx_name]
has_sjoin_nearest = True
except (AttributeError, NotImplementedError):
has_sjoin_nearest = False
Expand Down Expand Up @@ -1305,7 +1312,13 @@ def find_downstream_in_min_stream_order(segnum):
linestring_empty = wkt.loads("LINESTRING EMPTY")
for idx in res[~sel].index:
res.at[idx, "link"] = linestring_empty
res.set_geometry("link", drop=True, inplace=True)
res = (
res.set_geometry("link")
.drop(columns="geometry")
.rename_geometry("geometry")
)
if geom_crs:
res.set_crs(geom_crs, inplace=True)
res["dist_to_seg"] = res[sel].length
return res

Expand Down
6 changes: 2 additions & 4 deletions swn/modflow/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,10 +930,8 @@ def do_linemerge(ij, df, drop_reach_ids):
if is_spatial:
try:
match_s = geopandas.sjoin_nearest(
diversions_in_model, obj.grid_cells, "inner"
)[["index_right0", "index_right1"]].rename(
columns={"index_right0": "i", "index_right1": "j"}
)
diversions_in_model, obj.grid_cells.reset_index(), "inner"
)[["i", "j"]]
match_s.index.name = "divid"
match = match_s.reset_index()
has_sjoin_nearest = True
Expand Down
8 changes: 4 additions & 4 deletions swn/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,10 +511,10 @@ def location_pair_geoms(pairs, loc_df, n):
>>> pair_gdf.sort_values("length", ascending=False, inplace=True)
>>> pair_gdf
geometry length
11 14 LINESTRING (378.491 404.717, 420.000 330.000, ... 282.359779
12 13 LINESTRING (728.462 227.692, 710.000 160.000, ... 184.465938
15 13 LINESTRING (692.027 172.838, 710.000 160.000, ... 136.388347
14 15 LINESTRING (595.730 241.622, 692.027 172.838) 118.340096
11 14 LINESTRING (378.491 404.717, 420 330, 584 250,... 282.359779
12 13 LINESTRING (728.462 227.692, 710 160, 770 100,... 184.465938
15 13 LINESTRING (692.027 172.838, 710 160, 770 100,... 136.388347
14 15 LINESTRING (595.73 241.622, 692.027 172.838) 118.340096
"""
from shapely.ops import linemerge, substring, unary_union
Expand Down
20 changes: 16 additions & 4 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from shapely.geometry import LineString, Point

import swn
from swn.compat import ignore_shapely_warnings_for_object_array
from swn.compat import GEOPANDAS_GE_100, ignore_shapely_warnings_for_object_array
from swn.spatial import force_2d, round_coords

from .conftest import matplotlib, plt
Expand Down Expand Up @@ -1237,7 +1237,11 @@ def test_locate_geoms_in_basic_swn(caplog):
a = r2.geometry.apply(lambda x: Point(*x.coords[0]))
assert (a.distance(gs.drop(index=16)) == 0.0).all()
b = r2.geometry.apply(lambda x: Point(*x.coords[-1]))
seg_mls = n.segments.geometry[r2.segnum].unary_union
seg_geoms = n.segments.geometry[r2.segnum]
if GEOPANDAS_GE_100:
seg_mls = seg_geoms.union_all()
else:
seg_mls = seg_geoms.unary_union
assert (b.distance(seg_mls) < 1e-10).all()
# now check the empty geometry
for k in e.keys():
Expand Down Expand Up @@ -1342,7 +1346,11 @@ def test_locate_geoms_only_lines(coastal_geom, coastal_swn):
assert (r.geometry.apply(lambda g: len(g.coords)) == 2).all()
a = r.geometry.interpolate(0.0)
b = r.geometry.interpolate(1.0, normalized=True)
seg_mls = coastal_swn.segments.geometry[r.segnum].unary_union
seg_geoms = coastal_swn.segments.geometry[r.segnum]
if GEOPANDAS_GE_100:
seg_mls = seg_geoms.union_all()
else:
seg_mls = seg_geoms.unary_union
assert (a.distance(coastal_geom) < 1e-10).all()
assert (a.distance(seg_mls) > 0.0).all()
assert (b.distance(coastal_geom) > 0.0).all()
Expand Down Expand Up @@ -1407,7 +1415,11 @@ def test_locate_geoms_with_catchments(coastal_geom, coastal_swn_w_poly):
assert (r.geometry.apply(lambda g: len(g.coords)) == 2).all()
a = r.geometry.interpolate(0.0)
b = r.geometry.interpolate(1.0, normalized=True)
seg_mls = coastal_swn_w_poly.segments.geometry[r.segnum].unary_union
seg_geoms = coastal_swn_w_poly.segments.geometry[r.segnum]
if GEOPANDAS_GE_100:
seg_mls = seg_geoms.union_all()
else:
seg_mls = seg_geoms.unary_union
assert (a.distance(coastal_geom) < 1e-10).all()
assert (a.distance(seg_mls) > 0.0).all()
assert (b.distance(coastal_geom) > 0.0).all()
Expand Down
50 changes: 50 additions & 0 deletions tests/test_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Test compat module."""

import geopandas
from shapely.geometry import Point

from swn import compat


def test_sjoin_idx_names():
noname_idx_df = geopandas.GeoDataFrame(geometry=[Point(0, 1)])
left_idx_name, right_idx_name = compat.sjoin_idx_names(noname_idx_df, noname_idx_df)
assert left_idx_name == "index"
assert right_idx_name == "index_right"
sj = geopandas.sjoin(noname_idx_df, noname_idx_df)
assert sj.index.name is None
assert left_idx_name not in sj.columns
assert right_idx_name in sj.columns

name_idx_df = noname_idx_df.copy()
name_idx_df.index.name = "idx"
left_idx_name, right_idx_name = compat.sjoin_idx_names(name_idx_df, name_idx_df)
if compat.GEOPANDAS_GE_100:
assert left_idx_name == "idx_left"
assert right_idx_name == "idx_right"
else:
assert left_idx_name == "idx"
assert right_idx_name == "index_right"
sj = geopandas.sjoin(name_idx_df, name_idx_df)
assert sj.index.name == left_idx_name
assert left_idx_name not in sj.columns
assert right_idx_name in sj.columns

left_idx_name, right_idx_name = compat.sjoin_idx_names(noname_idx_df, name_idx_df)
assert left_idx_name == "index"
if compat.GEOPANDAS_GE_100:
assert right_idx_name == "idx"
else:
assert right_idx_name == "index_right"
sj = geopandas.sjoin(noname_idx_df, name_idx_df)
assert sj.index.name is None
assert left_idx_name not in sj.columns
assert right_idx_name in sj.columns

left_idx_name, right_idx_name = compat.sjoin_idx_names(name_idx_df, noname_idx_df)
assert left_idx_name == "idx"
assert right_idx_name == "index_right"
sj = geopandas.sjoin(name_idx_df, noname_idx_df)
assert sj.index.name == left_idx_name
assert left_idx_name not in sj.columns
assert right_idx_name in sj.columns

0 comments on commit 2cb24c6

Please sign in to comment.