From baded89e52926dc8ec2e1c6ec8ef07631c798ee3 Mon Sep 17 00:00:00 2001 From: Rosie Wood Date: Tue, 3 Sep 2024 16:36:24 +0100 Subject: [PATCH] force engine to pyogrio --- mapreader/load/images.py | 4 ++++ mapreader/spot_text/runner_base.py | 2 +- mapreader/utils/load_frames.py | 3 ++- tests/test_load/test_images.py | 6 +++--- tests/test_utils_load_frames.py | 6 ++++-- 5 files changed, 14 insertions(+), 7 deletions(-) diff --git a/mapreader/load/images.py b/mapreader/load/images.py index edba4796..8164912d 100644 --- a/mapreader/load/images.py +++ b/mapreader/load/images.py @@ -2715,6 +2715,10 @@ def save_patches_to_geojson( if not crs: crs = patch_df.crs + if "image_id" in patch_df.columns: + patch_df.drop(columns=["image_id"], inplace=True) + patch_df.reset_index(names="image_id", inplace=True) + # drop pixel stats columns patch_df.drop(columns=patch_df.filter(like="pixel", axis=1), inplace=True) diff --git a/mapreader/spot_text/runner_base.py b/mapreader/spot_text/runner_base.py index 4b549bcd..9942a153 100644 --- a/mapreader/spot_text/runner_base.py +++ b/mapreader/spot_text/runner_base.py @@ -397,7 +397,7 @@ def save_to_geojson( """ geo_df = self._dict_to_dataframe(self.geo_predictions, geo=True, parent=True) - geo_df.to_file(save_path, driver="GeoJSON") + geo_df.to_file(save_path, driver="GeoJSON", engine="pyogrio") def show( self, diff --git a/mapreader/utils/load_frames.py b/mapreader/utils/load_frames.py index 7a04d5d9..a2bdbaeb 100644 --- a/mapreader/utils/load_frames.py +++ b/mapreader/utils/load_frames.py @@ -48,7 +48,8 @@ def load_from_geojson( **kwargs, ): check_exists(fpath) - df = gpd.read_file(fpath, **kwargs) + engine = kwargs.pop("engine", "pyogrio") + df = gpd.read_file(fpath, engine=engine, **kwargs) if "image_id" in df.columns: df.set_index("image_id", drop=True, inplace=True) elif "name" in df.columns: diff --git a/tests/test_load/test_images.py b/tests/test_load/test_images.py index 12d6adf5..66247ea7 100644 --- a/tests/test_load/test_images.py +++ b/tests/test_load/test_images.py @@ -764,7 +764,7 @@ def test_save_to_geojson(init_maps, tmp_path, capfd): maps, _, _ = init_maps maps.save_patches_to_geojson(geojson_fname=f"{tmp_path}/patches.geojson") assert os.path.exists(f"{tmp_path}/patches.geojson") - geo_df = gpd.read_file(f"{tmp_path}/patches.geojson") + geo_df = gpd.read_file(f"{tmp_path}/patches.geojson", engine="pyogrio") assert "geometry" in geo_df.columns assert str(geo_df.crs.to_string()) == "EPSG:4326" assert isinstance(geo_df["geometry"][0], Polygon) @@ -798,7 +798,7 @@ def test_save_to_geojson_missing_data(sample_dir, image_id, tmp_path): ) maps.save_patches_to_geojson(geojson_fname=f"{tmp_path}/patches.geojson") assert os.path.exists(f"{tmp_path}/patches.geojson") - geo_df = gpd.read_file(f"{tmp_path}/patches.geojson") + geo_df = gpd.read_file(f"{tmp_path}/patches.geojson", engine="pyogrio") assert "geometry" in geo_df.columns assert str(geo_df.crs.to_string()) == "EPSG:4326" assert isinstance(geo_df["geometry"][0], Polygon) @@ -816,7 +816,7 @@ def test_save_to_geojson_polygon_strings( assert isinstance(maps.patches[patch_id]["geometry"], str) maps.save_patches_to_geojson(geojson_fname=f"{tmp_path}/patches.geojson") assert os.path.exists(f"{tmp_path}/patches.geojson") - geo_df = gpd.read_file(f"{tmp_path}/patches.geojson") + geo_df = gpd.read_file(f"{tmp_path}/patches.geojson", engine="pyogrio") assert "geometry" in geo_df.columns assert str(geo_df.crs.to_string()) == "EPSG:4326" assert isinstance(geo_df["geometry"][0], Polygon) diff --git a/tests/test_utils_load_frames.py b/tests/test_utils_load_frames.py index dbc6f21f..ff35eb93 100644 --- a/tests/test_utils_load_frames.py +++ b/tests/test_utils_load_frames.py @@ -100,8 +100,10 @@ def test_load_from_excel(init_dataframes, tmp_path): def test_load_from_geojson(init_dataframes, tmp_path): parent_df, _ = init_dataframes - parent_df.to_file(f"{tmp_path}/parent_df.geojson", driver="GeoJSON") - parent_df_geojson = gpd.read_file(f"{tmp_path}/parent_df.geojson") + parent_df.to_file( + f"{tmp_path}/parent_df.geojson", driver="GeoJSON", engine="pyogrio" + ) + parent_df_geojson = gpd.read_file(f"{tmp_path}/parent_df.geojson", engine="pyogrio") print(parent_df_geojson.columns) print(parent_df_geojson.head()) assert parent_df_geojson.index == range(len(parent_df)) # should be numeric index