Skip to content

Commit

Permalink
force engine to pyogrio
Browse files Browse the repository at this point in the history
  • Loading branch information
rwood-97 committed Sep 3, 2024
1 parent 2e8d917 commit baded89
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 7 deletions.
4 changes: 4 additions & 0 deletions mapreader/load/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -2715,6 +2715,10 @@ def save_patches_to_geojson(
if not crs:
crs = patch_df.crs

if "image_id" in patch_df.columns:
patch_df.drop(columns=["image_id"], inplace=True)
patch_df.reset_index(names="image_id", inplace=True)

# drop pixel stats columns
patch_df.drop(columns=patch_df.filter(like="pixel", axis=1), inplace=True)

Expand Down
2 changes: 1 addition & 1 deletion mapreader/spot_text/runner_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def save_to_geojson(
"""

geo_df = self._dict_to_dataframe(self.geo_predictions, geo=True, parent=True)
geo_df.to_file(save_path, driver="GeoJSON")
geo_df.to_file(save_path, driver="GeoJSON", engine="pyogrio")

Check warning on line 400 in mapreader/spot_text/runner_base.py

View check run for this annotation

Codecov / codecov/patch

mapreader/spot_text/runner_base.py#L400

Added line #L400 was not covered by tests

def show(
self,
Expand Down
3 changes: 2 additions & 1 deletion mapreader/utils/load_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def load_from_geojson(
**kwargs,
):
check_exists(fpath)
df = gpd.read_file(fpath, **kwargs)
engine = kwargs.pop("engine", "pyogrio")
df = gpd.read_file(fpath, engine=engine, **kwargs)
if "image_id" in df.columns:
df.set_index("image_id", drop=True, inplace=True)
elif "name" in df.columns:
Expand Down
6 changes: 3 additions & 3 deletions tests/test_load/test_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,7 @@ def test_save_to_geojson(init_maps, tmp_path, capfd):
maps, _, _ = init_maps
maps.save_patches_to_geojson(geojson_fname=f"{tmp_path}/patches.geojson")
assert os.path.exists(f"{tmp_path}/patches.geojson")
geo_df = gpd.read_file(f"{tmp_path}/patches.geojson")
geo_df = gpd.read_file(f"{tmp_path}/patches.geojson", engine="pyogrio")
assert "geometry" in geo_df.columns
assert str(geo_df.crs.to_string()) == "EPSG:4326"
assert isinstance(geo_df["geometry"][0], Polygon)
Expand Down Expand Up @@ -798,7 +798,7 @@ def test_save_to_geojson_missing_data(sample_dir, image_id, tmp_path):
)
maps.save_patches_to_geojson(geojson_fname=f"{tmp_path}/patches.geojson")
assert os.path.exists(f"{tmp_path}/patches.geojson")
geo_df = gpd.read_file(f"{tmp_path}/patches.geojson")
geo_df = gpd.read_file(f"{tmp_path}/patches.geojson", engine="pyogrio")
assert "geometry" in geo_df.columns
assert str(geo_df.crs.to_string()) == "EPSG:4326"
assert isinstance(geo_df["geometry"][0], Polygon)
Expand All @@ -816,7 +816,7 @@ def test_save_to_geojson_polygon_strings(
assert isinstance(maps.patches[patch_id]["geometry"], str)
maps.save_patches_to_geojson(geojson_fname=f"{tmp_path}/patches.geojson")
assert os.path.exists(f"{tmp_path}/patches.geojson")
geo_df = gpd.read_file(f"{tmp_path}/patches.geojson")
geo_df = gpd.read_file(f"{tmp_path}/patches.geojson", engine="pyogrio")
assert "geometry" in geo_df.columns
assert str(geo_df.crs.to_string()) == "EPSG:4326"
assert isinstance(geo_df["geometry"][0], Polygon)
Expand Down
6 changes: 4 additions & 2 deletions tests/test_utils_load_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,10 @@ def test_load_from_excel(init_dataframes, tmp_path):

def test_load_from_geojson(init_dataframes, tmp_path):
parent_df, _ = init_dataframes
parent_df.to_file(f"{tmp_path}/parent_df.geojson", driver="GeoJSON")
parent_df_geojson = gpd.read_file(f"{tmp_path}/parent_df.geojson")
parent_df.to_file(
f"{tmp_path}/parent_df.geojson", driver="GeoJSON", engine="pyogrio"
)
parent_df_geojson = gpd.read_file(f"{tmp_path}/parent_df.geojson", engine="pyogrio")
print(parent_df_geojson.columns)
print(parent_df_geojson.head())
assert parent_df_geojson.index == range(len(parent_df)) # should be numeric index
Expand Down

0 comments on commit baded89

Please sign in to comment.