From 7d51f23057b67a4e137863e97dc6649889d7a48a Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Tue, 10 Sep 2024 17:27:04 +0200 Subject: [PATCH] Datasets: support os.PathLike (#2273) --- tests/datasets/test_geo.py | 2 +- torchgeo/datasets/agb_live_woody_density.py | 3 +-- torchgeo/datasets/agrifieldnet.py | 7 +++---- torchgeo/datasets/cbf.py | 5 ++--- torchgeo/datasets/cdl.py | 5 ++--- torchgeo/datasets/chesapeake.py | 5 ++--- torchgeo/datasets/cms_mangrove_canopy.py | 5 ++--- torchgeo/datasets/esri2020.py | 5 ++--- torchgeo/datasets/eudem.py | 3 +-- torchgeo/datasets/eurocrops.py | 7 +++---- torchgeo/datasets/geo.py | 7 +++---- torchgeo/datasets/globbiomass.py | 5 ++--- torchgeo/datasets/l7irish.py | 7 +++---- torchgeo/datasets/l8biome.py | 5 ++--- torchgeo/datasets/landcoverai.py | 2 +- torchgeo/datasets/nlcd.py | 5 ++--- torchgeo/datasets/openbuildings.py | 9 ++++----- torchgeo/datasets/south_africa_crop_type.py | 7 +++---- torchgeo/datasets/south_america_soybean.py | 4 ++-- torchgeo/datasets/utils.py | 3 +-- 20 files changed, 42 insertions(+), 59 deletions(-) diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index 06202438372..71e07f6928b 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -38,7 +38,7 @@ def __init__( bounds: BoundingBox = BoundingBox(0, 1, 2, 3, 4, 5), crs: CRS = CRS.from_epsg(4087), res: float = 1, - paths: str | Path | Iterable[str | Path] | None = None, + paths: str | os.PathLike[str] | Iterable[str | os.PathLike[str]] | None = None, ) -> None: super().__init__() self.index.insert(0, tuple(bounds)) diff --git a/torchgeo/datasets/agb_live_woody_density.py b/torchgeo/datasets/agb_live_woody_density.py index 1ceaa9c9d3d..aaef8db9751 100644 --- a/torchgeo/datasets/agb_live_woody_density.py +++ b/torchgeo/datasets/agb_live_woody_density.py @@ -5,7 +5,6 @@ import json import os -import pathlib from collections.abc import Callable, Iterable from typing import Any @@ -106,7 +105,7 @@ def _verify(self) -> None: def _download(self) -> None: """Download the dataset.""" - assert isinstance(self.paths, str | pathlib.Path) + assert isinstance(self.paths, str | os.PathLike) download_url(self.url, self.paths, self.base_filename) with open(os.path.join(self.paths, self.base_filename)) as f: diff --git a/torchgeo/datasets/agrifieldnet.py b/torchgeo/datasets/agrifieldnet.py index e40fca5eecf..3624c1e193e 100644 --- a/torchgeo/datasets/agrifieldnet.py +++ b/torchgeo/datasets/agrifieldnet.py @@ -4,7 +4,6 @@ """AgriFieldNet India Challenge dataset.""" import os -import pathlib import re from collections.abc import Callable, Iterable, Sequence from typing import Any, ClassVar, cast @@ -181,10 +180,10 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: Returns: data, label, and field ids at that index """ - assert isinstance(self.paths, str | pathlib.Path) + assert isinstance(self.paths, str | os.PathLike) hits = self.index.intersection(tuple(query), objects=True) - filepaths = cast(list[Path], [hit.object for hit in hits]) + filepaths = cast(list[str], [hit.object for hit in hits]) if not filepaths: raise IndexError( @@ -246,7 +245,7 @@ def _verify(self) -> None: def _download(self) -> None: """Download the dataset.""" - assert isinstance(self.paths, str | pathlib.Path) + assert isinstance(self.paths, str | os.PathLike) os.makedirs(self.paths, exist_ok=True) azcopy = which('azcopy') azcopy('sync', f'{self.url}', self.paths, '--recursive=true') diff --git a/torchgeo/datasets/cbf.py b/torchgeo/datasets/cbf.py index ec6afdb34b7..3c986eb44c1 100644 --- a/torchgeo/datasets/cbf.py +++ b/torchgeo/datasets/cbf.py @@ -4,7 +4,6 @@ """Canadian Building Footprints dataset.""" import os -import pathlib from collections.abc import Callable, Iterable from typing import Any @@ -105,7 +104,7 @@ def _check_integrity(self) -> bool: Returns: True if dataset files are found and/or MD5s match, else False """ - assert isinstance(self.paths, str | pathlib.Path) + assert isinstance(self.paths, str | os.PathLike) for prov_terr, md5 in zip(self.provinces_territories, self.md5s): filepath = os.path.join(self.paths, prov_terr + '.zip') if not check_integrity(filepath, md5 if self.checksum else None): @@ -117,7 +116,7 @@ def _download(self) -> None: if self._check_integrity(): print('Files already downloaded and verified') return - assert isinstance(self.paths, str | pathlib.Path) + assert isinstance(self.paths, str | os.PathLike) for prov_terr, md5 in zip(self.provinces_territories, self.md5s): download_and_extract_archive( self.url + prov_terr + '.zip', diff --git a/torchgeo/datasets/cdl.py b/torchgeo/datasets/cdl.py index 6cd2b483062..0b0f6ac5b3d 100644 --- a/torchgeo/datasets/cdl.py +++ b/torchgeo/datasets/cdl.py @@ -4,7 +4,6 @@ """CDL dataset.""" import os -import pathlib from collections.abc import Callable, Iterable from typing import Any, ClassVar @@ -295,7 +294,7 @@ def _verify(self) -> None: # Check if the zip files have already been downloaded exists = [] - assert isinstance(self.paths, str | pathlib.Path) + assert isinstance(self.paths, str | os.PathLike) for year in self.years: pathname = os.path.join( self.paths, self.zipfile_glob.replace('*', str(year)) @@ -328,7 +327,7 @@ def _download(self) -> None: def _extract(self) -> None: """Extract the dataset.""" - assert isinstance(self.paths, str | pathlib.Path) + assert isinstance(self.paths, str | os.PathLike) for year in self.years: zipfile_name = self.zipfile_glob.replace('*', str(year)) pathname = os.path.join(self.paths, zipfile_name) diff --git a/torchgeo/datasets/chesapeake.py b/torchgeo/datasets/chesapeake.py index 33a347576f9..7e521092aaf 100644 --- a/torchgeo/datasets/chesapeake.py +++ b/torchgeo/datasets/chesapeake.py @@ -5,7 +5,6 @@ import glob import os -import pathlib import sys from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Sequence @@ -173,7 +172,7 @@ def _verify(self) -> None: return # Check if the zip file has already been downloaded - assert isinstance(self.paths, str | pathlib.Path) + assert isinstance(self.paths, str | os.PathLike) if glob.glob(os.path.join(self.paths, '**', '*.zip'), recursive=True): self._extract() return @@ -195,7 +194,7 @@ def _download(self) -> None: def _extract(self) -> None: """Extract the dataset.""" - assert isinstance(self.paths, str | pathlib.Path) + assert isinstance(self.paths, str | os.PathLike) for file in glob.iglob(os.path.join(self.paths, '**', '*.zip'), recursive=True): extract_archive(file) diff --git a/torchgeo/datasets/cms_mangrove_canopy.py b/torchgeo/datasets/cms_mangrove_canopy.py index 61d5c4acafd..f479d6989d9 100644 --- a/torchgeo/datasets/cms_mangrove_canopy.py +++ b/torchgeo/datasets/cms_mangrove_canopy.py @@ -4,7 +4,6 @@ """CMS Global Mangrove Canopy dataset.""" import os -import pathlib from collections.abc import Callable from typing import Any @@ -229,7 +228,7 @@ def _verify(self) -> None: return # Check if the zip file has already been downloaded - assert isinstance(self.paths, str | pathlib.Path) + assert isinstance(self.paths, str | os.PathLike) pathname = os.path.join(self.paths, self.zipfile) if os.path.exists(pathname): if self.checksum and not check_integrity(pathname, self.md5): @@ -241,7 +240,7 @@ def _verify(self) -> None: def _extract(self) -> None: """Extract the dataset.""" - assert isinstance(self.paths, str | pathlib.Path) + assert isinstance(self.paths, str | os.PathLike) pathname = os.path.join(self.paths, self.zipfile) extract_archive(pathname) diff --git a/torchgeo/datasets/esri2020.py b/torchgeo/datasets/esri2020.py index ed06309f91a..9753c0584c6 100644 --- a/torchgeo/datasets/esri2020.py +++ b/torchgeo/datasets/esri2020.py @@ -5,7 +5,6 @@ import glob import os -import pathlib from collections.abc import Callable, Iterable from typing import Any @@ -113,7 +112,7 @@ def _verify(self) -> None: return # Check if the zip files have already been downloaded - assert isinstance(self.paths, str | pathlib.Path) + assert isinstance(self.paths, str | os.PathLike) pathname = os.path.join(self.paths, self.zipfile) if glob.glob(pathname): self._extract() @@ -133,7 +132,7 @@ def _download(self) -> None: def _extract(self) -> None: """Extract the dataset.""" - assert isinstance(self.paths, str | pathlib.Path) + assert isinstance(self.paths, str | os.PathLike) extract_archive(os.path.join(self.paths, self.zipfile)) def plot( diff --git a/torchgeo/datasets/eudem.py b/torchgeo/datasets/eudem.py index 765433bfc05..5a9af7f6fa3 100644 --- a/torchgeo/datasets/eudem.py +++ b/torchgeo/datasets/eudem.py @@ -5,7 +5,6 @@ import glob import os -import pathlib from collections.abc import Callable, Iterable from typing import Any, ClassVar @@ -117,7 +116,7 @@ def _verify(self) -> None: return # Check if the zip files have already been downloaded - assert isinstance(self.paths, str | pathlib.Path) + assert isinstance(self.paths, str | os.PathLike) pathname = os.path.join(self.paths, self.zipfile_glob) if glob.glob(pathname): for zipfile in glob.iglob(pathname): diff --git a/torchgeo/datasets/eurocrops.py b/torchgeo/datasets/eurocrops.py index d2087e4b38d..5f438143c87 100644 --- a/torchgeo/datasets/eurocrops.py +++ b/torchgeo/datasets/eurocrops.py @@ -5,7 +5,6 @@ import csv import os -import pathlib from collections.abc import Callable, Iterable from typing import Any @@ -140,7 +139,7 @@ def _check_integrity(self) -> bool: if self.files and not self.checksum: return True - assert isinstance(self.paths, str | pathlib.Path) + assert isinstance(self.paths, str | os.PathLike) filepath = os.path.join(self.paths, self.hcat_fname) if not check_integrity(filepath, self.hcat_md5 if self.checksum else None): @@ -157,7 +156,7 @@ def _download(self) -> None: if self._check_integrity(): print('Files already downloaded and verified') return - assert isinstance(self.paths, str | pathlib.Path) + assert isinstance(self.paths, str | os.PathLike) download_url( self.base_url + self.hcat_fname, self.paths, @@ -179,7 +178,7 @@ def _load_class_map(self, classes: list[str] | None) -> None: (defaults to all classes) """ if not classes: - assert isinstance(self.paths, str | pathlib.Path) + assert isinstance(self.paths, str | os.PathLike) classes = [] filepath = os.path.join(self.paths, self.hcat_fname) with open(filepath) as f: diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 198065a708a..8233480443a 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -8,7 +8,6 @@ import functools import glob import os -import pathlib import re import sys import warnings @@ -300,7 +299,7 @@ def files(self) -> list[str]: .. versionadded:: 0.5 """ # Make iterable - if isinstance(self.paths, str | pathlib.Path): + if isinstance(self.paths, str | os.PathLike): paths: Iterable[Path] = [self.paths] else: paths = self.paths @@ -521,7 +520,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: IndexError: if query is not found in the index """ hits = self.index.intersection(tuple(query), objects=True) - filepaths = cast(list[Path], [hit.object for hit in hits]) + filepaths = cast(list[str], [hit.object for hit in hits]) if not filepaths: raise IndexError( @@ -564,7 +563,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: def _merge_files( self, - filepaths: Sequence[Path], + filepaths: Sequence[str], query: BoundingBox, band_indexes: Sequence[int] | None = None, ) -> Tensor: diff --git a/torchgeo/datasets/globbiomass.py b/torchgeo/datasets/globbiomass.py index e117bf361c6..c214fbba205 100644 --- a/torchgeo/datasets/globbiomass.py +++ b/torchgeo/datasets/globbiomass.py @@ -5,7 +5,6 @@ import glob import os -import pathlib from collections.abc import Callable, Iterable from typing import Any, ClassVar, cast @@ -193,7 +192,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: IndexError: if query is not found in the index """ hits = self.index.intersection(tuple(query), objects=True) - filepaths = cast(list[Path], [hit.object for hit in hits]) + filepaths = cast(list[str], [hit.object for hit in hits]) if not filepaths: raise IndexError( @@ -221,7 +220,7 @@ def _verify(self) -> None: return # Check if the zip files have already been downloaded - assert isinstance(self.paths, str | pathlib.Path) + assert isinstance(self.paths, str | os.PathLike) pathname = os.path.join(self.paths, f'*_{self.measurement}.zip') if glob.glob(pathname): for zipfile in glob.iglob(pathname): diff --git a/torchgeo/datasets/l7irish.py b/torchgeo/datasets/l7irish.py index 25830eaf3dc..d39f225ed75 100644 --- a/torchgeo/datasets/l7irish.py +++ b/torchgeo/datasets/l7irish.py @@ -5,7 +5,6 @@ import glob import os -import pathlib import re from collections.abc import Callable, Iterable, Sequence from typing import Any, ClassVar, cast @@ -94,7 +93,7 @@ def __init__( filename_regex = re.compile(L7IrishImage.filename_regex, re.VERBOSE) index = Index(interleaved=False, properties=Property(dimension=3)) for hit in self.index.intersection(self.index.bounds, objects=True): - dirname = os.path.dirname(cast(Path, hit.object)) + dirname = os.path.dirname(cast(str, hit.object)) image = glob.glob(os.path.join(dirname, L7IrishImage.filename_glob))[0] minx, maxx, miny, maxy, mint, maxt = hit.bounds if match := re.match(filename_regex, os.path.basename(image)): @@ -229,7 +228,7 @@ def _merge_dataset_indices(self) -> None: def _verify(self) -> None: """Verify the integrity of the dataset.""" # Check if the extracted files already exist - if not isinstance(self.paths, str | pathlib.Path): + if not isinstance(self.paths, str | os.PathLike): return for classname in [L7IrishImage, L7IrishMask]: @@ -262,7 +261,7 @@ def _download(self) -> None: def _extract(self) -> None: """Extract the dataset.""" - assert isinstance(self.paths, str | pathlib.Path) + assert isinstance(self.paths, str | os.PathLike) pathname = os.path.join(self.paths, '*.tar.gz') for tarfile in glob.iglob(pathname): extract_archive(tarfile) diff --git a/torchgeo/datasets/l8biome.py b/torchgeo/datasets/l8biome.py index 318efa2476a..e53c403b713 100644 --- a/torchgeo/datasets/l8biome.py +++ b/torchgeo/datasets/l8biome.py @@ -5,7 +5,6 @@ import glob import os -import pathlib from collections.abc import Callable, Iterable, Sequence from typing import Any, ClassVar @@ -174,7 +173,7 @@ def __init__( def _verify(self) -> None: """Verify the integrity of the dataset.""" # Check if the extracted files already exist - if not isinstance(self.paths, str | pathlib.Path): + if not isinstance(self.paths, str | os.PathLike): return for classname in [L8BiomeImage, L8BiomeMask]: @@ -207,7 +206,7 @@ def _download(self) -> None: def _extract(self) -> None: """Extract the dataset.""" - assert isinstance(self.paths, str | pathlib.Path) + assert isinstance(self.paths, str | os.PathLike) pathname = os.path.join(self.paths, '*.tar.gz') for tarfile in glob.iglob(pathname): extract_archive(tarfile) diff --git a/torchgeo/datasets/landcoverai.py b/torchgeo/datasets/landcoverai.py index d9a9643b524..a290e404843 100644 --- a/torchgeo/datasets/landcoverai.py +++ b/torchgeo/datasets/landcoverai.py @@ -254,7 +254,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: IndexError: if query is not found in the index """ hits = self.index.intersection(tuple(query), objects=True) - img_filepaths = cast(list[Path], [hit.object for hit in hits]) + img_filepaths = cast(list[str], [hit.object for hit in hits]) mask_filepaths = [ str(path).replace('images', 'masks') for path in img_filepaths ] diff --git a/torchgeo/datasets/nlcd.py b/torchgeo/datasets/nlcd.py index 7f4498c76da..ab0e83c96b1 100644 --- a/torchgeo/datasets/nlcd.py +++ b/torchgeo/datasets/nlcd.py @@ -5,7 +5,6 @@ import glob import os -import pathlib from collections.abc import Callable, Iterable from typing import Any, ClassVar @@ -192,7 +191,7 @@ def _verify(self) -> None: exists = [] for year in self.years: zipfile_year = self.zipfile_glob.replace('*', str(year), 1) - assert isinstance(self.paths, str | pathlib.Path) + assert isinstance(self.paths, str | os.PathLike) pathname = os.path.join(self.paths, '**', zipfile_year) if glob.glob(pathname, recursive=True): exists.append(True) @@ -224,7 +223,7 @@ def _extract(self) -> None: """Extract the dataset.""" for year in self.years: zipfile_name = self.zipfile_glob.replace('*', str(year), 1) - assert isinstance(self.paths, str | pathlib.Path) + assert isinstance(self.paths, str | os.PathLike) pathname = os.path.join(self.paths, '**', zipfile_name) extract_archive(glob.glob(pathname, recursive=True)[0], self.paths) diff --git a/torchgeo/datasets/openbuildings.py b/torchgeo/datasets/openbuildings.py index 50b35ee39db..292dc274c32 100644 --- a/torchgeo/datasets/openbuildings.py +++ b/torchgeo/datasets/openbuildings.py @@ -6,7 +6,6 @@ import glob import json import os -import pathlib import sys from collections.abc import Callable, Iterable from typing import Any, ClassVar, cast @@ -242,7 +241,7 @@ def __init__( # Create an R-tree to index the dataset using the polygon centroid as bounds self.index = Index(interleaved=False, properties=Property(dimension=3)) - assert isinstance(self.paths, str | pathlib.Path) + assert isinstance(self.paths, str | os.PathLike) with open(os.path.join(self.paths, 'tiles.geojson')) as f: data = json.load(f) @@ -305,7 +304,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: IndexError: if query is not found in the index """ hits = self.index.intersection(tuple(query), objects=True) - filepaths = cast(list[Path], [hit.object for hit in hits]) + filepaths = cast(list[str], [hit.object for hit in hits]) if not filepaths: raise IndexError( @@ -336,7 +335,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: return sample def _filter_geometries( - self, query: BoundingBox, filepaths: list[Path] + self, query: BoundingBox, filepaths: list[str] ) -> list[dict[str, Any]]: """Filters a df read from the polygon csv file based on query and conf thresh. @@ -398,7 +397,7 @@ def _wkt_fiona_geom_transform(self, x: str) -> dict[str, Any]: def _verify(self) -> None: """Verify the integrity of the dataset.""" # Check if the zip files have already been downloaded and checksum - assert isinstance(self.paths, str | pathlib.Path) + assert isinstance(self.paths, str | os.PathLike) pathname = os.path.join(self.paths, self.zipfile_glob) i = 0 for zipfile in glob.iglob(pathname): diff --git a/torchgeo/datasets/south_africa_crop_type.py b/torchgeo/datasets/south_africa_crop_type.py index 7b79f44c8d2..3c4f7f895ec 100644 --- a/torchgeo/datasets/south_africa_crop_type.py +++ b/torchgeo/datasets/south_africa_crop_type.py @@ -4,7 +4,6 @@ """South Africa Crop Type Competition Dataset.""" import os -import pathlib import re from collections.abc import Callable, Iterable, Sequence from typing import Any, ClassVar, cast @@ -161,11 +160,11 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: Returns: data and labels at that index """ - assert isinstance(self.paths, str | pathlib.Path) + assert isinstance(self.paths, str | os.PathLike) # Get all files matching the given query hits = self.index.intersection(tuple(query), objects=True) - filepaths = cast(list[Path], [hit.object for hit in hits]) + filepaths = cast(list[str], [hit.object for hit in hits]) if not filepaths: raise IndexError( @@ -253,7 +252,7 @@ def _verify(self) -> None: def _download(self) -> None: """Download the dataset.""" - assert isinstance(self.paths, str | pathlib.Path) + assert isinstance(self.paths, str | os.PathLike) os.makedirs(self.paths, exist_ok=True) azcopy = which('azcopy') azcopy('sync', f'{self.url}', self.paths, '--recursive=true') diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index 3ec4b559472..adbde74d6cb 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -3,7 +3,7 @@ """South America Soybean Dataset.""" -import pathlib +import os from collections.abc import Callable, Iterable from typing import Any, ClassVar @@ -113,7 +113,7 @@ def _verify(self) -> None: # Check if the extracted files already exist if self.files: return - assert isinstance(self.paths, str | pathlib.Path) + assert isinstance(self.paths, str | os.PathLike) # Check if the user requested to download the dataset if not self.download: diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index 237b479326e..a3a39a0f844 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -10,7 +10,6 @@ import contextlib import importlib import os -import pathlib import shutil import subprocess import sys @@ -42,7 +41,7 @@ ) -Path: TypeAlias = str | pathlib.Path +Path: TypeAlias = str | os.PathLike[str] @dataclass(frozen=True)