Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve NCCM checksum error and add years args #1870

Merged
merged 6 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file removed tests/data/nccm/13090442.zip
Binary file not shown.
Binary file removed tests/data/nccm/13090442/CDL2017_clip.tif
Binary file not shown.
Binary file removed tests/data/nccm/13090442/CDL2018_clip1.tif
Binary file not shown.
Binary file removed tests/data/nccm/13090442/CDL2019_clip.tif
Binary file not shown.
Binary file added tests/data/nccm/CDL2017_clip.tif
Binary file not shown.
Binary file added tests/data/nccm/CDL2018_clip1.tif
Binary file not shown.
Binary file added tests/data/nccm/CDL2019_clip.tif
Binary file not shown.
17 changes: 5 additions & 12 deletions tests/data/nccm/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import hashlib
import os
import shutil

import numpy as np
import rasterio
Expand Down Expand Up @@ -48,20 +47,14 @@ def create_file(path: str, dtype: str):


if __name__ == "__main__":
dir = os.path.join(os.getcwd(), "13090442")

if os.path.exists(dir) and os.path.isdir(dir):
shutil.rmtree(dir)

dir = os.path.join(os.getcwd())
os.makedirs(dir, exist_ok=True)

for file in files:
create_file(os.path.join(dir, file), dtype="int8")

# Compress data
shutil.make_archive("13090442", "zip", ".", dir)

# Compute checksums
with open("13090442.zip", "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f"13090442.zip: {md5}")
for file in files:
with open(file, "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f"{file}: {md5}")
21 changes: 14 additions & 7 deletions tests/datasets/test_nccm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,19 @@ class TestNCCM:
@pytest.fixture
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> NCCM:
monkeypatch.setattr(torchgeo.datasets.nccm, "download_url", download_url)
url = os.path.join("tests", "data", "nccm", "13090442.zip")
md5s = {
2017: "ae5c390d0ffb8970d544b8a09142759f",
2018: "0d453bdb8ea5b7318c33e62513760580",
2019: "d4ab7ab00bb57623eafb6b27747e5639",
}
monkeypatch.setattr(NCCM, "md5s", md5s)
urls = {
2017: os.path.join("tests", "data", "nccm", "CDL2017_clip.tif"),
2018: os.path.join("tests", "data", "nccm", "CDL2018_clip1.tif"),
2019: os.path.join("tests", "data", "nccm", "CDL2019_clip.tif"),
}
monkeypatch.setattr(NCCM, "urls", urls)
transforms = nn.Identity()
monkeypatch.setattr(NCCM, "url", url)
root = str(tmp_path)
return NCCM(root, transforms=transforms, download=True, checksum=True)

Expand All @@ -48,11 +58,8 @@ def test_or(self, dataset: NCCM) -> None:
def test_already_extracted(self, dataset: NCCM) -> None:
NCCM(dataset.paths, download=True)

def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join("tests", "data", "nccm", "13090442.zip")
root = str(tmp_path)
shutil.copy(pathname, root)
NCCM(root)
def test_already_downloaded(self, dataset: NCCM) -> None:
NCCM(dataset.paths, download=True)

def test_plot(self, dataset: NCCM) -> None:
query = dataset.bounds
Expand Down
56 changes: 31 additions & 25 deletions torchgeo/datasets/nccm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

"""Northeastern China Crop Map Dataset."""

import glob
import os
from collections.abc import Iterable
from typing import Any, Callable, Optional, Union

Expand All @@ -14,7 +12,7 @@
from rasterio.crs import CRS

from .geo import RasterDataset
from .utils import BoundingBox, DatasetNotFoundError, download_url, extract_archive
from .utils import BoundingBox, DatasetNotFoundError, download_url


class NCCM(RasterDataset):
Expand Down Expand Up @@ -55,12 +53,24 @@ class NCCM(RasterDataset):

filename_regex = r"CDL(?P<year>\d{4})_clip"
filename_glob = "CDL*.*"
zipfile_glob = "13090442.zip"

date_format = "%Y"
is_image = False
url = "https://figshare.com/ndownloader/articles/13090442/versions/1"
md5 = "eae952f1b346d7e649d027e8139a76f5"
urls = {
2019: "https://figshare.com/ndownloader/files/25070540",
2018: "https://figshare.com/ndownloader/files/25070624",
2017: "https://figshare.com/ndownloader/files/25070582",
}
md5s = {
2019: "0d062bbd42e483fdc8239d22dba7020f",
2018: "b3bb4894478d10786aa798fb11693ec1",
2017: "d047fbe4a85341fa6248fd7e0badab6c",
}
fnames = {
2019: "CDL2019_clip.tif",
2018: "CDL2018_clip1.tif",
2017: "CDL2017_clip.tif",
}

cmap = {
0: (0, 255, 0, 255),
Expand All @@ -75,6 +85,7 @@ def __init__(
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
years: list[int] = [2019],
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
cache: bool = True,
download: bool = False,
Expand All @@ -88,6 +99,7 @@ def __init__(
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
(defaults to the resolution of the first file found)
years: list of years for which to use nccm layers
transforms: a function/transform that takes an input sample
and returns a transformed version
cache: if True, cache file handle to speed up repeated sampling
Expand All @@ -97,7 +109,12 @@ def __init__(
Raises:
DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert set(years) <= self.md5s.keys(), (
"NCCM data product only exists for the following years: "
f"{list(self.md5s.keys())}."
)
self.paths = paths
self.years = years
self.download = download
self.checksum = checksum
self.ordinal_map = torch.full((max(self.cmap.keys()) + 1,), 4, dtype=self.dtype)
Expand Down Expand Up @@ -128,37 +145,26 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:

def _verify(self) -> None:
"""Verify the integrity of the dataset."""
# Check if the extracted files already exist
# Check if the files already exist
if self.files:
return

# Check if the zip file has already been downloaded
assert isinstance(self.paths, str)
pathname = os.path.join(self.paths, "**", self.zipfile_glob)
if glob.glob(pathname, recursive=True):
self._extract()
return

# Check if the user requested to download the dataset
if not self.download:
raise DatasetNotFoundError(self)

# Download the dataset
self._download()
self._extract()

def _download(self) -> None:
"""Download the dataset."""
filename = "13090442.zip"
download_url(
self.url, self.paths, filename, md5=self.md5 if self.checksum else None
)

def _extract(self) -> None:
"""Extract the dataset."""
assert isinstance(self.paths, str)
pathname = os.path.join(self.paths, "**", self.zipfile_glob)
extract_archive(glob.glob(pathname, recursive=True)[0], self.paths)
for year in self.years:
download_url(
self.urls[year],
self.paths,
filename=self.fnames[year],
md5=self.md5s[year] if self.checksum else None,
)

def plot(
self,
Expand Down
Loading