Skip to content

Commit

Permalink
Fixed inconsistencies in geodataframe index names across io functions…
Browse files Browse the repository at this point in the history
… and gis.

Added new tests to check for consistent index naming.

Fix #433 - Gis files written with column title 'name' (#435)

---------

Co-authored-by: Angus <[email protected]>
Co-authored-by: kbonney <[email protected]>
  • Loading branch information
3 people authored Aug 12, 2024
1 parent 0a750e7 commit 31004a9
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 31 deletions.
41 changes: 22 additions & 19 deletions documentation/gis.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,13 @@ For example, the junctions GeoDataFrame contains the following information:
:skipif: gpd is None

>>> print(wn_gis.junctions.head())
node_type elevation initial_quality geometry
10 Junction 216.408 5.000e-04 POINT (20.00000 70.00000)
11 Junction 216.408 5.000e-04 POINT (30.00000 70.00000)
12 Junction 213.360 5.000e-04 POINT (50.00000 70.00000)
13 Junction 211.836 5.000e-04 POINT (70.00000 70.00000)
21 Junction 213.360 5.000e-04 POINT (30.00000 40.00000)
node_type elevation initial_quality geometry
name
10 Junction 216.408 5.000e-04 POINT (20.00000 70.00000)
11 Junction 216.408 5.000e-04 POINT (30.00000 70.00000)
12 Junction 213.360 5.000e-04 POINT (50.00000 70.00000)
13 Junction 211.836 5.000e-04 POINT (70.00000 70.00000)
21 Junction 213.360 5.000e-04 POINT (30.00000 40.00000)

Each GeoDataFrame contains attributes and geometry:

Expand Down Expand Up @@ -333,21 +334,23 @@ and then translates the GeoDataFrames coordinates to EPSG:3857.

>>> wn_gis = wntr.network.to_gis(wn, crs='EPSG:4326')
>>> print(wn_gis.junctions.head())
node_type elevation initial_quality geometry
10 Junction 216.408 5.000e-04 POINT (20.00000 70.00000)
11 Junction 216.408 5.000e-04 POINT (30.00000 70.00000)
12 Junction 213.360 5.000e-04 POINT (50.00000 70.00000)
13 Junction 211.836 5.000e-04 POINT (70.00000 70.00000)
21 Junction 213.360 5.000e-04 POINT (30.00000 40.00000)

node_type elevation initial_quality geometry
name
10 Junction 216.408 5.000e-04 POINT (20.00000 70.00000)
11 Junction 216.408 5.000e-04 POINT (30.00000 70.00000)
12 Junction 213.360 5.000e-04 POINT (50.00000 70.00000)
13 Junction 211.836 5.000e-04 POINT (70.00000 70.00000)
21 Junction 213.360 5.000e-04 POINT (30.00000 40.00000)

>>> wn_gis.to_crs('EPSG:3857')
>>> print(wn_gis.junctions.head())
node_type elevation initial_quality geometry
10 Junction 216.408 5.000e-04 POINT (2226389.816 11068715.659)
11 Junction 216.408 5.000e-04 POINT (3339584.724 11068715.659)
12 Junction 213.360 5.000e-04 POINT (5565974.540 11068715.659)
13 Junction 211.836 5.000e-04 POINT (7792364.356 11068715.659)
21 Junction 213.360 5.000e-04 POINT (3339584.724 4865942.280)
node_type elevation initial_quality geometry
name
10 Junction 216.408 5.000e-04 POINT (2226389.816 11068715.659)
11 Junction 216.408 5.000e-04 POINT (3339584.724 11068715.659)
12 Junction 213.360 5.000e-04 POINT (5565974.540 11068715.659)
13 Junction 211.836 5.000e-04 POINT (7792364.356 11068715.659)
21 Junction 213.360 5.000e-04 POINT (3339584.724 4865942.280)

Snap point geometries to the nearest point or line
----------------------------------------------------
Expand Down
5 changes: 2 additions & 3 deletions wntr/gis/geospatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ def snap(A, B, tolerance):
assert A.crs == B.crs

# Modify B to include "indexB" as a separate column
B = B.reset_index()
B.rename(columns={'index':'indexB'}, inplace=True)
B = B.reset_index(names='indexB')

# Define the coordinate reference system, based on B
crs = B.crs
Expand Down Expand Up @@ -228,7 +227,7 @@ def intersect(A, B, B_value=None, include_background=False, background_value=0):

