Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

L8 Biome: convert to IntersectionDataset #2058

Merged
merged 1 commit into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion tests/datasets/test_l8biome.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import shutil
from pathlib import Path
from typing import cast

import matplotlib.pyplot as plt
import pytest
Expand Down Expand Up @@ -65,7 +66,9 @@ def test_plot(self, dataset: L8Biome) -> None:
plt.close()

def test_already_extracted(self, dataset: L8Biome) -> None:
L8Biome(dataset.paths, download=True)
paths = cast(str, dataset.paths)
L8Biome(paths, download=True)
L8Biome([paths], download=True)

def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join('tests', 'data', 'l8biome', '*.tar.gz')
Expand Down
157 changes: 80 additions & 77 deletions torchgeo/datasets/l8biome.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,79 @@
import glob
import os
from collections.abc import Callable, Iterable, Sequence
from typing import Any, cast
from typing import Any

import matplotlib.pyplot as plt
import torch
from matplotlib.figure import Figure
from rasterio.crs import CRS
from torch import Tensor

from .errors import DatasetNotFoundError, RGBBandsMissingError
from .geo import RasterDataset
from .geo import IntersectionDataset, RasterDataset
from .utils import BoundingBox, download_url, extract_archive


class L8Biome(RasterDataset):
class L8BiomeImage(RasterDataset):
"""Images from the L8 Biome dataset."""

# https://gisgeography.com/landsat-file-naming-convention/
filename_glob = 'LC8*.TIF'
filename_regex = r"""
^LC8
(?P<wrs_path>\d{3})
(?P<wrs_row>\d{3})
(?P<date>\d{7})
(?P<gsi>[A-Z]{3})
(?P<version>\d{2})
\.TIF$
"""
date_format = '%Y%j'
is_image = True
rgb_bands = ['B4', 'B3', 'B2']
all_bands = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B9', 'B10', 'B11']


class L8BiomeMask(RasterDataset):
"""Masks from the L8 Biome dataset."""

# https://gisgeography.com/landsat-file-naming-convention/
filename_glob = 'LC8*_fixedmask.TIF'
filename_regex = r"""
^LC8
(?P<wrs_path>\d{3})
(?P<wrs_row>\d{3})
(?P<date>\d{7})
(?P<gsi>[A-Z]{3})
(?P<version>\d{2})
_fixedmask
\.TIF$
"""
date_format = '%Y%j'
is_image = False
classes = ['Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud']
ordinal_map = torch.zeros(256, dtype=torch.long)
ordinal_map[64] = 1
ordinal_map[128] = 2
ordinal_map[192] = 3
ordinal_map[255] = 4

def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
"""Retrieve image/mask and metadata indexed by query.

Args:
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
Returns:
sample of image, mask and metadata at that index
Raises:
IndexError: if query is not found in the index
"""
sample = super().__getitem__(query)
sample['mask'] = self.ordinal_map[sample['mask']]
return sample


class L8Biome(IntersectionDataset):
"""L8 Biome dataset.

The `L8 Biome <https://landsat.usgs.gov/landsat-8-cloud-cover-assessment-validation-data>`__
Expand Down Expand Up @@ -70,31 +130,12 @@ class L8Biome(RasterDataset):
'wetlands': '1f86cc354631ca9a50ce54b7cab3f557',
}

classes = ['Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud']

# https://gisgeography.com/landsat-file-naming-convention/
filename_glob = 'LC8*.TIF'
filename_regex = r"""
^LC8
(?P<wrs_path>\d{3})
(?P<wrs_row>\d{3})
(?P<date>\d{7})
(?P<gsi>[A-Z]{3})
(?P<version>\d{2})
\.TIF$
"""
date_format = '%Y%j'

separate_files = False
rgb_bands = ['B4', 'B3', 'B2']
all_bands = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B9', 'B10', 'B11']

def __init__(
self,
paths: str | Iterable[str],
crs: CRS | None = CRS.from_epsg(3857),
res: float | None = None,
bands: Sequence[str] = all_bands,
bands: Sequence[str] = L8BiomeImage.all_bands,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
cache: bool = True,
download: bool = False,
Expand Down Expand Up @@ -124,18 +165,25 @@ def __init__(

self._verify()

super().__init__(
paths, crs=crs, res=res, bands=bands, transforms=transforms, cache=cache
)
self.image = L8BiomeImage(paths, crs, res, bands, transforms, cache)
self.mask = L8BiomeMask(paths, crs, res, None, transforms, cache)

super().__init__(self.image, self.mask)

def _verify(self) -> None:
"""Verify the integrity of the dataset."""
# Check if the extracted files already exist
if self.files:
if not isinstance(self.paths, str):
return

for classname in [L8BiomeImage, L8BiomeMask]:
pathname = os.path.join(self.paths, '**', classname.filename_glob)
if not glob.glob(pathname, recursive=True):
break
else:
return

# Check if the tar.gz files have already been downloaded
assert isinstance(self.paths, str)
pathname = os.path.join(self.paths, '*.tar.gz')
if glob.glob(pathname):
self._extract()
Expand Down Expand Up @@ -163,51 +211,6 @@ def _extract(self) -> None:
for tarfile in glob.iglob(pathname):
extract_archive(tarfile)

def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
"""Retrieve image/mask and metadata indexed by query.

Args:
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index

Returns:
sample of image, mask and metadata at that index

Raises:
IndexError: if query is not found in the index
"""
hits = self.index.intersection(tuple(query), objects=True)
filepaths = cast(list[str], [hit.object for hit in hits])

if not filepaths:
raise IndexError(
f'query: {query} not found in index with bounds: {self.bounds}'
)

image = self._merge_files(filepaths, query, self.band_indexes)

mask_filepaths = []
for filepath in filepaths:
mask_filepath = filepath.replace('.TIF', '_fixedmask.TIF')
mask_filepaths.append(mask_filepath)

mask = self._merge_files(mask_filepaths, query)
mask_mapping = {64: 1, 128: 2, 192: 3, 255: 4}

for k, v in mask_mapping.items():
mask[mask == k] = v

sample = {
'crs': self.crs,
'bbox': query,
'image': image.float(),
'mask': mask.long(),
}

if self.transforms is not None:
sample = self.transforms(sample)

return sample

def plot(
self,
sample: dict[str, Tensor],
Expand All @@ -217,7 +220,7 @@ def plot(
"""Plot a sample from the dataset.

Args:
sample: a sample returned by :meth:`__getitem__`
sample: a sample returned by :meth:`RasterDataset.__getitem__`
show_titles: flag indicating whether to show titles above each panel
suptitle: optional string to use as a suptitle

Expand All @@ -228,9 +231,9 @@ def plot(
RGBBandsMissingError: If *bands* does not include all RGB bands.
"""
rgb_indices = []
for band in self.rgb_bands:
if band in self.bands:
rgb_indices.append(self.bands.index(band))
for band in self.image.rgb_bands:
if band in self.image.bands:
rgb_indices.append(self.image.bands.index(band))
else:
raise RGBBandsMissingError()

Expand Down
Loading