Skip to content

Commit

Permalink
Datasets: support os.PathLike (#2273)
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Oct 10, 2024
1 parent 184e58e commit 7d51f23
Show file tree
Hide file tree
Showing 20 changed files with 42 additions and 59 deletions.
2 changes: 1 addition & 1 deletion tests/datasets/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
bounds: BoundingBox = BoundingBox(0, 1, 2, 3, 4, 5),
crs: CRS = CRS.from_epsg(4087),
res: float = 1,
paths: str | Path | Iterable[str | Path] | None = None,
paths: str | os.PathLike[str] | Iterable[str | os.PathLike[str]] | None = None,
) -> None:
super().__init__()
self.index.insert(0, tuple(bounds))
Expand Down
3 changes: 1 addition & 2 deletions torchgeo/datasets/agb_live_woody_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import json
import os
import pathlib
from collections.abc import Callable, Iterable
from typing import Any

Expand Down Expand Up @@ -106,7 +105,7 @@ def _verify(self) -> None:

def _download(self) -> None:
"""Download the dataset."""
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
download_url(self.url, self.paths, self.base_filename)

with open(os.path.join(self.paths, self.base_filename)) as f:
Expand Down
7 changes: 3 additions & 4 deletions torchgeo/datasets/agrifieldnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""AgriFieldNet India Challenge dataset."""

import os
import pathlib
import re
from collections.abc import Callable, Iterable, Sequence
from typing import Any, ClassVar, cast
Expand Down Expand Up @@ -181,10 +180,10 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
Returns:
data, label, and field ids at that index
"""
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)

hits = self.index.intersection(tuple(query), objects=True)
filepaths = cast(list[Path], [hit.object for hit in hits])
filepaths = cast(list[str], [hit.object for hit in hits])

