From 307f743629a16a999f14df5c8d61f9b874fefeb2 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Wed, 18 Dec 2024 09:33:44 -0800 Subject: [PATCH 1/3] add azure sentinel-1 sentinel-2 --- extra_requirements.txt | 2 + rslearn/data_sources/azure_sentinel1.py | 214 +++++++++++++ rslearn/data_sources/azure_sentinel2.py | 410 ++++++++++++++++++++++++ rslearn/data_sources/copernicus.py | 29 +- rslearn/tile_stores/tile_store.py | 31 +- 5 files changed, 653 insertions(+), 33 deletions(-) create mode 100644 rslearn/data_sources/azure_sentinel1.py create mode 100644 rslearn/data_sources/azure_sentinel2.py diff --git a/extra_requirements.txt b/extra_requirements.txt index ea73ab9..dd1d8a3 100644 --- a/extra_requirements.txt +++ b/extra_requirements.txt @@ -10,7 +10,9 @@ interrogate>=1.7 netCDF4>=1.7.2 osmium>=3.7 planet>=2.10 +planetary_computer>=1.0 pycocotools>=2.0 +pystac_client>=0.8 rtree>=1.2 s3fs>=2024.10.0 satlaspretrain_models>=0.3 diff --git a/rslearn/data_sources/azure_sentinel1.py b/rslearn/data_sources/azure_sentinel1.py new file mode 100644 index 0000000..054cce3 --- /dev/null +++ b/rslearn/data_sources/azure_sentinel1.py @@ -0,0 +1,214 @@ +"""Sentinel-1 on Planetary Computer.""" + +import os +import tempfile +from typing import Any + +import planetary_computer +import pystac +import pystac_client +import requests +import shapely +from upath import UPath + +from rslearn.config import QueryConfig, RasterLayerConfig +from rslearn.const import WGS84_PROJECTION +from rslearn.data_sources import DataSource, Item +from rslearn.data_sources.raster_source import is_raster_needed +from rslearn.data_sources.utils import match_candidate_items_to_window +from rslearn.log_utils import get_logger +from rslearn.tile_stores import TileStoreWithLayer +from rslearn.utils.geometry import STGeometry + +logger = get_logger(__name__) + + +class Sentinel1(DataSource): + """A data source for Sentinel-1 data on Microsoft Planetary Computer. + + This uses the radiometrically corrected data. + + See https://planetarycomputer.microsoft.com/dataset/sentinel-1-rtc. + + The PC_SDK_SUBSCRIPTION_KEY environment variable can be set but is not needed. + """ + + STAC_ENDPOINT = "https://planetarycomputer.microsoft.com/api/stac/v1" + + COLLECTION_NAME = "sentinel-1-rtc" + + def __init__( + self, + config: RasterLayerConfig, + query: dict[str, Any] | None = None, + sort_by: str | None = None, + sort_ascending: bool = True, + timeout: int = 10, + ): + """Initialize a new Sentinel1 instance. + + Args: + config: the LayerConfig of the layer containing this data source. + query: optional query argument to STAC searches. + sort_by: sort by this property in the STAC items. + sort_ascending: whether to sort ascending (or descending). + timeout: timeout for API requests in seconds. + """ + self.config = config + self.query = query + self.sort_by = sort_by + self.sort_ascending = sort_ascending + self.timeout = timeout + + self.client = pystac_client.Client.open( + self.STAC_ENDPOINT, modifier=planetary_computer.sign_inplace + ) + self.collection = self.client.get_collection(self.COLLECTION_NAME) + + @staticmethod + def from_config(config: RasterLayerConfig, ds_path: UPath) -> "Sentinel1": + """Creates a new Sentinel1instance from a configuration dictionary.""" + if config.data_source is None: + raise ValueError("config.data_source is required") + d = config.data_source.config_dict + kwargs: dict[str, Any] = dict( + config=config, + ) + + simple_optionals = ["query", "sort_by", "sort_ascending", "timeout"] + for k in simple_optionals: + if k in d: + kwargs[k] = d[k] + + return Sentinel1(**kwargs) + + def _stac_item_to_item(self, stac_item: pystac.Item) -> Item: + shp = shapely.geometry.shape(stac_item.geometry) + + # Get time range. + metadata = stac_item.common_metadata + if metadata.start_datetime is not None and metadata.end_datetime is not None: + time_range = ( + metadata.start_datetime, + metadata.end_datetime, + ) + elif stac_item.datetime is not None: + time_range = (stac_item.datetime, stac_item.datetime) + else: + raise ValueError( + f"item {stac_item.id} unexpectedly missing start_datetime, end_datetime, and datetime" + ) + + geom = STGeometry(WGS84_PROJECTION, shp, time_range) + return Item(stac_item.id, geom) + + def get_item_by_name(self, name: str) -> Item: + """Gets an item by name. + + Args: + name: the name of the item to get + + Returns: + the item object + """ + stac_item = self.collection.get_item(name) + return self._stac_item_to_item(stac_item) + + def get_items( + self, geometries: list[STGeometry], query_config: QueryConfig + ) -> list[list[list[Item]]]: + """Get a list of items in the data source intersecting the given geometries. + + Args: + geometries: the spatiotemporal geometries + query_config: the query configuration + + Returns: + List of groups of items that should be retrieved for each geometry. + """ + groups = [] + for geometry in geometries: + # Get potentially relevant items from the collection by performing one search + # for each requested geometry. + wgs84_geometry = geometry.to_projection(WGS84_PROJECTION) + logger.debug("performing STAC search for geometry %s", wgs84_geometry) + result = self.client.search( + collections=[self.COLLECTION_NAME], + intersects=shapely.to_geojson(wgs84_geometry.shp), + datetime=wgs84_geometry.time_range, + query=self.query, + ) + stac_items = [item for item in result.item_collection()] + logger.debug("STAC search yielded %d items", len(stac_items)) + + if self.sort_by is not None: + stac_items.sort( + key=lambda stac_item: stac_item.properties[self.sort_by], + reverse=not self.sort_ascending, + ) + + candidate_items = [ + self._stac_item_to_item(stac_item) for stac_item in stac_items + ] + cur_groups = match_candidate_items_to_window( + geometry, candidate_items, query_config + ) + groups.append(cur_groups) + + return groups + + def deserialize_item(self, serialized_item: Any) -> Item: + """Deserializes an item from JSON-decoded data.""" + assert isinstance(serialized_item, dict) + return Item.deserialize(serialized_item) + + def ingest( + self, + tile_store: TileStoreWithLayer, + items: list[Item], + geometries: list[list[STGeometry]], + ) -> None: + """Ingest items into the given tile store. + + Args: + tile_store: the tile store to ingest into + items: the items to ingest + geometries: a list of geometries needed for each item + """ + for item in items: + stac_item = self.collection.get_item(item.name) + + for band_name, asset in stac_item.assets.items(): + if not is_raster_needed([band_name], self.config.band_sets): + continue + if tile_store.is_raster_ready(item.name, [band_name]): + continue + + asset_url = asset.href + with tempfile.TemporaryDirectory() as tmp_dir: + local_fname = os.path.join(tmp_dir, "geotiff.tif") + logger.debug( + "azure_sentinel1 download item %s asset %s to %s", + item.name, + band_name, + local_fname, + ) + with requests.get( + asset_url, stream=True, timeout=self.timeout + ) as r: + r.raise_for_status() + with open(local_fname, "wb") as f: + for chunk in r.iter_content(chunk_size=8192): + f.write(chunk) + + logger.debug( + "azure_sentinel1 ingest item %s asset %s", item.name, band_name + ) + tile_store.write_raster_file( + item.name, [band_name], UPath(local_fname) + ) + logger.debug( + "azure_sentinel1 done ingesting item %s asset %s", + item.name, + band_name, + ) diff --git a/rslearn/data_sources/azure_sentinel2.py b/rslearn/data_sources/azure_sentinel2.py new file mode 100644 index 0000000..fd66873 --- /dev/null +++ b/rslearn/data_sources/azure_sentinel2.py @@ -0,0 +1,410 @@ +"""Sentinel-2 on Planetary Computer.""" + +import os +import tempfile +import xml.etree.ElementTree as ET +from typing import Any + +import affine +import numpy.typing as npt +import planetary_computer +import pystac +import pystac_client +import rasterio +import requests +import shapely +from rasterio.enums import Resampling +from upath import UPath + +from rslearn.config import LayerConfig, QueryConfig, RasterLayerConfig +from rslearn.const import WGS84_PROJECTION +from rslearn.data_sources import DataSource, Item +from rslearn.data_sources.raster_source import is_raster_needed +from rslearn.data_sources.utils import match_candidate_items_to_window +from rslearn.dataset import Window +from rslearn.dataset.materialize import RasterMaterializer +from rslearn.log_utils import get_logger +from rslearn.tile_stores import TileStore, TileStoreWithLayer +from rslearn.utils.geometry import PixelBounds, Projection, STGeometry +from rslearn.utils.raster_format import get_raster_projection_and_bounds + +from .copernicus import get_harmonize_callback + +logger = get_logger(__name__) + + +class Sentinel2(DataSource, TileStore): + """A data source for Sentinel-2 L2A data on Microsoft Planetary Computer. + + See https://planetarycomputer.microsoft.com/dataset/sentinel-2-l2a. + + The PC_SDK_SUBSCRIPTION_KEY environment variable can be set but is not needed. + """ + + STAC_ENDPOINT = "https://planetarycomputer.microsoft.com/api/stac/v1" + + COLLECTION_NAME = "sentinel-2-l2a" + + BANDS = { + "B01": ["B01"], + "B02": ["B02"], + "B03": ["B03"], + "B04": ["B04"], + "B05": ["B05"], + "B06": ["B06"], + "B07": ["B07"], + "B08": ["B08"], + "B09": ["B09"], + "B11": ["B11"], + "B12": ["B12"], + "B8A": ["B8A"], + "visual": ["R", "G", "B"], + } + + def __init__( + self, + config: RasterLayerConfig, + harmonize: bool = False, + query: dict[str, Any] | None = None, + sort_by: str | None = None, + sort_ascending: bool = True, + timeout: int = 10, + ): + """Initialize a new Sentinel2 instance. + + Args: + config: the LayerConfig of the layer containing this data source. + harmonize: harmonize pixel values across different processing baselines, + see https://developers.google.com/earth-engine/datasets/catalog/COPERNICUS_S2_SR_HARMONIZED + query: optional query argument to STAC searches. + sort_by: sort by this property in the STAC items. + sort_ascending: whether to sort ascending (or descending). + timeout: timeout for API requests in seconds. + """ + self.config = config + self.harmonize = harmonize + self.query = query + self.sort_by = sort_by + self.sort_ascending = sort_ascending + self.timeout = timeout + + self.client = pystac_client.Client.open( + self.STAC_ENDPOINT, modifier=planetary_computer.sign_inplace + ) + self.collection = self.client.get_collection(self.COLLECTION_NAME) + + @staticmethod + def from_config(config: RasterLayerConfig, ds_path: UPath) -> "Sentinel2": + """Creates a new Sentinel2 instance from a configuration dictionary.""" + if config.data_source is None: + raise ValueError("config.data_source is required") + d = config.data_source.config_dict + kwargs: dict[str, Any] = dict( + config=config, + ) + + simple_optionals = [ + "harmonize", + "query", + "sort_by", + "sort_ascending", + "timeout", + ] + for k in simple_optionals: + if k in d: + kwargs[k] = d[k] + + return Sentinel2(**kwargs) + + def _stac_item_to_item(self, stac_item: pystac.Item) -> Item: + shp = shapely.geometry.shape(stac_item.geometry) + + # Get time range. + metadata = stac_item.common_metadata + if metadata.start_datetime is not None and metadata.end_datetime is not None: + time_range = ( + metadata.start_datetime, + metadata.end_datetime, + ) + elif stac_item.datetime is not None: + time_range = (stac_item.datetime, stac_item.datetime) + else: + raise ValueError( + f"item {stac_item.id} unexpectedly missing start_datetime, end_datetime, and datetime" + ) + + geom = STGeometry(WGS84_PROJECTION, shp, time_range) + return Item(stac_item.id, geom) + + def get_item_by_name(self, name: str) -> Item: + """Gets an item by name. + + Args: + name: the name of the item to get + + Returns: + the item object + """ + stac_item = self.collection.get_item(name) + return self._stac_item_to_item(stac_item) + + def get_items( + self, geometries: list[STGeometry], query_config: QueryConfig + ) -> list[list[list[Item]]]: + """Get a list of items in the data source intersecting the given geometries. + + Args: + geometries: the spatiotemporal geometries + query_config: the query configuration + + Returns: + List of groups of items that should be retrieved for each geometry. + """ + groups = [] + for geometry in geometries: + # Get potentially relevant items from the collection by performing one search + # for each requested geometry. + wgs84_geometry = geometry.to_projection(WGS84_PROJECTION) + logger.debug("performing STAC search for geometry %s", wgs84_geometry) + result = self.client.search( + collections=[self.COLLECTION_NAME], + intersects=shapely.to_geojson(wgs84_geometry.shp), + datetime=wgs84_geometry.time_range, + query=self.query, + ) + stac_items = [item for item in result.item_collection()] + logger.debug("STAC search yielded %d items", len(stac_items)) + + if self.sort_by is not None: + stac_items.sort( + key=lambda stac_item: stac_item.properties[self.sort_by], + reverse=not self.sort_ascending, + ) + + candidate_items = [ + self._stac_item_to_item(stac_item) for stac_item in stac_items + ] + cur_groups = match_candidate_items_to_window( + geometry, candidate_items, query_config + ) + groups.append(cur_groups) + + return groups + + def deserialize_item(self, serialized_item: Any) -> Item: + """Deserializes an item from JSON-decoded data.""" + assert isinstance(serialized_item, dict) + return Item.deserialize(serialized_item) + + def _get_product_xml(self, stac_item: pystac.Item) -> ET.Element: + asset_url = stac_item.assets["product-metadata"].href + response = requests.get(asset_url, timeout=self.timeout) + response.raise_for_status() + return ET.fromstring(response.content) + + def ingest( + self, + tile_store: TileStoreWithLayer, + items: list[Item], + geometries: list[list[STGeometry]], + ) -> None: + """Ingest items into the given tile store. + + Args: + tile_store: the tile store to ingest into + items: the items to ingest + geometries: a list of geometries needed for each item + """ + for item in items: + stac_item = self.collection.get_item(item.name) + + for asset_key, band_names in self.BANDS.items(): + if not is_raster_needed(band_names, self.config.band_sets): + continue + if tile_store.is_raster_ready(item.name, band_names): + continue + + asset_url = stac_item.assets[asset_key].href + + with tempfile.TemporaryDirectory() as tmp_dir: + fname = os.path.join(tmp_dir, f"{asset_key}.tif") + logger.debug( + "azure_sentinel2 start downloading item %s asset %s", + item.name, + asset_key, + ) + with requests.get( + asset_url, stream=True, timeout=self.timeout + ) as r: + r.raise_for_status() + with open(fname, "wb") as f: + for chunk in r.iter_content(chunk_size=8192): + f.write(chunk) + + logger.debug( + "azure_sentinel2 ingest item %s asset %s", item.name, asset_key + ) + + # Harmonize values if needed. + # TCI does not need harmonization. + harmonize_callback = None + if self.harmonize and asset_key != "visual": + harmonize_callback = get_harmonize_callback( + self._get_product_xml(stac_item) + ) + + if harmonize_callback is not None: + # In this case we need to read the array, convert the pixel + # values, and pass modified array directly to the TileStore. + with rasterio.open(fname) as src: + array = src.read() + projection, bounds = get_raster_projection_and_bounds(src) + array = harmonize_callback(array) + tile_store.write_raster( + item.name, band_names, projection, bounds, array + ) + + else: + tile_store.write_raster_file( + item.name, band_names, UPath(fname) + ) + + logger.debug( + "azure_sentinel2 done ingesting item %s asset %s", + item.name, + asset_key, + ) + + def is_raster_ready( + self, layer_name: str, item_name: str, bands: list[str] + ) -> bool: + """Checks if this raster has been written to the store. + + Args: + layer_name: the layer name or alias. + item_name: the item. + bands: the list of bands identifying which specific raster to read. + + Returns: + whether there is a raster in the store matching the source, item, and + bands. + """ + return True + + def get_raster_bands(self, layer_name: str, item_name: str) -> list[list[str]]: + """Get the sets of bands that have been stored for the specified item. + + Args: + layer_name: the layer name or alias. + item_name: the item. + + Returns: + a list of lists of bands that are in the tile store (with one raster + stored corresponding to each inner list). If no rasters are ready for + this item, returns empty list. + """ + return list(self.BANDS.values()) + + def _get_asset_by_band(self, bands: list[str]) -> str: + # Get the name of the asset based on the bands. + for asset_key, asset_bands in self.BANDS.items(): + if bands == asset_bands: + return asset_key + + raise ValueError(f"no raster with bands {bands}") + + def get_raster_bounds( + self, layer_name: str, item_name: str, bands: list[str], projection: Projection + ) -> PixelBounds: + """Get the bounds of the raster in the specified projection. + + Args: + layer_name: the layer name or alias. + item_name: the item to check. + bands: the list of bands identifying which specific raster to read. These + bands must match the bands of a stored raster. + projection: the projection to get the raster's bounds in. + + Returns: + the bounds of the raster in the projection. + """ + item = self.get_item_by_name(item_name) + geom = item.geometry.to_projection(projection) + return ( + int(geom.shp.bounds[0]), + int(geom.shp.bounds[1]), + int(geom.shp.bounds[2]), + int(geom.shp.bounds[3]), + ) + + def read_raster( + self, + layer_name: str, + item_name: str, + bands: list[str], + projection: Projection, + bounds: PixelBounds, + resampling: Resampling = Resampling.bilinear, + ) -> npt.NDArray[Any]: + """Read raster data from the store. + + Args: + layer_name: the layer name or alias. + item_name: the item to read. + bands: the list of bands identifying which specific raster to read. These + bands must match the bands of a stored raster. + projection: the projection to read in. + bounds: the bounds to read. + resampling: the resampling method to use in case reprojection is needed. + + Returns: + the raster data + """ + asset_key = self._get_asset_by_band(bands) + stac_item = self.collection.get_item(item_name) + asset_url = stac_item.assets[asset_key].href + + # Construct the transform to use for the warped dataset. + wanted_transform = affine.Affine( + projection.x_resolution, + 0, + bounds[0] * projection.x_resolution, + 0, + projection.y_resolution, + bounds[1] * projection.y_resolution, + ) + + with rasterio.open(asset_url) as src: + with rasterio.vrt.WarpedVRT( + src, + crs=projection.crs, + transform=wanted_transform, + width=bounds[2] - bounds[0], + height=bounds[3] - bounds[1], + resampling=resampling, + ) as vrt: + return vrt.read() + + def materialize( + self, + window: Window, + item_groups: list[list[Item]], + layer_name: str, + layer_cfg: LayerConfig, + ) -> None: + """Materialize data for the window. + + Args: + window: the window to materialize + item_groups: the items from get_items + layer_name: the name of this layer + layer_cfg: the config of this layer + """ + assert isinstance(layer_cfg, RasterLayerConfig) + RasterMaterializer().materialize( + TileStoreWithLayer(self, layer_name), + window, + layer_name, + layer_cfg, + item_groups, + ) diff --git a/rslearn/data_sources/copernicus.py b/rslearn/data_sources/copernicus.py index 3fa7577..569a110 100644 --- a/rslearn/data_sources/copernicus.py +++ b/rslearn/data_sources/copernicus.py @@ -26,7 +26,7 @@ def get_harmonize_callback( - tree: ET.ElementTree, + tree: ET.ElementTree | ET.Element, ) -> Callable[[npt.NDArray], npt.NDArray] | None: """Gets the harmonization callback based on the metadata XML. @@ -41,17 +41,22 @@ def get_harmonize_callback( None if no callback is needed, or the callback to subtract the new offset """ offset = None - for el in tree.iter("RADIO_ADD_OFFSET"): - if el.text is None: - raise ValueError(f"text is missing in {el}") - value = int(el.text) - if offset is None: - offset = value - assert offset <= 0 - # For now assert the offset is always -1000. - assert offset == -1000 - else: - assert offset == value + + # The metadata will use different tag for L1C / L2A. + # L1C: RADIO_ADD_OFFSET + # L2A: BOA_ADD_OFFSET + for potential_tag in ["RADIO_ADD_OFFSET", "BOA_ADD_OFFSET"]: + for el in tree.iter(potential_tag): + if el.text is None: + raise ValueError(f"text is missing in {el}") + value = int(el.text) + if offset is None: + offset = value + assert offset <= 0 + # For now assert the offset is always -1000. + assert offset == -1000 + else: + assert offset == value if offset is None or offset == 0: return None diff --git a/rslearn/tile_stores/tile_store.py b/rslearn/tile_stores/tile_store.py index 42ebefb..479bf53 100644 --- a/rslearn/tile_stores/tile_store.py +++ b/rslearn/tile_stores/tile_store.py @@ -1,6 +1,5 @@ """Base class for tile stores.""" -from abc import ABC, abstractmethod from typing import Any import numpy.typing as npt @@ -10,13 +9,12 @@ from rslearn.utils import Feature, PixelBounds, Projection -class TileStore(ABC): +class TileStore: """An abstract class for a tile store. A tile store supports operations to read and write raster and vector data. """ - @abstractmethod def set_dataset_path(self, ds_path: UPath) -> None: """Set the dataset path. @@ -28,7 +26,6 @@ def set_dataset_path(self, ds_path: UPath) -> None: """ pass - @abstractmethod def is_raster_ready( self, layer_name: str, item_name: str, bands: list[str] ) -> bool: @@ -43,9 +40,8 @@ def is_raster_ready( whether there is a raster in the store matching the source, item, and bands. """ - pass + raise NotImplementedError - @abstractmethod def get_raster_bands(self, layer_name: str, item_name: str) -> list[list[str]]: """Get the sets of bands that have been stored for the specified item. @@ -58,9 +54,8 @@ def get_raster_bands(self, layer_name: str, item_name: str) -> list[list[str]]: stored corresponding to each inner list). If no rasters are ready for this item, returns empty list. """ - pass + raise NotImplementedError - @abstractmethod def get_raster_bounds( self, layer_name: str, item_name: str, bands: list[str], projection: Projection ) -> PixelBounds: @@ -76,9 +71,8 @@ def get_raster_bounds( Returns: the bounds of the raster in the projection. """ - pass + raise NotImplementedError - @abstractmethod def read_raster( self, layer_name: str, @@ -102,9 +96,8 @@ def read_raster( Returns: the raster data """ - pass + raise NotImplementedError - @abstractmethod def write_raster( self, layer_name: str, @@ -124,9 +117,8 @@ def write_raster( bounds: the bounds of the array. array: the raster data. """ - pass + raise NotImplementedError - @abstractmethod def write_raster_file( self, layer_name: str, item_name: str, bands: list[str], fname: UPath ) -> None: @@ -138,9 +130,8 @@ def write_raster_file( bands: the list of bands in the array. fname: the raster file. """ - pass + raise NotImplementedError - @abstractmethod def is_vector_ready(self, layer_name: str, item_name: str) -> bool: """Checks if this vector item has been written to the store. @@ -151,9 +142,8 @@ def is_vector_ready(self, layer_name: str, item_name: str) -> bool: Returns: whether the vector data from the item has been stored. """ - pass + raise NotImplementedError - @abstractmethod def read_vector( self, layer_name: str, @@ -172,9 +162,8 @@ def read_vector( Returns: the vector data """ - pass + raise NotImplementedError - @abstractmethod def write_vector( self, layer_name: str, item_name: str, features: list[Feature] ) -> None: @@ -185,7 +174,7 @@ def write_vector( item_name: the item to write. features: the vector data. """ - pass + raise NotImplementedError class TileStoreWithLayer: From d4fe0061444d819ff6f44b04803d5fffe5be3cab Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Wed, 15 Jan 2025 15:56:49 -0800 Subject: [PATCH 2/3] start working on test for sentinel-1 --- rslearn/data_sources/azure_sentinel1.py | 143 +++++++++++++++++- .../data_sources/test_azure_sentinel1.py | 49 ++++++ 2 files changed, 188 insertions(+), 4 deletions(-) create mode 100644 tests/integration/data_sources/test_azure_sentinel1.py diff --git a/rslearn/data_sources/azure_sentinel1.py b/rslearn/data_sources/azure_sentinel1.py index 054cce3..38f0347 100644 --- a/rslearn/data_sources/azure_sentinel1.py +++ b/rslearn/data_sources/azure_sentinel1.py @@ -4,26 +4,32 @@ import tempfile from typing import Any +import affine +import numpy.typing as npt import planetary_computer import pystac import pystac_client +import rasterio import requests import shapely +from rasterio.enums import Resampling from upath import UPath -from rslearn.config import QueryConfig, RasterLayerConfig +from rslearn.config import LayerConfig, QueryConfig, RasterLayerConfig from rslearn.const import WGS84_PROJECTION from rslearn.data_sources import DataSource, Item from rslearn.data_sources.raster_source import is_raster_needed from rslearn.data_sources.utils import match_candidate_items_to_window +from rslearn.dataset import Window +from rslearn.dataset.materialize import RasterMaterializer from rslearn.log_utils import get_logger -from rslearn.tile_stores import TileStoreWithLayer -from rslearn.utils.geometry import STGeometry +from rslearn.tile_stores import TileStore, TileStoreWithLayer +from rslearn.utils.geometry import PixelBounds, Projection, STGeometry logger = get_logger(__name__) -class Sentinel1(DataSource): +class Sentinel1(DataSource, TileStore): """A data source for Sentinel-1 data on Microsoft Planetary Computer. This uses the radiometrically corrected data. @@ -212,3 +218,132 @@ def ingest( item.name, band_name, ) + + def is_raster_ready( + self, layer_name: str, item_name: str, bands: list[str] + ) -> bool: + """Checks if this raster has been written to the store. + + Args: + layer_name: the layer name or alias. + item_name: the item. + bands: the list of bands identifying which specific raster to read. + + Returns: + whether there is a raster in the store matching the source, item, and + bands. + """ + return True + + def get_raster_bands(self, layer_name: str, item_name: str) -> list[list[str]]: + """Get the sets of bands that have been stored for the specified item. + + Args: + layer_name: the layer name or alias. + item_name: the item. + + Returns: + a list of lists of bands that are in the tile store (with one raster + stored corresponding to each inner list). If no rasters are ready for + this item, returns empty list. + """ + stac_item = self.collection.get_item(item_name) + bands = [[band_name] for band_name in stac_item.assets.keys()] + return bands + + def get_raster_bounds( + self, layer_name: str, item_name: str, bands: list[str], projection: Projection + ) -> PixelBounds: + """Get the bounds of the raster in the specified projection. + + Args: + layer_name: the layer name or alias. + item_name: the item to check. + bands: the list of bands identifying which specific raster to read. These + bands must match the bands of a stored raster. + projection: the projection to get the raster's bounds in. + + Returns: + the bounds of the raster in the projection. + """ + item = self.get_item_by_name(item_name) + geom = item.geometry.to_projection(projection) + return ( + int(geom.shp.bounds[0]), + int(geom.shp.bounds[1]), + int(geom.shp.bounds[2]), + int(geom.shp.bounds[3]), + ) + + def read_raster( + self, + layer_name: str, + item_name: str, + bands: list[str], + projection: Projection, + bounds: PixelBounds, + resampling: Resampling = Resampling.bilinear, + ) -> npt.NDArray[Any]: + """Read raster data from the store. + + Args: + layer_name: the layer name or alias. + item_name: the item to read. + bands: the list of bands identifying which specific raster to read. These + bands must match the bands of a stored raster. + projection: the projection to read in. + bounds: the bounds to read. + resampling: the resampling method to use in case reprojection is needed. + + Returns: + the raster data + """ + assert len(bands) == 1 + band_name = bands[0] + stac_item = self.collection.get_item(item_name) + asset_url = stac_item.assets[band_name].href + + # Construct the transform to use for the warped dataset. + wanted_transform = affine.Affine( + projection.x_resolution, + 0, + bounds[0] * projection.x_resolution, + 0, + projection.y_resolution, + bounds[1] * projection.y_resolution, + ) + + with rasterio.open(asset_url) as src: + with rasterio.vrt.WarpedVRT( + src, + crs=projection.crs, + transform=wanted_transform, + width=bounds[2] - bounds[0], + height=bounds[3] - bounds[1], + resampling=resampling, + ) as vrt: + return vrt.read() + + def materialize( + self, + window: Window, + item_groups: list[list[Item]], + layer_name: str, + layer_cfg: LayerConfig, + ) -> None: + """Materialize data for the window. + + Args: + window: the window to materialize + item_groups: the items from get_items + layer_name: the name of this layer + layer_cfg: the config of this layer + """ + assert isinstance(layer_cfg, RasterLayerConfig) + RasterMaterializer().materialize( + TileStoreWithLayer(self, layer_name), + window, + layer_name, + layer_cfg, + item_groups, + ) diff --git a/tests/integration/data_sources/test_azure_sentinel1.py b/tests/integration/data_sources/test_azure_sentinel1.py new file mode 100644 index 0000000..b36bf17 --- /dev/null +++ b/tests/integration/data_sources/test_azure_sentinel1.py @@ -0,0 +1,49 @@ +import pathlib + +from upath import UPath + +from rslearn.config import ( + BandSetConfig, + DType, + LayerType, + QueryConfig, + RasterLayerConfig, + SpaceMode, +) +from rslearn.data_sources.azure_sentinel1 import Sentinel1 +from rslearn.tile_stores import DefaultTileStore, TileStoreWithLayer +from rslearn.utils import STGeometry + + +class TestSentinel1: + """Tests the Sentinel1 data source.""" + + def test_local( + self, tmp_path: pathlib.Path, seattle2020: STGeometry, use_rtree_index: bool + ) -> None: + """Test ingesting an item corresponding to seattle2020 to local filesystem.""" + layer_config = RasterLayerConfig( + LayerType.RASTER, + [BandSetConfig(config_dict={}, dtype=DType.UINT8, bands=["vv"])], + ) + query_config = QueryConfig(space_mode=SpaceMode.INTERSECTS) + s1_query_dict = {"sar:polarizations": ["VV", "VH"]} + data_source = Sentinel1( + config=layer_config, + query=s1_query_dict, + ) + + print("get items") + item_groups = data_source.get_items([seattle2020], query_config)[0] + item = item_groups[0][0] + + tile_store_dir = UPath(tmp_path) + tile_store = DefaultTileStore(str(tile_store_dir)) + tile_store.set_dataset_path(tile_store_dir) + + print("ingest") + layer_name = "layer" + data_source.ingest( + TileStoreWithLayer(tile_store, layer_name), item_groups[0], [[seattle2020]] + ) + assert tile_store.is_raster_ready(layer_name, item.name, ["vv"]) From 916b7ef704e94487ec612fa87068cb38d46e3b2b Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Thu, 16 Jan 2025 11:07:13 -0800 Subject: [PATCH 3/3] Add Azure tests --- .../data_sources/test_azure_sentinel1.py | 61 +++++++++---------- .../data_sources/test_azure_sentinel2.py | 42 +++++++++++++ 2 files changed, 71 insertions(+), 32 deletions(-) create mode 100644 tests/integration/data_sources/test_azure_sentinel2.py diff --git a/tests/integration/data_sources/test_azure_sentinel1.py b/tests/integration/data_sources/test_azure_sentinel1.py index b36bf17..45e1dfa 100644 --- a/tests/integration/data_sources/test_azure_sentinel1.py +++ b/tests/integration/data_sources/test_azure_sentinel1.py @@ -15,35 +15,32 @@ from rslearn.utils import STGeometry -class TestSentinel1: - """Tests the Sentinel1 data source.""" - - def test_local( - self, tmp_path: pathlib.Path, seattle2020: STGeometry, use_rtree_index: bool - ) -> None: - """Test ingesting an item corresponding to seattle2020 to local filesystem.""" - layer_config = RasterLayerConfig( - LayerType.RASTER, - [BandSetConfig(config_dict={}, dtype=DType.UINT8, bands=["vv"])], - ) - query_config = QueryConfig(space_mode=SpaceMode.INTERSECTS) - s1_query_dict = {"sar:polarizations": ["VV", "VH"]} - data_source = Sentinel1( - config=layer_config, - query=s1_query_dict, - ) - - print("get items") - item_groups = data_source.get_items([seattle2020], query_config)[0] - item = item_groups[0][0] - - tile_store_dir = UPath(tmp_path) - tile_store = DefaultTileStore(str(tile_store_dir)) - tile_store.set_dataset_path(tile_store_dir) - - print("ingest") - layer_name = "layer" - data_source.ingest( - TileStoreWithLayer(tile_store, layer_name), item_groups[0], [[seattle2020]] - ) - assert tile_store.is_raster_ready(layer_name, item.name, ["vv"]) +def test_ingest_seattle(tmp_path: pathlib.Path, seattle2020: STGeometry) -> None: + """Test ingesting an item corresponding to seattle2020 to local filesystem.""" + layer_config = RasterLayerConfig( + LayerType.RASTER, + [BandSetConfig(config_dict={}, dtype=DType.UINT8, bands=["vv"])], + ) + query_config = QueryConfig(space_mode=SpaceMode.INTERSECTS) + # The asset band is vv but in the STAC metadata it is capitalized. + # We search for a VV+VH image since that is the standard one for GRD/IW. + s1_query_dict = {"sar:polarizations": {"eq": ["VV", "VH"]}} + data_source = Sentinel1( + config=layer_config, + query=s1_query_dict, + ) + + print("get items") + item_groups = data_source.get_items([seattle2020], query_config)[0] + item = item_groups[0][0] + + tile_store_dir = UPath(tmp_path) + tile_store = DefaultTileStore(str(tile_store_dir)) + tile_store.set_dataset_path(tile_store_dir) + + print("ingest") + layer_name = "layer" + data_source.ingest( + TileStoreWithLayer(tile_store, layer_name), item_groups[0], [[seattle2020]] + ) + assert tile_store.is_raster_ready(layer_name, item.name, ["vv"]) diff --git a/tests/integration/data_sources/test_azure_sentinel2.py b/tests/integration/data_sources/test_azure_sentinel2.py new file mode 100644 index 0000000..28c87ac --- /dev/null +++ b/tests/integration/data_sources/test_azure_sentinel2.py @@ -0,0 +1,42 @@ +import pathlib + +from upath import UPath + +from rslearn.config import ( + BandSetConfig, + DType, + LayerType, + QueryConfig, + RasterLayerConfig, + SpaceMode, +) +from rslearn.data_sources.azure_sentinel2 import Sentinel2 +from rslearn.tile_stores import DefaultTileStore, TileStoreWithLayer +from rslearn.utils import STGeometry + +TEST_BAND = "B04" + + +def test_ingest_seattle(tmp_path: pathlib.Path, seattle2020: STGeometry) -> None: + """Test ingesting an item corresponding to seattle2020 to local filesystem.""" + layer_config = RasterLayerConfig( + LayerType.RASTER, + [BandSetConfig(config_dict={}, dtype=DType.UINT8, bands=[TEST_BAND])], + ) + query_config = QueryConfig(space_mode=SpaceMode.INTERSECTS) + data_source = Sentinel2(config=layer_config) + + print("get items") + item_groups = data_source.get_items([seattle2020], query_config)[0] + item = item_groups[0][0] + + tile_store_dir = UPath(tmp_path) + tile_store = DefaultTileStore(str(tile_store_dir)) + tile_store.set_dataset_path(tile_store_dir) + + print("ingest") + layer_name = "layer" + data_source.ingest( + TileStoreWithLayer(tile_store, layer_name), item_groups[0], [[seattle2020]] + ) + assert tile_store.is_raster_ready(layer_name, item.name, [TEST_BAND])