diff --git a/tests/data/nccm/13090442.zip b/tests/data/nccm/13090442.zip deleted file mode 100644 index 19d0792078a..00000000000 Binary files a/tests/data/nccm/13090442.zip and /dev/null differ diff --git a/tests/data/nccm/13090442/CDL2017_clip.tif b/tests/data/nccm/13090442/CDL2017_clip.tif deleted file mode 100644 index 8dce2bb82e9..00000000000 Binary files a/tests/data/nccm/13090442/CDL2017_clip.tif and /dev/null differ diff --git a/tests/data/nccm/13090442/CDL2018_clip1.tif b/tests/data/nccm/13090442/CDL2018_clip1.tif deleted file mode 100644 index 531cd5f4f1f..00000000000 Binary files a/tests/data/nccm/13090442/CDL2018_clip1.tif and /dev/null differ diff --git a/tests/data/nccm/13090442/CDL2019_clip.tif b/tests/data/nccm/13090442/CDL2019_clip.tif deleted file mode 100644 index 67be3087ed7..00000000000 Binary files a/tests/data/nccm/13090442/CDL2019_clip.tif and /dev/null differ diff --git a/tests/data/nccm/CDL2017_clip.tif b/tests/data/nccm/CDL2017_clip.tif new file mode 100644 index 00000000000..1040f7936c6 Binary files /dev/null and b/tests/data/nccm/CDL2017_clip.tif differ diff --git a/tests/data/nccm/CDL2018_clip1.tif b/tests/data/nccm/CDL2018_clip1.tif new file mode 100644 index 00000000000..3313fef10d1 Binary files /dev/null and b/tests/data/nccm/CDL2018_clip1.tif differ diff --git a/tests/data/nccm/CDL2019_clip.tif b/tests/data/nccm/CDL2019_clip.tif new file mode 100644 index 00000000000..9c4d1dcae44 Binary files /dev/null and b/tests/data/nccm/CDL2019_clip.tif differ diff --git a/tests/data/nccm/data.py b/tests/data/nccm/data.py index 6a98ca3a2d0..2956f147033 100644 --- a/tests/data/nccm/data.py +++ b/tests/data/nccm/data.py @@ -5,7 +5,6 @@ import hashlib import os -import shutil import numpy as np import rasterio @@ -48,20 +47,14 @@ def create_file(path: str, dtype: str): if __name__ == "__main__": - dir = os.path.join(os.getcwd(), "13090442") - - if os.path.exists(dir) and os.path.isdir(dir): - shutil.rmtree(dir) - + dir = os.path.join(os.getcwd()) os.makedirs(dir, exist_ok=True) for file in files: create_file(os.path.join(dir, file), dtype="int8") - # Compress data - shutil.make_archive("13090442", "zip", ".", dir) - # Compute checksums - with open("13090442.zip", "rb") as f: - md5 = hashlib.md5(f.read()).hexdigest() - print(f"13090442.zip: {md5}") + for file in files: + with open(file, "rb") as f: + md5 = hashlib.md5(f.read()).hexdigest() + print(f"{file}: {md5}") diff --git a/tests/datasets/test_nccm.py b/tests/datasets/test_nccm.py index 6637da3e840..0d922d9d3d5 100644 --- a/tests/datasets/test_nccm.py +++ b/tests/datasets/test_nccm.py @@ -25,9 +25,19 @@ class TestNCCM: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> NCCM: monkeypatch.setattr(torchgeo.datasets.nccm, "download_url", download_url) - url = os.path.join("tests", "data", "nccm", "13090442.zip") + md5s = { + 2017: "ae5c390d0ffb8970d544b8a09142759f", + 2018: "0d453bdb8ea5b7318c33e62513760580", + 2019: "d4ab7ab00bb57623eafb6b27747e5639", + } + monkeypatch.setattr(NCCM, "md5s", md5s) + urls = { + 2017: os.path.join("tests", "data", "nccm", "CDL2017_clip.tif"), + 2018: os.path.join("tests", "data", "nccm", "CDL2018_clip1.tif"), + 2019: os.path.join("tests", "data", "nccm", "CDL2019_clip.tif"), + } + monkeypatch.setattr(NCCM, "urls", urls) transforms = nn.Identity() - monkeypatch.setattr(NCCM, "url", url) root = str(tmp_path) return NCCM(root, transforms=transforms, download=True, checksum=True) @@ -48,11 +58,8 @@ def test_or(self, dataset: NCCM) -> None: def test_already_extracted(self, dataset: NCCM) -> None: NCCM(dataset.paths, download=True) - def test_already_downloaded(self, tmp_path: Path) -> None: - pathname = os.path.join("tests", "data", "nccm", "13090442.zip") - root = str(tmp_path) - shutil.copy(pathname, root) - NCCM(root) + def test_already_downloaded(self, dataset: NCCM) -> None: + NCCM(dataset.paths, download=True) def test_plot(self, dataset: NCCM) -> None: query = dataset.bounds diff --git a/torchgeo/datasets/nccm.py b/torchgeo/datasets/nccm.py index 3a43ddddcc5..38a0d3eee91 100644 --- a/torchgeo/datasets/nccm.py +++ b/torchgeo/datasets/nccm.py @@ -3,8 +3,6 @@ """Northeastern China Crop Map Dataset.""" -import glob -import os from collections.abc import Iterable from typing import Any, Callable, Optional, Union @@ -14,7 +12,7 @@ from rasterio.crs import CRS from .geo import RasterDataset -from .utils import BoundingBox, DatasetNotFoundError, download_url, extract_archive +from .utils import BoundingBox, DatasetNotFoundError, download_url class NCCM(RasterDataset): @@ -55,12 +53,24 @@ class NCCM(RasterDataset): filename_regex = r"CDL(?P\d{4})_clip" filename_glob = "CDL*.*" - zipfile_glob = "13090442.zip" date_format = "%Y" is_image = False - url = "https://figshare.com/ndownloader/articles/13090442/versions/1" - md5 = "eae952f1b346d7e649d027e8139a76f5" + urls = { + 2019: "https://figshare.com/ndownloader/files/25070540", + 2018: "https://figshare.com/ndownloader/files/25070624", + 2017: "https://figshare.com/ndownloader/files/25070582", + } + md5s = { + 2019: "0d062bbd42e483fdc8239d22dba7020f", + 2018: "b3bb4894478d10786aa798fb11693ec1", + 2017: "d047fbe4a85341fa6248fd7e0badab6c", + } + fnames = { + 2019: "CDL2019_clip.tif", + 2018: "CDL2018_clip1.tif", + 2017: "CDL2017_clip.tif", + } cmap = { 0: (0, 255, 0, 255), @@ -75,6 +85,7 @@ def __init__( paths: Union[str, Iterable[str]] = "data", crs: Optional[CRS] = None, res: Optional[float] = None, + years: list[int] = [2019], transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, cache: bool = True, download: bool = False, @@ -88,6 +99,7 @@ def __init__( (defaults to the CRS of the first file found) res: resolution of the dataset in units of CRS (defaults to the resolution of the first file found) + years: list of years for which to use nccm layers transforms: a function/transform that takes an input sample and returns a transformed version cache: if True, cache file handle to speed up repeated sampling @@ -97,7 +109,12 @@ def __init__( Raises: DatasetNotFoundError: If dataset is not found and *download* is False. """ + assert set(years) <= self.md5s.keys(), ( + "NCCM data product only exists for the following years: " + f"{list(self.md5s.keys())}." + ) self.paths = paths + self.years = years self.download = download self.checksum = checksum self.ordinal_map = torch.full((max(self.cmap.keys()) + 1,), 4, dtype=self.dtype) @@ -128,37 +145,26 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: def _verify(self) -> None: """Verify the integrity of the dataset.""" - # Check if the extracted files already exist + # Check if the files already exist if self.files: return - # Check if the zip file has already been downloaded - assert isinstance(self.paths, str) - pathname = os.path.join(self.paths, "**", self.zipfile_glob) - if glob.glob(pathname, recursive=True): - self._extract() - return - # Check if the user requested to download the dataset if not self.download: raise DatasetNotFoundError(self) # Download the dataset self._download() - self._extract() def _download(self) -> None: """Download the dataset.""" - filename = "13090442.zip" - download_url( - self.url, self.paths, filename, md5=self.md5 if self.checksum else None - ) - - def _extract(self) -> None: - """Extract the dataset.""" - assert isinstance(self.paths, str) - pathname = os.path.join(self.paths, "**", self.zipfile_glob) - extract_archive(glob.glob(pathname, recursive=True)[0], self.paths) + for year in self.years: + download_url( + self.urls[year], + self.paths, + filename=self.fnames[year], + md5=self.md5s[year] if self.checksum else None, + ) def plot( self,