Skip to content

Commit

Permalink
ENH: add preserve_index keyword to gdf_to_nx to allow preservation of…
Browse files Browse the repository at this point in the history
… index and order (#641)

* primal

* dual

* test

* expand docs
  • Loading branch information
martinfleis authored Jul 11, 2024
1 parent 92bd918 commit efb11dd
Show file tree
Hide file tree
Showing 3 changed files with 373 additions and 10 deletions.
286 changes: 281 additions & 5 deletions docs/user_guide/graph/convert.ipynb

Large diffs are not rendered by default.

26 changes: 26 additions & 0 deletions momepy/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import networkx
import numpy as np
import pytest
from geopandas.testing import assert_geodataframe_equal
from shapely.geometry import LineString, Point

import momepy as mm
Expand Down Expand Up @@ -198,6 +199,31 @@ def test_nx_to_gdf_osmnx(self):
assert len(pts) == 7
assert len(lines) == 16

@pytest.mark.parametrize("approach", ["primal", "dual"])
def test_nx_roundtrip(self, approach):
nx = mm.gdf_to_nx(self.df_streets, preserve_index=True, approach=approach)
gdf = mm.nx_to_gdf(nx, points=False)
assert_geodataframe_equal(gdf.drop(columns="mm_len"), self.df_streets)

@pytest.mark.parametrize("approach", ["primal", "dual"])
def test_nx_roundtrip_named(self, approach):
df = self.df_streets
df.index.name = "foo"
nx = mm.gdf_to_nx(df, preserve_index=True, approach=approach)
gdf = mm.nx_to_gdf(nx, points=False)
assert_geodataframe_equal(gdf.drop(columns="mm_len"), df)
assert gdf.index.name == "foo"

@pytest.mark.parametrize("approach", ["primal", "dual"])
def test_nx_roundtrip_custom(self, approach):
df = self.df_streets
df.index = (df.index * 10).astype(str)
df.index.name = "foo"
nx = mm.gdf_to_nx(df, preserve_index=True, approach=approach)
gdf = mm.nx_to_gdf(nx, points=False)
assert_geodataframe_equal(gdf.drop(columns="mm_len"), df)
assert gdf.index.name == "foo"

def test_limit_range(self):
assert list(mm.limit_range(np.arange(10), rng=(25, 75))) == [2, 3, 4, 5, 6, 7]
assert list(mm.limit_range(np.arange(10), rng=(10, 90))) == [
Expand Down
71 changes: 66 additions & 5 deletions momepy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import libpysal
import networkx as nx
import numpy as np
import pandas as pd
from numpy.lib import NumpyVersion
from shapely.geometry import Point

Expand Down Expand Up @@ -110,10 +111,15 @@ def _angle(a, b, c):
return abs((a2 - a1 + 180) % 360 - 180)


def _generate_primal(graph, gdf_network, fields, multigraph, oneway_column=None):
def _generate_primal(
graph, gdf_network, fields, multigraph, oneway_column=None, preserve_index=False
):
"""Generate a primal graph. Helper for ``gdf_to_nx``."""
graph.graph["approach"] = "primal"

if gdf_network.index.name is not None:
graph.graph["index_name"] = gdf_network.index.name

msg = (
" This can lead to unexpected behaviour. "
"The intended usage of the conversion function "
Expand All @@ -133,13 +139,18 @@ def _generate_primal(graph, gdf_network, fields, multigraph, oneway_column=None)
category=RuntimeWarning,
stacklevel=3,
)
custom_index = not gdf_network.index.equals(pd.RangeIndex(len(gdf_network)))

for row in gdf_network.itertuples():
for i, row in enumerate(gdf_network.itertuples()):
first = row.geometry.coords[0]
last = row.geometry.coords[-1]

data = list(row)[1:]
attributes = dict(zip(fields, data, strict=True))
if preserve_index:
attributes["index_position"] = i
if custom_index:
attributes["index"] = row.Index
if multigraph:
graph.add_edge(first, last, **attributes)

Expand All @@ -154,9 +165,17 @@ def _generate_primal(graph, gdf_network, fields, multigraph, oneway_column=None)
nx.set_node_attributes(graph, node_attrs)


def _generate_dual(graph, gdf_network, fields, angles, multigraph, angle):
def _generate_dual(
graph, gdf_network, fields, angles, multigraph, angle, preserve_index
):
"""Generate a dual graph. Helper for ``gdf_to_nx``."""
graph.graph["approach"] = "dual"

if gdf_network.index.name is not None:
graph.graph["index_name"] = gdf_network.index.name

custom_index = not gdf_network.index.equals(pd.RangeIndex(len(gdf_network)))

key = 0

sw = libpysal.weights.Queen.from_dataframe(
Expand All @@ -170,6 +189,10 @@ def _generate_dual(graph, gdf_network, fields, angles, multigraph, angle):
centroid = (row.temp_x_coords, row.temp_y_coords)
data = list(row)[1:-2]
attributes = dict(zip(fields, data, strict=True))
if preserve_index:
attributes["index_position"] = i
if custom_index:
attributes["index"] = row.Index
graph.add_node(centroid, **attributes)

if sw.cardinalities[i] > 0:
Expand Down Expand Up @@ -216,6 +239,7 @@ def gdf_to_nx(
angle="angle",
oneway_column=None,
integer_labels=False,
preserve_index=False,
):
"""
Convert a LineString GeoDataFrame to a ``networkx.MultiGraph`` or other
Expand Down Expand Up @@ -258,6 +282,12 @@ def gdf_to_nx(
Convert node labels to integers. By default, node labels are tuples with (x, y)
coordinates. Set to True to encode them as integers. Note that the x, and y
coordinates are always preserved as node attributes.
preserve_index : bool, default False
Preserve information about the index of ``gdf_network``. If
``gdf_network.index`` is the default ``RangeIndex``, ``"index_position"``
attribute is added to each edge. If it is a custom index, ``"index_position"``
and ``"index"`` attributes are added. These attributes are then used by
:func:`nx_to_gdf` to faithfully roundtrip the data in the same order.
Returns
-------
Expand Down Expand Up @@ -328,14 +358,27 @@ def gdf_to_nx(
"Bidirectional lines are only supported for directed graphs."
)

_generate_primal(net, gdf_network, fields, multigraph, oneway_column)
_generate_primal(
net,
gdf_network,
fields,
multigraph,
oneway_column,
preserve_index=preserve_index,
)

elif approach == "dual":
if directed:
raise ValueError("Directed graphs are not supported in dual approach.")

_generate_dual(
net, gdf_network, fields, angles=angles, multigraph=multigraph, angle=angle
net,
gdf_network,
fields,
angles=angles,
multigraph=multigraph,
angle=angle,
preserve_index=preserve_index,
)

else:
Expand Down Expand Up @@ -373,6 +416,15 @@ def _lines_to_gdf(net, points, node_id):

if "crs" in net.graph:
gdf_edges.crs = net.graph["crs"]
if "index_position" in gdf_edges.columns:
gdf_edges = gdf_edges.sort_values("index_position").drop(
columns="index_position"
)
if "index" in gdf_edges.columns:
gdf_edges = gdf_edges.set_index("index")
else:
gdf_edges = gdf_edges.reset_index(drop=True)
gdf_edges.index.name = net.graph.get("index_name", None)

return gdf_edges

Expand Down Expand Up @@ -404,6 +456,15 @@ def _dual_to_gdf(net):
"""Generate a linestring gdf from a dual network. Helper for ``nx_to_gdf``."""
starts, edge_data = zip(*net.nodes(data=True), strict=True)
gdf_edges = gpd.GeoDataFrame(list(edge_data))
if "index_position" in gdf_edges.columns:
gdf_edges = gdf_edges.sort_values("index_position").drop(
columns="index_position"
)
if "index" in gdf_edges.columns:
gdf_edges = gdf_edges.set_index("index")
else:
gdf_edges = gdf_edges.reset_index(drop=True)
gdf_edges.index.name = net.graph.get("index_name", None)
gdf_edges.crs = net.graph["crs"]
return gdf_edges

Expand Down

0 comments on commit efb11dd

Please sign in to comment.