Skip to content

Commit

Permalink
Resolve NCCM checksum error and add years args (#1870)
Browse files Browse the repository at this point in the history
* add new download links, years Args, and new test data

* remove download test file

* include all years by default

* sort year and verify
  • Loading branch information
yichiac authored Feb 12, 2024
1 parent 8af188c commit f3270ca
Show file tree
Hide file tree
Showing 10 changed files with 50 additions and 44 deletions.
Binary file removed tests/data/nccm/13090442.zip
Binary file not shown.
Binary file removed tests/data/nccm/13090442/CDL2017_clip.tif
Binary file not shown.
Binary file removed tests/data/nccm/13090442/CDL2018_clip1.tif
Binary file not shown.
Binary file removed tests/data/nccm/13090442/CDL2019_clip.tif
Binary file not shown.
Binary file added tests/data/nccm/CDL2017_clip.tif
Binary file not shown.
Binary file added tests/data/nccm/CDL2018_clip1.tif
Binary file not shown.
Binary file added tests/data/nccm/CDL2019_clip.tif
Binary file not shown.
17 changes: 5 additions & 12 deletions tests/data/nccm/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import hashlib
import os
import shutil

import numpy as np
import rasterio
Expand Down Expand Up @@ -48,20 +47,14 @@ def create_file(path: str, dtype: str):


if __name__ == "__main__":
dir = os.path.join(os.getcwd(), "13090442")

if os.path.exists(dir) and os.path.isdir(dir):
shutil.rmtree(dir)

dir = os.path.join(os.getcwd())
os.makedirs(dir, exist_ok=True)

for file in files:
create_file(os.path.join(dir, file), dtype="int8")

# Compress data
shutil.make_archive("13090442", "zip", ".", dir)

# Compute checksums
with open("13090442.zip", "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f"13090442.zip: {md5}")
for file in files:
with open(file, "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f"{file}: {md5}")
21 changes: 14 additions & 7 deletions tests/datasets/test_nccm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,19 @@ class TestNCCM:
@pytest.fixture
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> NCCM:
monkeypatch.setattr(torchgeo.datasets.nccm, "download_url", download_url)
url = os.path.join("tests", "data", "nccm", "13090442.zip")
md5s = {
2017: "ae5c390d0ffb8970d544b8a09142759f",
2018: "0d453bdb8ea5b7318c33e62513760580",
2019: "d4ab7ab00bb57623eafb6b27747e5639",
}
monkeypatch.setattr(NCCM, "md5s", md5s)
urls = {
2017: os.path.join("tests", "data", "nccm", "CDL2017_clip.tif"),
2018: os.path.join("tests", "data", "nccm", "CDL2018_clip1.tif"),
2019: os.path.join("tests", "data", "nccm", "CDL2019_clip.tif"),
}
monkeypatch.setattr(NCCM, "urls", urls)
transforms = nn.Identity()
monkeypatch.setattr(NCCM, "url", url)
root = str(tmp_path)
return NCCM(root, transforms=transforms, download=True, checksum=True)

Expand All @@ -48,11 +58,8 @@ def test_or(self, dataset: NCCM) -> None:
def test_already_extracted(self, dataset: NCCM) -> None:
NCCM(dataset.paths, download=True)

def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join("tests", "data", "nccm", "13090442.zip")
root = str(tmp_path)
shutil.copy(pathname, root)
NCCM(root)
def test_already_downloaded(self, dataset: NCCM) -> None:
NCCM(dataset.paths, download=True)

def test_plot(self, dataset: NCCM) -> None:
query = dataset.bounds
Expand Down
56 changes: 31 additions & 25 deletions torchgeo/datasets/nccm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

"""Northeastern China Crop Map Dataset."""

import glob
import os
from collections.abc import Iterable
from typing import Any, Callable, Optional, Union

Expand All @@ -14,7 +12,7 @@
from rasterio.crs import CRS

from .geo import RasterDataset
from .utils import BoundingBox, DatasetNotFoundError, download_url, extract_archive
from .utils import BoundingBox, DatasetNotFoundError, download_url


class NCCM(RasterDataset):
Expand Down Expand Up @@ -55,12 +53,24 @@ class NCCM(RasterDataset):

filename_regex = r"CDL(?P<year>\d{4})_clip"
filename_glob = "CDL*.*"
zipfile_glob = "13090442.zip"

date_format = "%Y"
is_image = False
url = "https://figshare.com/ndownloader/articles/13090442/versions/1"
md5 = "eae952f1b346d7e649d027e8139a76f5"
urls = {
2019: "https://figshare.com/ndownloader/files/25070540",
2018: "https://figshare.com/ndownloader/files/25070624",
2017: "https://figshare.com/ndownloader/files/25070582",
}
md5s = {
2019: "0d062bbd42e483fdc8239d22dba7020f",
2018: "b3bb4894478d10786aa798fb11693ec1",
2017: "d047fbe4a85341fa6248fd7e0badab6c",
}
fnames = {
2019: "CDL2019_clip.tif",
2018: "CDL2018_clip1.tif",
2017: "CDL2017_clip.tif",
}

cmap = {
0: (0, 255, 0, 255),
Expand All @@ -75,6 +85,7 @@ def __init__(
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
years: list[int] = [2019],
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
cache: bool = True,
download: bool = False,
Expand All @@ -88,6 +99,7 @@ def __init__(
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
(defaults to the resolution of the first file found)
years: list of years for which to use nccm layers
transforms: a function/transform that takes an input sample
and returns a transformed version
cache: if True, cache file handle to speed up repeated sampling
Expand All @@ -97,7 +109,12 @@ def __init__(
Raises:
DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert set(years) <= self.md5s.keys(), (
"NCCM data product only exists for the following years: "
f"{list(self.md5s.keys())}."
)
self.paths = paths
self.years = years
self.download = download
self.checksum = checksum
self.ordinal_map = torch.full((max(self.cmap.keys()) + 1,), 4, dtype=self.dtype)
Expand Down Expand Up @@ -128,37 +145,26 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:

def _verify(self) -> None:
"""Verify the integrity of the dataset."""
# Check if the extracted files already exist
# Check if the files already exist
if self.files:
return

# Check if the zip file has already been downloaded
assert isinstance(self.paths, str)
pathname = os.path.join(self.paths, "**", self.zipfile_glob)
if glob.glob(pathname, recursive=True):
self._extract()
return

# Check if the user requested to download the dataset
if not self.download:
raise DatasetNotFoundError(self)

# Download the dataset
self._download()
self._extract()

def _download(self) -> None:
"""Download the dataset."""
filename = "13090442.zip"
download_url(
self.url, self.paths, filename, md5=self.md5 if self.checksum else None
)

def _extract(self) -> None:
"""Extract the dataset."""
assert isinstance(self.paths, str)
pathname = os.path.join(self.paths, "**", self.zipfile_glob)
extract_archive(glob.glob(pathname, recursive=True)[0], self.paths)
for year in self.years:
download_url(
self.urls[year],
self.paths,
filename=self.fnames[year],
md5=self.md5s[year] if self.checksum else None,
)

def plot(
self,
Expand Down

0 comments on commit f3270ca

Please sign in to comment.