n = intersects.groupby('_tmp_index_name')['geometry'].count()
B_indices = intersects.groupby('_tmp_index_name')['index_right'].apply(list)
stats = pd.DataFrame(index=A.index, data={'intersections': B_indices,
stats = pd.DataFrame(index=A.index.copy(), data={'intersections': B_indices,
'n': n,})
stats['n'] = stats['n'].fillna(0)
stats['n'] = stats['n'].apply(int)
Expand Down
7 changes: 3 additions & 4 deletions wntr/gis/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ def _extract_geodataframe(df, crs=None, links_as_points=False):
# Set index
if len(df) > 0:
df.set_index('name', inplace=True)
df.index.name = None

df = gpd.GeoDataFrame(df, crs=crs, geometry=geom)
else:
Expand Down Expand Up @@ -300,7 +299,7 @@ def add_link_attributes(self, values, name):
self.pumps[name] = np.nan
self.pumps.loc[link_name, name] = value

def _read(self, files, index_col='index'):
def _read(self, files, index_col='name'):

if 'junctions' in files.keys():
data = gpd.read_file(files['junctions']).set_index(index_col)
Expand All @@ -321,7 +320,7 @@ def _read(self, files, index_col='index'):
data = gpd.read_file(files['valves']).set_index(index_col)
self.valves = pd.concat([self.valves, data])

def read_geojson(self, files, index_col='index'):
def read_geojson(self, files, index_col='name'):
"""
Append information from GeoJSON files to a WaterNetworkGIS object
Expand All @@ -336,7 +335,7 @@ def read_geojson(self, files, index_col='index'):
"""
self._read(files, index_col)

def read_shapefile(self, files, index_col='index'):
def read_shapefile(self, files, index_col='name'):
"""
Append information from Esri Shapefiles to a WaterNetworkGIS object
Expand Down
6 changes: 3 additions & 3 deletions wntr/network/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ def write_geojson(wn, prefix: str, crs=None, pumps_as_points=True,
wn_gis.write_geojson(prefix=prefix)


def read_geojson(files, index_col='index', append=None):
def read_geojson(files, index_col='name', append=None):
"""
Create or append a WaterNetworkModel from GeoJSON files
Expand Down Expand Up @@ -611,7 +611,7 @@ def write_shapefile(wn, prefix: str, crs=None, pumps_as_points=True,
valves_as_points=valves_as_points)
wn_gis.write_shapefile(prefix=prefix)

def read_shapefile(files, index_col='index', append=None):
def read_shapefile(files, index_col='name', append=None):
"""
Create or append a WaterNetworkModel from Esri Shapefiles
Expand All @@ -634,7 +634,7 @@ def read_shapefile(files, index_col='index', append=None):
"""
gis_data = WaterNetworkGIS()
gis_data.read_shapefile(files, index_col='index')
gis_data.read_shapefile(files,index_col=index_col)
wn = gis_data._create_wn(append=append)

return wn
Expand Down
34 changes: 32 additions & 2 deletions wntr/tests/test_gis.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,36 @@ def tearDownClass(self):
def test_gis_index(self):
# Tests that WN can be made using dataframes with customized index names
wn_gis = self.wn.to_gis()

# check that index name of geodataframes is "name"
assert wn_gis.junctions.index.name == "name"
assert wn_gis.tanks.index.name == "name"
assert wn_gis.reservoirs.index.name == "name"
assert wn_gis.pipes.index.name == "name"
assert wn_gis.pumps.index.name == "name"

# check that index names can be changed and still be read back into a wn
wn_gis.junctions.index.name = "my_index"
wn_gis.pipes.index.name = "my_index"
wn2 = wntr.network.from_gis(wn_gis)
self.wn == wn2


assert self.wn.pipe_name_list == wn2.pipe_name_list
assert self.wn.junction_name_list == wn2.junction_name_list

# test snap and intersect functionality with alternate index names
result = wntr.gis.snap(self.points, wn_gis.junctions, tolerance=5.0)
assert len(result) > 0
result = wntr.gis.snap(wn_gis.junctions, self.points, tolerance=5.0)
assert len(result) > 0
result = wntr.gis.intersect(wn_gis.junctions, self.polygons)
assert len(result) > 0
result = wntr.gis.intersect(self.polygons, wn_gis.pipes)
assert len(result) > 0

# check that custom index name persists after running snap/intersect
assert wn_gis.junctions.index.name == "my_index"
assert wn_gis.pipes.index.name == "my_index"

def test_wn_to_gis(self):
# Check type
isinstance(self.gis_data.junctions, gpd.GeoDataFrame)
Expand Down Expand Up @@ -256,8 +281,13 @@ def test_write_geojson(self):
for component in components:
if component == 'valves':
continue # Net1 has no valves
# check file exists
filename = abspath(join(testdir, prefix+'_'+component+'.geojson'))
self.assertTrue(isfile(filename))

# check for "name" column
gdf = gpd.read_file(filename)
assert "name" in gdf.columns

def test_snap_points_to_points(self):

Expand Down

0 comments on commit 31004a9

Please sign in to comment.