Skip to content

Commit

Permalink
L8 Biome: convert to IntersectionDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed May 13, 2024
1 parent 94bd5c7 commit 30c4b5c
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 78 deletions.
5 changes: 4 additions & 1 deletion tests/datasets/test_l8biome.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import shutil
from pathlib import Path
from typing import cast

import matplotlib.pyplot as plt
import pytest
Expand Down Expand Up @@ -65,7 +66,9 @@ def test_plot(self, dataset: L8Biome) -> None:
plt.close()

def test_already_extracted(self, dataset: L8Biome) -> None:
L8Biome(dataset.paths, download=True)
paths = cast(str, dataset.paths)
L8Biome(paths, download=True)
L8Biome([paths], download=True)

def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join('tests', 'data', 'l8biome', '*.tar.gz')
Expand Down
157 changes: 80 additions & 77 deletions torchgeo/datasets/l8biome.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,79 @@
import glob
import os
from collections.abc import Callable, Iterable, Sequence
from typing import Any, cast
from typing import Any

import matplotlib.pyplot as plt
import torch
from matplotlib.figure import Figure
from rasterio.crs import CRS
from torch import Tensor

from .errors import DatasetNotFoundError, RGBBandsMissingError
from .geo import RasterDataset
from .geo import IntersectionDataset, RasterDataset
from .utils import BoundingBox, download_url, extract_archive


class L8Biome(RasterDataset):
class L8BiomeImage(RasterDataset):
"""Images from the L8 Biome dataset."""

# https://gisgeography.com/landsat-file-naming-convention/
filename_glob = 'LC8*.TIF'
filename_regex = r"""
^LC8
(?P<wrs_path>\d{3})
(?P<wrs_row>\d{3})
(?P<date>\d{7})
(?P<gsi>[A-Z]{3})
(?P<version>\d{2})
\.TIF$
"""
date_format = '%Y%j'
is_image = True
rgb_bands = ['B4', 'B3', 'B2']
all_bands = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B9', 'B10', 'B11']


class L8BiomeMask(RasterDataset):
"""Masks from the L8 Biome dataset."""

# https://gisgeography.com/landsat-file-naming-convention/
filename_glob = 'LC8*_fixedmask.TIF'
filename_regex = r"""
^LC8
(?P<wrs_path>\d{3})
(?P<wrs_row>\d{3})
(?P<date>\d{7})
(?P<gsi>[A-Z]{3})
(?P<version>\d{2})
_fixedmask
\.TIF$
"""
date_format = '%Y%j'
is_image = False
classes = ['Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud']
ordinal_map = torch.zeros(256, dtype=torch.long)
ordinal_map[64] = 1
ordinal_map[128] = 2
ordinal_map[192] = 3
ordinal_map[255] = 4

def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
"""Retrieve image/mask and metadata indexed by query.
Args:
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
Returns:
sample of image, mask and metadata at that index
Raises:
IndexError: if query is not found in the index
"""
sample = super().__getitem__(query)
sample['mask'] = self.ordinal_map[sample['mask']]
return sample


class L8Biome(IntersectionDataset):
"""L8 Biome dataset.
The `L8 Biome <https://landsat.usgs.gov/landsat-8-cloud-cover-assessment-validation-data>`__
Expand Down Expand Up @@ -70,31 +130,12 @@ class L8Biome(RasterDataset):
'wetlands': '1f86cc354631ca9a50ce54b7cab3f557',
}

classes = ['Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud']

# https://gisgeography.com/landsat-file-naming-convention/
filename_glob = 'LC8*.TIF'
filename_regex = r"""
^LC8
(?P<wrs_path>\d{3})
(?P<wrs_row>\d{3})
(?P<date>\d{7})
(?P<gsi>[A-Z]{3})
(?P<version>\d{2})
\.TIF$
"""
date_format = '%Y%j'

separate_files = False
rgb_bands = ['B4', 'B3', 'B2']
all_bands = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B9', 'B10', 'B11']

def __init__(
self,
paths: str | Iterable[str],
crs: CRS | None = CRS.from_epsg(3857),
res: float | None = None,
bands: Sequence[str] = all_bands,
bands: Sequence[str] = L8BiomeImage.all_bands,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
cache: bool = True,
download: bool = False,
Expand Down Expand Up @@ -124,18 +165,25 @@ def __init__(

self._verify()

super().__init__(
paths, crs=crs, res=res, bands=bands, transforms=transforms, cache=cache
)
self.image = L8BiomeImage(paths, crs, res, bands, transforms, cache)
self.mask = L8BiomeMask(paths, crs, res, None, transforms, cache)

super().__init__(self.image, self.mask)

def _verify(self) -> None:
"""Verify the integrity of the dataset."""
# Check if the extracted files already exist
if self.files:
if not isinstance(self.paths, str):
return

for classname in [L8BiomeImage, L8BiomeMask]:
pathname = os.path.join(self.paths, '**', classname.filename_glob)
if not glob.glob(pathname, recursive=True):
break
else:
return

# Check if the tar.gz files have already been downloaded
assert isinstance(self.paths, str)
pathname = os.path.join(self.paths, '*.tar.gz')
if glob.glob(pathname):
self._extract()
Expand Down Expand Up @@ -163,51 +211,6 @@ def _extract(self) -> None:
for tarfile in glob.iglob(pathname):
extract_archive(tarfile)

def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
"""Retrieve image/mask and metadata indexed by query.
Args:
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
Returns:
sample of image, mask and metadata at that index
Raises:
IndexError: if query is not found in the index
"""
hits = self.index.intersection(tuple(query), objects=True)
filepaths = cast(list[str], [hit.object for hit in hits])

if not filepaths:
raise IndexError(
f'query: {query} not found in index with bounds: {self.bounds}'
)

image = self._merge_files(filepaths, query, self.band_indexes)

mask_filepaths = []
for filepath in filepaths:
mask_filepath = filepath.replace('.TIF', '_fixedmask.TIF')
mask_filepaths.append(mask_filepath)

mask = self._merge_files(mask_filepaths, query)
mask_mapping = {64: 1, 128: 2, 192: 3, 255: 4}

for k, v in mask_mapping.items():
mask[mask == k] = v

sample = {
'crs': self.crs,
'bbox': query,
'image': image.float(),
'mask': mask.long(),
}

if self.transforms is not None:
sample = self.transforms(sample)

return sample

def plot(
self,
sample: dict[str, Tensor],
Expand All @@ -217,7 +220,7 @@ def plot(
"""Plot a sample from the dataset.
Args:
sample: a sample returned by :meth:`__getitem__`
sample: a sample returned by :meth:`RasterDataset.__getitem__`
show_titles: flag indicating whether to show titles above each panel
suptitle: optional string to use as a suptitle
Expand All @@ -228,9 +231,9 @@ def plot(
RGBBandsMissingError: If *bands* does not include all RGB bands.
"""
rgb_indices = []
for band in self.rgb_bands:
if band in self.bands:
rgb_indices.append(self.bands.index(band))
for band in self.image.rgb_bands:
if band in self.image.bands:
rgb_indices.append(self.image.bands.index(band))
else:
raise RGBBandsMissingError()

Expand Down

0 comments on commit 30c4b5c

Please sign in to comment.