diff --git a/tests/datasets/test_l8biome.py b/tests/datasets/test_l8biome.py index 96c3209ca0a..99b09501450 100644 --- a/tests/datasets/test_l8biome.py +++ b/tests/datasets/test_l8biome.py @@ -5,6 +5,7 @@ import os import shutil from pathlib import Path +from typing import cast import matplotlib.pyplot as plt import pytest @@ -65,7 +66,9 @@ def test_plot(self, dataset: L8Biome) -> None: plt.close() def test_already_extracted(self, dataset: L8Biome) -> None: - L8Biome(dataset.paths, download=True) + paths = cast(str, dataset.paths) + L8Biome(paths, download=True) + L8Biome([paths], download=True) def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join('tests', 'data', 'l8biome', '*.tar.gz') diff --git a/torchgeo/datasets/l8biome.py b/torchgeo/datasets/l8biome.py index 05e2ef0b58d..c200b5c63bc 100644 --- a/torchgeo/datasets/l8biome.py +++ b/torchgeo/datasets/l8biome.py @@ -6,19 +6,79 @@ import glob import os from collections.abc import Callable, Iterable, Sequence -from typing import Any, cast +from typing import Any import matplotlib.pyplot as plt +import torch from matplotlib.figure import Figure from rasterio.crs import CRS from torch import Tensor from .errors import DatasetNotFoundError, RGBBandsMissingError -from .geo import RasterDataset +from .geo import IntersectionDataset, RasterDataset from .utils import BoundingBox, download_url, extract_archive -class L8Biome(RasterDataset): +class L8BiomeImage(RasterDataset): + """Images from the L8 Biome dataset.""" + + # https://gisgeography.com/landsat-file-naming-convention/ + filename_glob = 'LC8*.TIF' + filename_regex = r""" + ^LC8 + (?P\d{3}) + (?P\d{3}) + (?P\d{7}) + (?P[A-Z]{3}) + (?P\d{2}) + \.TIF$ + """ + date_format = '%Y%j' + is_image = True + rgb_bands = ['B4', 'B3', 'B2'] + all_bands = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B9', 'B10', 'B11'] + + +class L8BiomeMask(RasterDataset): + """Masks from the L8 Biome dataset.""" + + # https://gisgeography.com/landsat-file-naming-convention/ + filename_glob = 'LC8*_fixedmask.TIF' + filename_regex = r""" + ^LC8 + (?P\d{3}) + (?P\d{3}) + (?P\d{7}) + (?P[A-Z]{3}) + (?P\d{2}) + _fixedmask + \.TIF$ + """ + date_format = '%Y%j' + is_image = False + classes = ['Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud'] + ordinal_map = torch.zeros(256, dtype=torch.long) + ordinal_map[64] = 1 + ordinal_map[128] = 2 + ordinal_map[192] = 3 + ordinal_map[255] = 4 + + def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + """Retrieve image/mask and metadata indexed by query. + + Args: + query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index + Returns: + sample of image, mask and metadata at that index + Raises: + IndexError: if query is not found in the index + """ + sample = super().__getitem__(query) + sample['mask'] = self.ordinal_map[sample['mask']] + return sample + + +class L8Biome(IntersectionDataset): """L8 Biome dataset. The `L8 Biome `__ @@ -70,31 +130,12 @@ class L8Biome(RasterDataset): 'wetlands': '1f86cc354631ca9a50ce54b7cab3f557', } - classes = ['Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud'] - - # https://gisgeography.com/landsat-file-naming-convention/ - filename_glob = 'LC8*.TIF' - filename_regex = r""" - ^LC8 - (?P\d{3}) - (?P\d{3}) - (?P\d{7}) - (?P[A-Z]{3}) - (?P\d{2}) - \.TIF$ - """ - date_format = '%Y%j' - - separate_files = False - rgb_bands = ['B4', 'B3', 'B2'] - all_bands = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B9', 'B10', 'B11'] - def __init__( self, paths: str | Iterable[str], crs: CRS | None = CRS.from_epsg(3857), res: float | None = None, - bands: Sequence[str] = all_bands, + bands: Sequence[str] = L8BiomeImage.all_bands, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, cache: bool = True, download: bool = False, @@ -124,18 +165,25 @@ def __init__( self._verify() - super().__init__( - paths, crs=crs, res=res, bands=bands, transforms=transforms, cache=cache - ) + self.image = L8BiomeImage(paths, crs, res, bands, transforms, cache) + self.mask = L8BiomeMask(paths, crs, res, None, transforms, cache) + + super().__init__(self.image, self.mask) def _verify(self) -> None: """Verify the integrity of the dataset.""" # Check if the extracted files already exist - if self.files: + if not isinstance(self.paths, str): + return + + for classname in [L8BiomeImage, L8BiomeMask]: + pathname = os.path.join(self.paths, '**', classname.filename_glob) + if not glob.glob(pathname, recursive=True): + break + else: return # Check if the tar.gz files have already been downloaded - assert isinstance(self.paths, str) pathname = os.path.join(self.paths, '*.tar.gz') if glob.glob(pathname): self._extract() @@ -163,51 +211,6 @@ def _extract(self) -> None: for tarfile in glob.iglob(pathname): extract_archive(tarfile) - def __getitem__(self, query: BoundingBox) -> dict[str, Any]: - """Retrieve image/mask and metadata indexed by query. - - Args: - query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index - - Returns: - sample of image, mask and metadata at that index - - Raises: - IndexError: if query is not found in the index - """ - hits = self.index.intersection(tuple(query), objects=True) - filepaths = cast(list[str], [hit.object for hit in hits]) - - if not filepaths: - raise IndexError( - f'query: {query} not found in index with bounds: {self.bounds}' - ) - - image = self._merge_files(filepaths, query, self.band_indexes) - - mask_filepaths = [] - for filepath in filepaths: - mask_filepath = filepath.replace('.TIF', '_fixedmask.TIF') - mask_filepaths.append(mask_filepath) - - mask = self._merge_files(mask_filepaths, query) - mask_mapping = {64: 1, 128: 2, 192: 3, 255: 4} - - for k, v in mask_mapping.items(): - mask[mask == k] = v - - sample = { - 'crs': self.crs, - 'bbox': query, - 'image': image.float(), - 'mask': mask.long(), - } - - if self.transforms is not None: - sample = self.transforms(sample) - - return sample - def plot( self, sample: dict[str, Tensor], @@ -217,7 +220,7 @@ def plot( """Plot a sample from the dataset. Args: - sample: a sample returned by :meth:`__getitem__` + sample: a sample returned by :meth:`RasterDataset.__getitem__` show_titles: flag indicating whether to show titles above each panel suptitle: optional string to use as a suptitle @@ -228,9 +231,9 @@ def plot( RGBBandsMissingError: If *bands* does not include all RGB bands. """ rgb_indices = [] - for band in self.rgb_bands: - if band in self.bands: - rgb_indices.append(self.bands.index(band)) + for band in self.image.rgb_bands: + if band in self.image.bands: + rgb_indices.append(self.image.bands.index(band)) else: raise RGBBandsMissingError()