if not filepaths:
raise IndexError(
Expand Down Expand Up @@ -246,7 +245,7 @@ def _verify(self) -> None:

def _download(self) -> None:
"""Download the dataset."""
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
os.makedirs(self.paths, exist_ok=True)
azcopy = which('azcopy')
azcopy('sync', f'{self.url}', self.paths, '--recursive=true')
Expand Down
5 changes: 2 additions & 3 deletions torchgeo/datasets/cbf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""Canadian Building Footprints dataset."""

import os
import pathlib
from collections.abc import Callable, Iterable
from typing import Any

Expand Down Expand Up @@ -105,7 +104,7 @@ def _check_integrity(self) -> bool:
Returns:
True if dataset files are found and/or MD5s match, else False
"""
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
for prov_terr, md5 in zip(self.provinces_territories, self.md5s):
filepath = os.path.join(self.paths, prov_terr + '.zip')
if not check_integrity(filepath, md5 if self.checksum else None):
Expand All @@ -117,7 +116,7 @@ def _download(self) -> None:
if self._check_integrity():
print('Files already downloaded and verified')
return
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
for prov_terr, md5 in zip(self.provinces_territories, self.md5s):
download_and_extract_archive(
self.url + prov_terr + '.zip',
Expand Down
5 changes: 2 additions & 3 deletions torchgeo/datasets/cdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""CDL dataset."""

import os
import pathlib
from collections.abc import Callable, Iterable
from typing import Any, ClassVar

Expand Down Expand Up @@ -295,7 +294,7 @@ def _verify(self) -> None:

# Check if the zip files have already been downloaded
exists = []
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
for year in self.years:
pathname = os.path.join(
self.paths, self.zipfile_glob.replace('*', str(year))
Expand Down Expand Up @@ -328,7 +327,7 @@ def _download(self) -> None:

def _extract(self) -> None:
"""Extract the dataset."""
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
for year in self.years:
zipfile_name = self.zipfile_glob.replace('*', str(year))
pathname = os.path.join(self.paths, zipfile_name)
Expand Down
5 changes: 2 additions & 3 deletions torchgeo/datasets/chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import glob
import os
import pathlib
import sys
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Sequence
Expand Down Expand Up @@ -173,7 +172,7 @@ def _verify(self) -> None:
return

# Check if the zip file has already been downloaded
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
if glob.glob(os.path.join(self.paths, '**', '*.zip'), recursive=True):
self._extract()
return
Expand All @@ -195,7 +194,7 @@ def _download(self) -> None:

def _extract(self) -> None:
"""Extract the dataset."""
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
for file in glob.iglob(os.path.join(self.paths, '**', '*.zip'), recursive=True):
extract_archive(file)

Expand Down
5 changes: 2 additions & 3 deletions torchgeo/datasets/cms_mangrove_canopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""CMS Global Mangrove Canopy dataset."""

import os
import pathlib
from collections.abc import Callable
from typing import Any

Expand Down Expand Up @@ -229,7 +228,7 @@ def _verify(self) -> None:
return

# Check if the zip file has already been downloaded
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
pathname = os.path.join(self.paths, self.zipfile)
if os.path.exists(pathname):
if self.checksum and not check_integrity(pathname, self.md5):
Expand All @@ -241,7 +240,7 @@ def _verify(self) -> None:

def _extract(self) -> None:
"""Extract the dataset."""
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
pathname = os.path.join(self.paths, self.zipfile)
extract_archive(pathname)

Expand Down
5 changes: 2 additions & 3 deletions torchgeo/datasets/esri2020.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import glob
import os
import pathlib
from collections.abc import Callable, Iterable
from typing import Any

Expand Down Expand Up @@ -113,7 +112,7 @@ def _verify(self) -> None:
return

# Check if the zip files have already been downloaded
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
pathname = os.path.join(self.paths, self.zipfile)
if glob.glob(pathname):
self._extract()
Expand All @@ -133,7 +132,7 @@ def _download(self) -> None:

def _extract(self) -> None:
"""Extract the dataset."""
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
extract_archive(os.path.join(self.paths, self.zipfile))

def plot(
Expand Down
3 changes: 1 addition & 2 deletions torchgeo/datasets/eudem.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import glob
import os
import pathlib
from collections.abc import Callable, Iterable
from typing import Any, ClassVar

Expand Down Expand Up @@ -117,7 +116,7 @@ def _verify(self) -> None:
return

# Check if the zip files have already been downloaded
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
pathname = os.path.join(self.paths, self.zipfile_glob)
if glob.glob(pathname):
for zipfile in glob.iglob(pathname):
Expand Down
7 changes: 3 additions & 4 deletions torchgeo/datasets/eurocrops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import csv
import os
import pathlib
from collections.abc import Callable, Iterable
from typing import Any

Expand Down Expand Up @@ -140,7 +139,7 @@ def _check_integrity(self) -> bool:
if self.files and not self.checksum:
return True

assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)

filepath = os.path.join(self.paths, self.hcat_fname)
if not check_integrity(filepath, self.hcat_md5 if self.checksum else None):
Expand All @@ -157,7 +156,7 @@ def _download(self) -> None:
if self._check_integrity():
print('Files already downloaded and verified')
return
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
download_url(
self.base_url + self.hcat_fname,
self.paths,
Expand All @@ -179,7 +178,7 @@ def _load_class_map(self, classes: list[str] | None) -> None:
(defaults to all classes)
"""
if not classes:
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
classes = []
filepath = os.path.join(self.paths, self.hcat_fname)
with open(filepath) as f:
Expand Down
7 changes: 3 additions & 4 deletions torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import functools
import glob
import os
import pathlib
import re
import sys
import warnings
Expand Down Expand Up @@ -300,7 +299,7 @@ def files(self) -> list[str]:
.. versionadded:: 0.5
"""
# Make iterable
if isinstance(self.paths, str | pathlib.Path):
if isinstance(self.paths, str | os.PathLike):
paths: Iterable[Path] = [self.paths]
else:
paths = self.paths
Expand Down Expand Up @@ -521,7 +520,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
IndexError: if query is not found in the index
"""
hits = self.index.intersection(tuple(query), objects=True)
filepaths = cast(list[Path], [hit.object for hit in hits])
filepaths = cast(list[str], [hit.object for hit in hits])

if not filepaths:
raise IndexError(
Expand Down Expand Up @@ -564,7 +563,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:

def _merge_files(
self,
filepaths: Sequence[Path],
filepaths: Sequence[str],
query: BoundingBox,
band_indexes: Sequence[int] | None = None,
) -> Tensor:
Expand Down
5 changes: 2 additions & 3 deletions torchgeo/datasets/globbiomass.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import glob
import os
import pathlib
from collections.abc import Callable, Iterable
from typing import Any, ClassVar, cast

Expand Down Expand Up @@ -193,7 +192,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
IndexError: if query is not found in the index
"""
hits = self.index.intersection(tuple(query), objects=True)
filepaths = cast(list[Path], [hit.object for hit in hits])
filepaths = cast(list[str], [hit.object for hit in hits])

if not filepaths:
raise IndexError(
Expand Down Expand Up @@ -221,7 +220,7 @@ def _verify(self) -> None:
return

# Check if the zip files have already been downloaded
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
pathname = os.path.join(self.paths, f'*_{self.measurement}.zip')
if glob.glob(pathname):
for zipfile in glob.iglob(pathname):
Expand Down
7 changes: 3 additions & 4 deletions torchgeo/datasets/l7irish.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import glob
import os
import pathlib
import re
from collections.abc import Callable, Iterable, Sequence
from typing import Any, ClassVar, cast
Expand Down Expand Up @@ -94,7 +93,7 @@ def __init__(
filename_regex = re.compile(L7IrishImage.filename_regex, re.VERBOSE)
index = Index(interleaved=False, properties=Property(dimension=3))
for hit in self.index.intersection(self.index.bounds, objects=True):
dirname = os.path.dirname(cast(Path, hit.object))
dirname = os.path.dirname(cast(str, hit.object))
image = glob.glob(os.path.join(dirname, L7IrishImage.filename_glob))[0]
minx, maxx, miny, maxy, mint, maxt = hit.bounds
if match := re.match(filename_regex, os.path.basename(image)):
Expand Down Expand Up @@ -229,7 +228,7 @@ def _merge_dataset_indices(self) -> None:
def _verify(self) -> None:
"""Verify the integrity of the dataset."""
# Check if the extracted files already exist
if not isinstance(self.paths, str | pathlib.Path):
if not isinstance(self.paths, str | os.PathLike):
return

for classname in [L7IrishImage, L7IrishMask]:
Expand Down Expand Up @@ -262,7 +261,7 @@ def _download(self) -> None:

def _extract(self) -> None:
"""Extract the dataset."""
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
pathname = os.path.join(self.paths, '*.tar.gz')
for tarfile in glob.iglob(pathname):
extract_archive(tarfile)
Expand Down
5 changes: 2 additions & 3 deletions torchgeo/datasets/l8biome.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import glob
import os
import pathlib
from collections.abc import Callable, Iterable, Sequence
from typing import Any, ClassVar

Expand Down Expand Up @@ -174,7 +173,7 @@ def __init__(
def _verify(self) -> None:
"""Verify the integrity of the dataset."""
# Check if the extracted files already exist
if not isinstance(self.paths, str | pathlib.Path):
if not isinstance(self.paths, str | os.PathLike):
return

for classname in [L8BiomeImage, L8BiomeMask]:
Expand Down Expand Up @@ -207,7 +206,7 @@ def _download(self) -> None:

def _extract(self) -> None:
"""Extract the dataset."""
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
pathname = os.path.join(self.paths, '*.tar.gz')
for tarfile in glob.iglob(pathname):
extract_archive(tarfile)
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/landcoverai.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
IndexError: if query is not found in the index
"""
hits = self.index.intersection(tuple(query), objects=True)
img_filepaths = cast(list[Path], [hit.object for hit in hits])
img_filepaths = cast(list[str], [hit.object for hit in hits])
mask_filepaths = [
str(path).replace('images', 'masks') for path in img_filepaths
]
Expand Down
5 changes: 2 additions & 3 deletions torchgeo/datasets/nlcd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import glob
import os
import pathlib
from collections.abc import Callable, Iterable
from typing import Any, ClassVar

Expand Down Expand Up @@ -192,7 +191,7 @@ def _verify(self) -> None:
exists = []
for year in self.years:
zipfile_year = self.zipfile_glob.replace('*', str(year), 1)
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
pathname = os.path.join(self.paths, '**', zipfile_year)
if glob.glob(pathname, recursive=True):
exists.append(True)
Expand Down Expand Up @@ -224,7 +223,7 @@ def _extract(self) -> None:
"""Extract the dataset."""
for year in self.years:
zipfile_name = self.zipfile_glob.replace('*', str(year), 1)
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
pathname = os.path.join(self.paths, '**', zipfile_name)
extract_archive(glob.glob(pathname, recursive=True)[0], self.paths)

Expand Down
Loading

0 comments on commit 7d51f23

Please sign in to comment.