Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

L7 Irish: convert to IntersectionDataset #2034

Merged
merged 7 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion tests/datasets/test_l7irish.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import shutil
from pathlib import Path
from typing import cast

import matplotlib.pyplot as plt
import pytest
Expand Down Expand Up @@ -65,7 +66,9 @@ def test_plot(self, dataset: L7Irish) -> None:
plt.close()

def test_already_extracted(self, dataset: L7Irish) -> None:
L7Irish(dataset.paths, download=True)
paths = cast(str, dataset.paths)
L7Irish(paths, download=True)
L7Irish([paths], download=True)

def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join('tests', 'data', 'l7irish', '*.tar.gz')
Expand Down
206 changes: 127 additions & 79 deletions torchgeo/datasets/l7irish.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,115 @@

import glob
import os
import re
from collections.abc import Callable, Iterable, Sequence
from typing import Any, cast

import matplotlib.pyplot as plt
import torch
from matplotlib.figure import Figure
from rasterio.crs import CRS
from rtree.index import Index, Property
from torch import Tensor

from .errors import DatasetNotFoundError, RGBBandsMissingError
from .geo import RasterDataset
from .utils import BoundingBox, download_url, extract_archive
from .geo import IntersectionDataset, RasterDataset
from .utils import BoundingBox, disambiguate_timestamp, download_url, extract_archive


class L7Irish(RasterDataset):
class L7IrishImage(RasterDataset):
"""Images from the L7 Irish dataset."""

# https://landsat.usgs.gov/cloud-validation/cca_irish_2015/L7_Irish_Cloud_Validation_Masks.xml
filename_glob = 'L71*.TIF'
filename_regex = r"""
^L71
(?P<wrs_path>\d{3})
(?P<wrs_row>\d{3})
_(?P=wrs_row)
(?P<date>\d{8})
\.TIF$
"""
date_format = '%Y%m%d'
is_image = True
rgb_bands = ['B30', 'B20', 'B10']
all_bands = ['B10', 'B20', 'B30', 'B40', 'B50', 'B61', 'B62', 'B70', 'B80']


class L7IrishMask(RasterDataset):
"""Masks from the L7 Irish dataset."""

# https://landsat.usgs.gov/cloud-validation/cca_irish_2015/L7_Irish_Cloud_Validation_Masks.xml
filename_glob = 'L7_p*_r*_newmask2015.TIF'
filename_regex = r"""
^L7
_p(?P<wrs_path>\d+)
_r(?P<wrs_row>\d+)
_newmask2015\.TIF$
"""
is_image = False
classes = ['Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud']
ordinal_map = torch.zeros(256, dtype=torch.long)
ordinal_map[64] = 1
ordinal_map[128] = 2
ordinal_map[192] = 3
ordinal_map[255] = 4

def __init__(
self,
paths: str | Iterable[str] = 'data',
crs: CRS | None = None,
res: float | None = None,
bands: Sequence[str] | None = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
cache: bool = True,
) -> None:
"""Initialize a new L7IrishMask instance.

Args:
paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
(defaults to the resolution of the first file found)
bands: bands to return (defaults to all bands)
transforms: a function/transform that takes an input sample
and returns a transformed version
cache: if True, cache file handle to speed up repeated sampling
"""
super().__init__(paths, crs, res, bands, transforms, cache)

# Mask filename does not include the date, grab it from the image filename
filename_regex = re.compile(L7IrishImage.filename_regex, re.VERBOSE)
index = Index(interleaved=False, properties=Property(dimension=3))
for hit in self.index.intersection(self.index.bounds, objects=True):
dirname = os.path.dirname(cast(str, hit.object))
image = glob.glob(os.path.join(dirname, L7IrishImage.filename_glob))[0]
minx, maxx, miny, maxy, mint, maxt = hit.bounds
if match := re.match(filename_regex, os.path.basename(image)):
date = match.group('date')
mint, maxt = disambiguate_timestamp(date, L7IrishImage.date_format)
index.insert(hit.id, (minx, maxx, miny, maxy, mint, maxt), hit.object)
self.index = index

def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
"""Retrieve image/mask and metadata indexed by query.

Args:
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index

Returns:
sample of image, mask and metadata at that index

Raises:
IndexError: if query is not found in the index
"""
sample = super().__getitem__(query)
sample['mask'] = self.ordinal_map[sample['mask']]
return sample


