-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #116 from allenai/favyen/azure-sentinel1-sentinel2
Add Azure sentinel-1 sentinel-2 data sources
- Loading branch information
Showing
7 changed files
with
876 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,349 @@ | ||
"""Sentinel-1 on Planetary Computer.""" | ||
|
||
import os | ||
import tempfile | ||
from typing import Any | ||
|
||
import affine | ||
import numpy.typing as npt | ||
import planetary_computer | ||
import pystac | ||
import pystac_client | ||
import rasterio | ||
import requests | ||
import shapely | ||
from rasterio.enums import Resampling | ||
from upath import UPath | ||
|
||
from rslearn.config import LayerConfig, QueryConfig, RasterLayerConfig | ||
from rslearn.const import WGS84_PROJECTION | ||
from rslearn.data_sources import DataSource, Item | ||
from rslearn.data_sources.raster_source import is_raster_needed | ||
from rslearn.data_sources.utils import match_candidate_items_to_window | ||
from rslearn.dataset import Window | ||
from rslearn.dataset.materialize import RasterMaterializer | ||
from rslearn.log_utils import get_logger | ||
from rslearn.tile_stores import TileStore, TileStoreWithLayer | ||
from rslearn.utils.geometry import PixelBounds, Projection, STGeometry | ||
|
||
logger = get_logger(__name__) | ||
|
||
|
||
class Sentinel1(DataSource, TileStore): | ||
"""A data source for Sentinel-1 data on Microsoft Planetary Computer. | ||
This uses the radiometrically corrected data. | ||
See https://planetarycomputer.microsoft.com/dataset/sentinel-1-rtc. | ||
The PC_SDK_SUBSCRIPTION_KEY environment variable can be set but is not needed. | ||
""" | ||
|
||
STAC_ENDPOINT = "https://planetarycomputer.microsoft.com/api/stac/v1" | ||
|
||
COLLECTION_NAME = "sentinel-1-rtc" | ||
|
||
def __init__( | ||
self, | ||
config: RasterLayerConfig, | ||
query: dict[str, Any] | None = None, | ||
sort_by: str | None = None, | ||
sort_ascending: bool = True, | ||
timeout: int = 10, | ||
): | ||
"""Initialize a new Sentinel1 instance. | ||
Args: | ||
config: the LayerConfig of the layer containing this data source. | ||
query: optional query argument to STAC searches. | ||
sort_by: sort by this property in the STAC items. | ||
sort_ascending: whether to sort ascending (or descending). | ||
timeout: timeout for API requests in seconds. | ||
""" | ||
self.config = config | ||
self.query = query | ||
self.sort_by = sort_by | ||
self.sort_ascending = sort_ascending | ||
self.timeout = timeout | ||
|
||
self.client = pystac_client.Client.open( | ||
self.STAC_ENDPOINT, modifier=planetary_computer.sign_inplace | ||
) | ||
self.collection = self.client.get_collection(self.COLLECTION_NAME) | ||
|
||
@staticmethod | ||
def from_config(config: RasterLayerConfig, ds_path: UPath) -> "Sentinel1": | ||
"""Creates a new Sentinel1instance from a configuration dictionary.""" | ||
if config.data_source is None: | ||
raise ValueError("config.data_source is required") | ||
d = config.data_source.config_dict | ||
kwargs: dict[str, Any] = dict( | ||
config=config, | ||
) | ||
|
||
simple_optionals = ["query", "sort_by", "sort_ascending", "timeout"] | ||
for k in simple_optionals: | ||
if k in d: | ||
kwargs[k] = d[k] | ||
|
||
return Sentinel1(**kwargs) | ||
|
||
def _stac_item_to_item(self, stac_item: pystac.Item) -> Item: | ||
shp = shapely.geometry.shape(stac_item.geometry) | ||
|
||
# Get time range. | ||
metadata = stac_item.common_metadata | ||
if metadata.start_datetime is not None and metadata.end_datetime is not None: | ||
time_range = ( | ||
metadata.start_datetime, | ||
metadata.end_datetime, | ||
) | ||
elif stac_item.datetime is not None: | ||
time_range = (stac_item.datetime, stac_item.datetime) | ||
else: | ||
raise ValueError( | ||
f"item {stac_item.id} unexpectedly missing start_datetime, end_datetime, and datetime" | ||
) | ||
|
||
geom = STGeometry(WGS84_PROJECTION, shp, time_range) | ||
return Item(stac_item.id, geom) | ||
|
||
def get_item_by_name(self, name: str) -> Item: | ||
"""Gets an item by name. | ||
Args: | ||
name: the name of the item to get | ||
Returns: | ||
the item object | ||
""" | ||
stac_item = self.collection.get_item(name) | ||
return self._stac_item_to_item(stac_item) | ||
|
||
def get_items( | ||
self, geometries: list[STGeometry], query_config: QueryConfig | ||
) -> list[list[list[Item]]]: | ||
"""Get a list of items in the data source intersecting the given geometries. | ||
Args: | ||
geometries: the spatiotemporal geometries | ||
query_config: the query configuration | ||
Returns: | ||
List of groups of items that should be retrieved for each geometry. | ||
""" | ||
groups = [] | ||
for geometry in geometries: | ||
# Get potentially relevant items from the collection by performing one search | ||
# for each requested geometry. | ||
wgs84_geometry = geometry.to_projection(WGS84_PROJECTION) | ||
logger.debug("performing STAC search for geometry %s", wgs84_geometry) | ||
result = self.client.search( | ||
collections=[self.COLLECTION_NAME], | ||
intersects=shapely.to_geojson(wgs84_geometry.shp), | ||
datetime=wgs84_geometry.time_range, | ||
query=self.query, | ||
) | ||
stac_items = [item for item in result.item_collection()] | ||
logger.debug("STAC search yielded %d items", len(stac_items)) | ||
|
||
if self.sort_by is not None: | ||
stac_items.sort( | ||
key=lambda stac_item: stac_item.properties[self.sort_by], | ||
reverse=not self.sort_ascending, | ||
) | ||
|
||
candidate_items = [ | ||
self._stac_item_to_item(stac_item) for stac_item in stac_items | ||
] | ||
cur_groups = match_candidate_items_to_window( | ||
geometry, candidate_items, query_config | ||
) | ||
groups.append(cur_groups) | ||
|
||
return groups | ||
|
||
def deserialize_item(self, serialized_item: Any) -> Item: | ||
"""Deserializes an item from JSON-decoded data.""" | ||
assert isinstance(serialized_item, dict) | ||
return Item.deserialize(serialized_item) | ||
|
||
def ingest( | ||
self, | ||
tile_store: TileStoreWithLayer, | ||
items: list[Item], | ||
geometries: list[list[STGeometry]], | ||
) -> None: | ||
"""Ingest items into the given tile store. | ||
Args: | ||
tile_store: the tile store to ingest into | ||
items: the items to ingest | ||
geometries: a list of geometries needed for each item | ||
""" | ||
for item in items: | ||
stac_item = self.collection.get_item(item.name) | ||
|
||
for band_name, asset in stac_item.assets.items(): | ||
if not is_raster_needed([band_name], self.config.band_sets): | ||
continue | ||
if tile_store.is_raster_ready(item.name, [band_name]): | ||
continue | ||
|
||
asset_url = asset.href | ||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
local_fname = os.path.join(tmp_dir, "geotiff.tif") | ||
logger.debug( | ||
"azure_sentinel1 download item %s asset %s to %s", | ||
item.name, | ||
band_name, | ||
local_fname, | ||
) | ||
with requests.get( | ||
asset_url, stream=True, timeout=self.timeout | ||
) as r: | ||
r.raise_for_status() | ||
with open(local_fname, "wb") as f: | ||
for chunk in r.iter_content(chunk_size=8192): | ||
f.write(chunk) | ||
|
||
logger.debug( | ||
"azure_sentinel1 ingest item %s asset %s", item.name, band_name | ||
) | ||
tile_store.write_raster_file( | ||
item.name, [band_name], UPath(local_fname) | ||
) | ||
logger.debug( | ||
"azure_sentinel1 done ingesting item %s asset %s", | ||
item.name, | ||
band_name, | ||
) | ||
|
||
def is_raster_ready( | ||
self, layer_name: str, item_name: str, bands: list[str] | ||
) -> bool: | ||
"""Checks if this raster has been written to the store. | ||
Args: | ||
layer_name: the layer name or alias. | ||
item_name: the item. | ||
bands: the list of bands identifying which specific raster to read. | ||
Returns: | ||
whether there is a raster in the store matching the source, item, and | ||
bands. | ||
""" | ||
return True | ||
|
||
def get_raster_bands(self, layer_name: str, item_name: str) -> list[list[str]]: | ||
"""Get the sets of bands that have been stored for the specified item. | ||
Args: | ||
layer_name: the layer name or alias. | ||
item_name: the item. | ||
Returns: | ||
a list of lists of bands that are in the tile store (with one raster | ||
stored corresponding to each inner list). If no rasters are ready for | ||
this item, returns empty list. | ||
""" | ||
stac_item = self.collection.get_item(item_name) | ||
bands = [[band_name] for band_name in stac_item.assets.keys()] | ||
return bands | ||
|
||
def get_raster_bounds( | ||
self, layer_name: str, item_name: str, bands: list[str], projection: Projection | ||
) -> PixelBounds: | ||
"""Get the bounds of the raster in the specified projection. | ||
Args: | ||
layer_name: the layer name or alias. | ||
item_name: the item to check. | ||
bands: the list of bands identifying which specific raster to read. These | ||
bands must match the bands of a stored raster. | ||
projection: the projection to get the raster's bounds in. | ||
Returns: | ||
the bounds of the raster in the projection. | ||
""" | ||
item = self.get_item_by_name(item_name) | ||
geom = item.geometry.to_projection(projection) | ||
return ( | ||
int(geom.shp.bounds[0]), | ||
int(geom.shp.bounds[1]), | ||
int(geom.shp.bounds[2]), | ||
int(geom.shp.bounds[3]), | ||
) | ||
|
||
def read_raster( | ||
self, | ||
layer_name: str, | ||
item_name: str, | ||
bands: list[str], | ||
projection: Projection, | ||
bounds: PixelBounds, | ||
resampling: Resampling = Resampling.bilinear, | ||
) -> npt.NDArray[Any]: | ||
"""Read raster data from the store. | ||
Args: | ||
layer_name: the layer name or alias. | ||
item_name: the item to read. | ||
bands: the list of bands identifying which specific raster to read. These | ||
bands must match the bands of a stored raster. | ||
projection: the projection to read in. | ||
bounds: the bounds to read. | ||
resampling: the resampling method to use in case reprojection is needed. | ||
Returns: | ||
the raster data | ||
""" | ||
assert len(bands) == 1 | ||
band_name = bands[0] | ||
stac_item = self.collection.get_item(item_name) | ||
asset_url = stac_item.assets[band_name].href | ||
|
||
# Construct the transform to use for the warped dataset. | ||
wanted_transform = affine.Affine( | ||
projection.x_resolution, | ||
0, | ||
bounds[0] * projection.x_resolution, | ||
0, | ||
projection.y_resolution, | ||
bounds[1] * projection.y_resolution, | ||
) | ||
|
||
with rasterio.open(asset_url) as src: | ||
with rasterio.vrt.WarpedVRT( | ||
src, | ||
crs=projection.crs, | ||
transform=wanted_transform, | ||
width=bounds[2] - bounds[0], | ||
height=bounds[3] - bounds[1], | ||
resampling=resampling, | ||
) as vrt: | ||
return vrt.read() | ||
|
||
def materialize( | ||
self, | ||
window: Window, | ||
item_groups: list[list[Item]], | ||
layer_name: str, | ||
layer_cfg: LayerConfig, | ||
) -> None: | ||
"""Materialize data for the window. | ||
Args: | ||
window: the window to materialize | ||
item_groups: the items from get_items | ||
layer_name: the name of this layer | ||
layer_cfg: the config of this layer | ||
""" | ||
assert isinstance(layer_cfg, RasterLayerConfig) | ||
RasterMaterializer().materialize( | ||
TileStoreWithLayer(self, layer_name), | ||
window, | ||
layer_name, | ||
layer_cfg, | ||
item_groups, | ||
) |
Oops, something went wrong.