Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve NCCM checksum error and add years args #1870

Merged
merged 6 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file removed tests/data/nccm/13090442.zip
Binary file not shown.
Binary file removed tests/data/nccm/13090442/CDL2017_clip.tif
Binary file not shown.
Binary file removed tests/data/nccm/13090442/CDL2018_clip1.tif
Binary file not shown.
Binary file removed tests/data/nccm/13090442/CDL2019_clip.tif
Binary file not shown.
Binary file added tests/data/nccm/CDL2017_clip.tif
Binary file not shown.
Binary file added tests/data/nccm/CDL2018_clip1.tif
Binary file not shown.
Binary file added tests/data/nccm/CDL2019_clip.tif
Binary file not shown.
17 changes: 5 additions & 12 deletions tests/data/nccm/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import hashlib
import os
import shutil

import numpy as np
import rasterio
Expand Down Expand Up @@ -48,20 +47,14 @@ def create_file(path: str, dtype: str):


if __name__ == "__main__":
dir = os.path.join(os.getcwd(), "13090442")

if os.path.exists(dir) and os.path.isdir(dir):
shutil.rmtree(dir)

dir = os.path.join(os.getcwd())
os.makedirs(dir, exist_ok=True)

for file in files:
create_file(os.path.join(dir, file), dtype="int8")

# Compress data
shutil.make_archive("13090442", "zip", ".", dir)

# Compute checksums
with open("13090442.zip", "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f"13090442.zip: {md5}")
for file in files:
with open(file, "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f"{file}: {md5}")
21 changes: 14 additions & 7 deletions tests/datasets/test_nccm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,19 @@ class TestNCCM:
@pytest.fixture
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> NCCM:
monkeypatch.setattr(torchgeo.datasets.nccm, "download_url", download_url)
url = os.path.join("tests", "data", "nccm", "13090442.zip")
md5s = {
2017: "ae5c390d0ffb8970d544b8a09142759f",
2018: "0d453bdb8ea5b7318c33e62513760580",
2019: "d4ab7ab00bb57623eafb6b27747e5639",
}
monkeypatch.setattr(NCCM, "md5s", md5s)
urls = {
2017: os.path.join("tests", "data", "nccm", "CDL2017_clip.tif"),
2018: os.path.join("tests", "data", "nccm", "CDL2018_clip1.tif"),
2019: os.path.join("tests", "data", "nccm", "CDL2019_clip.tif"),
}
monkeypatch.setattr(NCCM, "urls", urls)
transforms = nn.Identity()
monkeypatch.setattr(NCCM, "url", url)
root = str(tmp_path)
return NCCM(root, transforms=transforms, download=True, checksum=True)

Expand All @@ -48,11 +58,8 @@ def test_or(self, dataset: NCCM) -> None:
def test_already_extracted(self, dataset: NCCM) -> None:
NCCM(dataset.paths, download=True)

def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join("tests", "data", "nccm", "13090442.zip")
root = str(tmp_path)
shutil.copy(pathname, root)
NCCM(root)
def test_already_downloaded(self, dataset: NCCM) -> None:
NCCM(dataset.paths, download=True)

def test_plot(self, dataset: NCCM) -> None:
query = dataset.bounds
Expand Down
50 changes: 32 additions & 18 deletions torchgeo/datasets/nccm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from rasterio.crs import CRS

from .geo import RasterDataset
from .utils import BoundingBox, DatasetNotFoundError, download_url, extract_archive
from .utils import BoundingBox, DatasetNotFoundError, download_url


class NCCM(RasterDataset):
Expand Down Expand Up @@ -55,12 +55,24 @@ class NCCM(RasterDataset):

filename_regex = r"CDL(?P<year>\d{4})_clip"
filename_glob = "CDL*.*"
zipfile_glob = "13090442.zip"

date_format = "%Y"
is_image = False
url = "https://figshare.com/ndownloader/articles/13090442/versions/1"
md5 = "eae952f1b346d7e649d027e8139a76f5"
urls = {
2017: "https://figshare.com/ndownloader/files/25070582",
yichiac marked this conversation as resolved.
Show resolved Hide resolved
2018: "https://figshare.com/ndownloader/files/25070624",
2019: "https://figshare.com/ndownloader/files/25070540",
}
md5s = {
2017: "d047fbe4a85341fa6248fd7e0badab6c",
2018: "b3bb4894478d10786aa798fb11693ec1",
2019: "0d062bbd42e483fdc8239d22dba7020f",
}
fnames = {
2017: "CDL2017_clip.tif",
2018: "CDL2018_clip1.tif",
2019: "CDL2019_clip.tif",
}

cmap = {
0: (0, 255, 0, 255),
Expand All @@ -75,6 +87,7 @@ def __init__(
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
years: list[int] = [2017, 2018, 2019],
yichiac marked this conversation as resolved.
Show resolved Hide resolved
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
cache: bool = True,
download: bool = False,
Expand All @@ -88,6 +101,7 @@ def __init__(
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
(defaults to the resolution of the first file found)
years: list of years for which to use nccm layers
transforms: a function/transform that takes an input sample
and returns a transformed version
cache: if True, cache file handle to speed up repeated sampling
Expand All @@ -97,7 +111,12 @@ def __init__(
Raises:
DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert set(years) <= self.md5s.keys(), (
"NCCM data product only exists for the following years: "
f"{list(self.md5s.keys())}."
)
self.paths = paths
self.years = years
self.download = download
self.checksum = checksum
self.ordinal_map = torch.full((max(self.cmap.keys()) + 1,), 4, dtype=self.dtype)
Expand Down Expand Up @@ -132,11 +151,10 @@ def _verify(self) -> None:
if self.files:
return

# Check if the zip file has already been downloaded
# Check if the mask files have already been downloaded
yichiac marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(self.paths, str)
pathname = os.path.join(self.paths, "**", self.zipfile_glob)
pathname = os.path.join(self.paths, self.filename_glob)
if glob.glob(pathname, recursive=True):
self._extract()
return

# Check if the user requested to download the dataset
Expand All @@ -145,20 +163,16 @@ def _verify(self) -> None:

# Download the dataset
self._download()
self._extract()

def _download(self) -> None:
"""Download the dataset."""
filename = "13090442.zip"
download_url(
self.url, self.paths, filename, md5=self.md5 if self.checksum else None
)

def _extract(self) -> None:
"""Extract the dataset."""
assert isinstance(self.paths, str)
pathname = os.path.join(self.paths, "**", self.zipfile_glob)
extract_archive(glob.glob(pathname, recursive=True)[0], self.paths)
for year in self.years:
download_url(
self.urls[year],
self.paths,
filename=self.fnames[year],
md5=self.md5s[year] if self.checksum else None,
)

def plot(
self,
Expand Down
Loading