diff --git a/tests/datasets/test_l7irish.py b/tests/datasets/test_l7irish.py index 5ae714b7645..f0600d74168 100644 --- a/tests/datasets/test_l7irish.py +++ b/tests/datasets/test_l7irish.py @@ -5,6 +5,7 @@ import os import shutil from pathlib import Path +from typing import cast import matplotlib.pyplot as plt import pytest @@ -65,7 +66,9 @@ def test_plot(self, dataset: L7Irish) -> None: plt.close() def test_already_extracted(self, dataset: L7Irish) -> None: - L7Irish(dataset.paths, download=True) + paths = cast(str, dataset.paths) + L7Irish(paths, download=True) + L7Irish([paths], download=True) def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join('tests', 'data', 'l7irish', '*.tar.gz') diff --git a/torchgeo/datasets/l7irish.py b/torchgeo/datasets/l7irish.py index e42a20dc388..7153738b391 100644 --- a/torchgeo/datasets/l7irish.py +++ b/torchgeo/datasets/l7irish.py @@ -5,20 +5,115 @@ import glob import os +import re from collections.abc import Callable, Iterable, Sequence from typing import Any, cast import matplotlib.pyplot as plt +import torch from matplotlib.figure import Figure from rasterio.crs import CRS +from rtree.index import Index, Property from torch import Tensor from .errors import DatasetNotFoundError, RGBBandsMissingError -from .geo import RasterDataset -from .utils import BoundingBox, download_url, extract_archive +from .geo import IntersectionDataset, RasterDataset +from .utils import BoundingBox, disambiguate_timestamp, download_url, extract_archive -class L7Irish(RasterDataset): +class L7IrishImage(RasterDataset): + """Images from the L7 Irish dataset.""" + + # https://landsat.usgs.gov/cloud-validation/cca_irish_2015/L7_Irish_Cloud_Validation_Masks.xml + filename_glob = 'L71*.TIF' + filename_regex = r""" + ^L71 + (?P\d{3}) + (?P\d{3}) + _(?P=wrs_row) + (?P\d{8}) + \.TIF$ + """ + date_format = '%Y%m%d' + is_image = True + rgb_bands = ['B30', 'B20', 'B10'] + all_bands = ['B10', 'B20', 'B30', 'B40', 'B50', 'B61', 'B62', 'B70', 'B80'] + + +class L7IrishMask(RasterDataset): + """Masks from the L7 Irish dataset.""" + + # https://landsat.usgs.gov/cloud-validation/cca_irish_2015/L7_Irish_Cloud_Validation_Masks.xml + filename_glob = 'L7_p*_r*_newmask2015.TIF' + filename_regex = r""" + ^L7 + _p(?P\d+) + _r(?P\d+) + _newmask2015\.TIF$ + """ + is_image = False + classes = ['Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud'] + ordinal_map = torch.zeros(256, dtype=torch.long) + ordinal_map[64] = 1 + ordinal_map[128] = 2 + ordinal_map[192] = 3 + ordinal_map[255] = 4 + + def __init__( + self, + paths: str | Iterable[str] = 'data', + crs: CRS | None = None, + res: float | None = None, + bands: Sequence[str] | None = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + cache: bool = True, + ) -> None: + """Initialize a new L7IrishMask instance. + + Args: + paths: one or more root directories to search or files to load + crs: :term:`coordinate reference system (CRS)` to warp to + (defaults to the CRS of the first file found) + res: resolution of the dataset in units of CRS + (defaults to the resolution of the first file found) + bands: bands to return (defaults to all bands) + transforms: a function/transform that takes an input sample + and returns a transformed version + cache: if True, cache file handle to speed up repeated sampling + """ + super().__init__(paths, crs, res, bands, transforms, cache) + + # Mask filename does not include the date, grab it from the image filename + filename_regex = re.compile(L7IrishImage.filename_regex, re.VERBOSE) + index = Index(interleaved=False, properties=Property(dimension=3)) + for hit in self.index.intersection(self.index.bounds, objects=True): + dirname = os.path.dirname(cast(str, hit.object)) + image = glob.glob(os.path.join(dirname, L7IrishImage.filename_glob))[0] + minx, maxx, miny, maxy, mint, maxt = hit.bounds + if match := re.match(filename_regex, os.path.basename(image)): + date = match.group('date') + mint, maxt = disambiguate_timestamp(date, L7IrishImage.date_format) + index.insert(hit.id, (minx, maxx, miny, maxy, mint, maxt), hit.object) + self.index = index + + def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + """Retrieve image/mask and metadata indexed by query. + + Args: + query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index + + Returns: + sample of image, mask and metadata at that index + + Raises: + IndexError: if query is not found in the index + """ + sample = super().__getitem__(query) + sample['mask'] = self.ordinal_map[sample['mask']] + return sample + + +class L7Irish(IntersectionDataset): """L7 Irish dataset. The `L7 Irish `__ @@ -72,30 +167,12 @@ class L7Irish(RasterDataset): 'tropical': 'd7931419c70f3520a17361d96f1a4810', } - classes = ['Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud'] - - # https://landsat.usgs.gov/cloud-validation/cca_irish_2015/L7_Irish_Cloud_Validation_Masks.xml - filename_glob = 'L71*.TIF' - filename_regex = r""" - ^L71 - (?P\d{3}) - (?P\d{3}) - _(?P=wrs_row) - (?P\d{8}) - \.TIF$ - """ - date_format = '%Y%m%d' - - separate_files = False - rgb_bands = ['B30', 'B20', 'B10'] - all_bands = ['B10', 'B20', 'B30', 'B40', 'B50', 'B61', 'B62', 'B70', 'B80'] - def __init__( self, paths: str | Iterable[str] = 'data', crs: CRS | None = CRS.from_epsg(3857), res: float | None = None, - bands: Sequence[str] = all_bands, + bands: Sequence[str] = L7IrishImage.all_bands, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, cache: bool = True, download: bool = False, @@ -125,18 +202,37 @@ def __init__( self._verify() - super().__init__( - paths, crs=crs, res=res, bands=bands, transforms=transforms, cache=cache - ) + self.image = L7IrishImage(paths, crs, res, bands, transforms, cache) + self.mask = L7IrishMask(paths, crs, res, None, transforms, cache) + + super().__init__(self.image, self.mask) + + def _merge_dataset_indices(self) -> None: + """Create a new R-tree out of the individual indices from two datasets.""" + i = 0 + ds1, ds2 = self.datasets + for hit1 in ds1.index.intersection(ds1.index.bounds, objects=True): + for hit2 in ds2.index.intersection(hit1.bounds, objects=True): + box1 = BoundingBox(*hit1.bounds) + box2 = BoundingBox(*hit2.bounds) + if box1 == box2: + self.index.insert(i, tuple(box1 & box2)) + i += 1 def _verify(self) -> None: """Verify the integrity of the dataset.""" # Check if the extracted files already exist - if self.files: + if not isinstance(self.paths, str): + return + + for classname in [L7IrishImage, L7IrishMask]: + pathname = os.path.join(self.paths, '**', classname.filename_glob) + if not glob.glob(pathname, recursive=True): + break + else: return # Check if the tar.gz files have already been downloaded - assert isinstance(self.paths, str) pathname = os.path.join(self.paths, '*.tar.gz') if glob.glob(pathname): self._extract() @@ -164,54 +260,6 @@ def _extract(self) -> None: for tarfile in glob.iglob(pathname): extract_archive(tarfile) - def __getitem__(self, query: BoundingBox) -> dict[str, Any]: - """Retrieve image/mask and metadata indexed by query. - - Args: - query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index - - Returns: - sample of image, mask and metadata at that index - - Raises: - IndexError: if query is not found in the index - """ - hits = self.index.intersection(tuple(query), objects=True) - filepaths = cast(list[str], [hit.object for hit in hits]) - - if not filepaths: - raise IndexError( - f'query: {query} not found in index with bounds: {self.bounds}' - ) - - image = self._merge_files(filepaths, query, self.band_indexes) - - mask_filepaths = [] - for filepath in filepaths: - path, row = os.path.basename(os.path.dirname(filepath)).split('_')[:2] - mask_filepath = filepath.replace( - os.path.basename(filepath), f'L7_{path}_{row}_newmask2015.TIF' - ) - mask_filepaths.append(mask_filepath) - - mask = self._merge_files(mask_filepaths, query) - mask_mapping = {64: 1, 128: 2, 192: 3, 255: 4} - - for k, v in mask_mapping.items(): - mask[mask == k] = v - - sample = { - 'crs': self.crs, - 'bbox': query, - 'image': image.float(), - 'mask': mask.long(), - } - - if self.transforms is not None: - sample = self.transforms(sample) - - return sample - def plot( self, sample: dict[str, Tensor], @@ -221,7 +269,7 @@ def plot( """Plot a sample from the dataset. Args: - sample: a sample returned by :meth:`__getitem__` + sample: a sample returned by :meth:`RasterDataset.__getitem__` show_titles: flag indicating whether to show titles above each panel suptitle: optional string to use as a suptitle @@ -232,9 +280,9 @@ def plot( RGBBandsMissingError: If *bands* does not include all RGB bands. """ rgb_indices = [] - for band in self.rgb_bands: - if band in self.bands: - rgb_indices.append(self.bands.index(band)) + for band in self.image.rgb_bands: + if band in self.image.bands: + rgb_indices.append(self.image.bands.index(band)) else: raise RGBBandsMissingError()