Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RioXarrayDataset for in-memory geographical xarray.DataArray objects #509

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ datasets =
radiant-mlhub>=0.2.1
# rarfile 3+ required for correct Rar file detection
rarfile>=3
rioxarray
# scipy 0.9+ required for scipy.io.wavfile.read
scipy>=0.9
# zipfile-deflate64 0.2+ required for extraction bugfix:
Expand Down
47 changes: 47 additions & 0 deletions tests/datasets/test_rioxarray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import numpy as np
import pytest
import torch
import xarray as xr
from rasterio.crs import CRS

from torchgeo.datasets import (
BoundingBox,
IntersectionDataset,
RioXarrayDataset,
UnionDataset,
)

pytest.importorskip("rioxarray")


class TestRioXarrayDataset:
@pytest.fixture(scope="class")
def dataset(self) -> RioXarrayDataset:
xr_dataarray = xr.DataArray(
data=np.random.randn(5, 3),
coords=dict(y=[5.6, 4.5, 3.4, 2.3, 1.2], x=[6.7, 7.8, 8.9]),
dims=["y", "x"],
)
xr_dataarray.rio.set_crs(input_crs="EPSG:3857")
return RioXarrayDataset(xr_dataarray=xr_dataarray)

def test_getitem(self, dataset: RioXarrayDataset) -> None:
x = dataset[dataset.bounds]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["crs"], CRS)

def test_and(self, dataset: RioXarrayDataset) -> None:
ds = dataset & dataset
assert isinstance(ds, IntersectionDataset)

def test_or(self, dataset: RioXarrayDataset) -> None:
ds = dataset | dataset
assert isinstance(ds, UnionDataset)

def test_invalid_query(self, dataset: RioXarrayDataset) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
with pytest.raises(
IndexError, match="query: .* not found in index with bounds:"
):
dataset[query]
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from .patternnet import PatternNet
from .potsdam import Potsdam2D
from .resisc45 import RESISC45
from .rioxarray import RioXarrayDataset
from .seco import SeasonalContrastS2
from .sen12ms import SEN12MS
from .sentinel import Sentinel, Sentinel2
Expand Down Expand Up @@ -132,6 +133,7 @@
"Landsat9",
"NAIP",
"OpenBuildings",
"RioXarrayDataset",
"Sentinel",
"Sentinel2",
# VisionDataset
Expand Down
97 changes: 97 additions & 0 deletions torchgeo/datasets/rioxarray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""In-memory geographical xarray.DataArray."""

import sys
from typing import Any, Callable, Dict, Optional, cast

import torch
import xarray as xr
from rasterio.crs import CRS
from rtree.index import Index, Property

from .geo import GeoDataset
from .utils import BoundingBox


class RioXarrayDataset(GeoDataset):
"""Wrapper for geographical datasets stored as an xarray.DataArray.

Relies on rioxarray.
"""

def __init__(
self,
xr_dataarray: xr.DataArray,
crs: Optional[CRS] = None,
res: Optional[float] = None,
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
) -> None:
"""Initialize a new Dataset instance.

Args:
xr_dataarray: n-dimensional xarray.DataArray
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of dataarray)
res: resolution of the dataset in units of CRS
(defaults to the resolution of the dataarray)
transforms: a function/transform that takes an input sample
and returns a transformed version

Raises:
FileNotFoundError: if no files are found in ``root``
"""
super().__init__(transforms)

self.xr_dataarray = xr_dataarray
self.transforms = transforms

# Create an R-tree to index the dataset
self.index = Index(interleaved=False, properties=Property(dimension=3))

# Populate the dataset index
if crs is None:
crs = xr_dataarray.rio.crs
if res is None:
res = xr_dataarray.rio.resolution()[0]

(minx, miny, maxx, maxy) = xr_dataarray.rio.bounds()
if hasattr(xr_dataarray, "time"):
mint = int(xr_dataarray.time.min().data)
maxt = int(xr_dataarray.time.max().data)
else:
mint = 0
maxt = sys.maxsize
coords = (minx, maxx, miny, maxy, mint, maxt)
self.index.insert(0, coords, xr_dataarray.name)

self._crs = cast(CRS, crs)
self.res = cast(float, res)

def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
"""Retrieve image/mask and metadata indexed by query.

Args:
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index

Returns:
sample of image/mask and metadata at that index

Raises:
IndexError: if query is not found in the index
"""
hits = self.index.intersection(tuple(query), objects=True)
items = [hit.object for hit in hits]

if not items:
raise IndexError(
f"query: {query} not found in index with bounds: {self.bounds}"
)

image = self.xr_dataarray.rio.clip_box(
minx=query.minx, miny=query.miny, maxx=query.maxx, maxy=query.maxy
)
sample = {"image": torch.tensor(image.data), "crs": self.crs, "bbox": query}

if self.transforms is not None:
sample = self.transforms(sample)

return sample