class L7Irish(IntersectionDataset):
"""L7 Irish dataset.

The `L7 Irish <https://landsat.usgs.gov/landsat-7-cloud-cover-assessment-validation-data>`__
Expand Down Expand Up @@ -72,30 +167,12 @@ class L7Irish(RasterDataset):
'tropical': 'd7931419c70f3520a17361d96f1a4810',
}

classes = ['Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud']

# https://landsat.usgs.gov/cloud-validation/cca_irish_2015/L7_Irish_Cloud_Validation_Masks.xml
filename_glob = 'L71*.TIF'
filename_regex = r"""
^L71
(?P<wrs_path>\d{3})
(?P<wrs_row>\d{3})
_(?P=wrs_row)
(?P<date>\d{8})
\.TIF$
"""
date_format = '%Y%m%d'

separate_files = False
rgb_bands = ['B30', 'B20', 'B10']
all_bands = ['B10', 'B20', 'B30', 'B40', 'B50', 'B61', 'B62', 'B70', 'B80']

def __init__(
self,
paths: str | Iterable[str] = 'data',
crs: CRS | None = CRS.from_epsg(3857),
res: float | None = None,
bands: Sequence[str] = all_bands,
bands: Sequence[str] = L7IrishImage.all_bands,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
cache: bool = True,
download: bool = False,
Expand Down Expand Up @@ -125,18 +202,37 @@ def __init__(

self._verify()

super().__init__(
paths, crs=crs, res=res, bands=bands, transforms=transforms, cache=cache
)
self.image = L7IrishImage(paths, crs, res, bands, transforms, cache)
self.mask = L7IrishMask(paths, crs, res, None, transforms, cache)

super().__init__(self.image, self.mask)

def _merge_dataset_indices(self) -> None:
"""Create a new R-tree out of the individual indices from two datasets."""
i = 0
ds1, ds2 = self.datasets
for hit1 in ds1.index.intersection(ds1.index.bounds, objects=True):
for hit2 in ds2.index.intersection(hit1.bounds, objects=True):
box1 = BoundingBox(*hit1.bounds)
box2 = BoundingBox(*hit2.bounds)
if box1 == box2:
self.index.insert(i, tuple(box1 & box2))
i += 1

def _verify(self) -> None:
"""Verify the integrity of the dataset."""
# Check if the extracted files already exist
if self.files:
if not isinstance(self.paths, str):
return

for classname in [L7IrishImage, L7IrishMask]:
pathname = os.path.join(self.paths, '**', classname.filename_glob)
if not glob.glob(pathname, recursive=True):
break
else:
return

# Check if the tar.gz files have already been downloaded
assert isinstance(self.paths, str)
pathname = os.path.join(self.paths, '*.tar.gz')
if glob.glob(pathname):
self._extract()
Expand Down Expand Up @@ -164,54 +260,6 @@ def _extract(self) -> None:
for tarfile in glob.iglob(pathname):
extract_archive(tarfile)

def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
"""Retrieve image/mask and metadata indexed by query.

Args:
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index

Returns:
sample of image, mask and metadata at that index

Raises:
IndexError: if query is not found in the index
"""
hits = self.index.intersection(tuple(query), objects=True)
filepaths = cast(list[str], [hit.object for hit in hits])

if not filepaths:
raise IndexError(
f'query: {query} not found in index with bounds: {self.bounds}'
)

image = self._merge_files(filepaths, query, self.band_indexes)

mask_filepaths = []
for filepath in filepaths:
path, row = os.path.basename(os.path.dirname(filepath)).split('_')[:2]
mask_filepath = filepath.replace(
os.path.basename(filepath), f'L7_{path}_{row}_newmask2015.TIF'
)
mask_filepaths.append(mask_filepath)

mask = self._merge_files(mask_filepaths, query)
mask_mapping = {64: 1, 128: 2, 192: 3, 255: 4}

for k, v in mask_mapping.items():
mask[mask == k] = v

sample = {
'crs': self.crs,
'bbox': query,
'image': image.float(),
'mask': mask.long(),
}

if self.transforms is not None:
sample = self.transforms(sample)

return sample

def plot(
self,
sample: dict[str, Tensor],
Expand All @@ -221,7 +269,7 @@ def plot(
"""Plot a sample from the dataset.

Args:
sample: a sample returned by :meth:`__getitem__`
sample: a sample returned by :meth:`RasterDataset.__getitem__`
show_titles: flag indicating whether to show titles above each panel
suptitle: optional string to use as a suptitle

Expand All @@ -232,9 +280,9 @@ def plot(
RGBBandsMissingError: If *bands* does not include all RGB bands.
"""
rgb_indices = []
for band in self.rgb_bands:
if band in self.bands:
rgb_indices.append(self.bands.index(band))
for band in self.image.rgb_bands:
if band in self.image.bands:
rgb_indices.append(self.image.bands.index(band))
else:
raise RGBBandsMissingError()

Expand Down
Loading