From 77d4baf0fab746a7660da44ccee188a607b7873d Mon Sep 17 00:00:00 2001 From: Paul Haesler Date: Tue, 23 Apr 2024 16:42:05 +1000 Subject: [PATCH] data.py and loading.py now fully type-hinted. --- datacube_ows/data.py | 163 +++++++++++++++++------------- datacube_ows/loading.py | 48 ++++++--- datacube_ows/ows_configuration.py | 2 +- 3 files changed, 126 insertions(+), 87 deletions(-) diff --git a/datacube_ows/data.py b/datacube_ows/data.py index 54f69920..401e1d11 100644 --- a/datacube_ows/data.py +++ b/datacube_ows/data.py @@ -8,11 +8,13 @@ import re from datetime import date, datetime, timedelta from itertools import chain +from typing import Iterable, cast, Any, Mapping import numpy import numpy.ma import pytz import xarray +from datacube.model import Dataset from datacube.utils.masking import mask_to_dict from flask import render_template from odc.geo import geom @@ -22,14 +24,15 @@ from rasterio.io import MemoryFile from datacube_ows.cube_pool import cube -from datacube_ows.loading import DataStacker +from datacube_ows.loading import DataStacker, ProductBandQuery from datacube_ows.mv_index import MVSelectOpts from datacube_ows.ogc_exceptions import WMSException from datacube_ows.ogc_utils import (dataset_center_time, solar_date, tz_for_geometry, xarray_image_as_png) -from datacube_ows.config_utils import ConfigException -from datacube_ows.ows_configuration import get_config +from datacube_ows.config_utils import ConfigException, RAW_CFG, CFG_DICT +from datacube_ows.ows_configuration import get_config, OWSNamedLayer, OWSConfig +from datacube_ows.styles import StyleDef from datacube_ows.query_profiler import QueryProfiler from datacube_ows.resource_limits import ResourceLimited from datacube_ows.utils import default_to_utc, log_call @@ -38,22 +41,15 @@ _LOG = logging.getLogger(__name__) +FlaskResponse = tuple[str | bytes, int, dict[str, str]] -def datasets_in_xarray(xa): - if xa is None: - return 0 - return sum(len(xa.values[i]) for i in range(0, len(xa.values))) - -def bbox_to_geom(bbox, crs): - return geom.box(bbox.left, bbox.bottom, bbox.right, bbox.top, crs) - - -def user_date_sorter(layer, odc_dates, geom, user_dates): +def user_date_sorter(layer: OWSNamedLayer, odc_dates: list[datetime], + geometry: geom.Geometry, user_dates: list[datetime]) -> xarray.DataArray: # TODO: Make more elegant. Just a little bit elegant would do. result = [] if layer.time_resolution.is_solar(): - tz = tz_for_geometry(geom) + tz = tz_for_geometry(geometry) else: tz = None @@ -97,8 +93,9 @@ def check_date(time_res, user_date, odc_date): class EmptyResponse(Exception): pass + @log_call -def get_map(args): +def get_map(args: dict[str, str]) -> FlaskResponse: # pylint: disable=too-many-nested-blocks, too-many-branches, too-many-statements, too-many-locals # Parse GET parameters params = GetMapParameters(args) @@ -131,22 +128,23 @@ def get_map(args): stacker.resource_limited = True qprof["resource_limited"] = str(e) if qprof.active: - q_ds_dict = stacker.datasets(dc.index, mode=MVSelectOpts.DATASETS) + q_ds_dict = cast(dict[ProductBandQuery, xarray.DataArray], + stacker.datasets(dc.index, mode=MVSelectOpts.DATASETS)) qprof["datasets"] = [] - for q, dss in q_ds_dict.items(): - query_res = {} + for q, dsxr in q_ds_dict.items(): + query_res: dict[str, Any] = {} query_res["query"] = str(q) query_res["datasets"] = [ [ f"{ds.id} ({ds.type.name})" for ds in tdss ] - for tdss in dss.values + for tdss in dsxr.values ] qprof["datasets"].append(query_res) if stacker.resource_limited and not params.product.low_res_product_names: qprof.start_event("extent-in-query") - extent = stacker.datasets(dc.index, mode=MVSelectOpts.EXTENT) + extent = cast(geom.Geometry | None, stacker.datasets(dc.index, mode=MVSelectOpts.EXTENT)) qprof.end_event("extent-in-query") if extent is None: qprof["write_action"] = "No extent: Write Empty" @@ -169,7 +167,7 @@ def get_map(args): qprof["n_summary_datasets"] = stacker.datasets(dc.index, mode=MVSelectOpts.COUNT) qprof.end_event("count-summary-datasets") qprof.start_event("fetch-datasets") - datasets = stacker.datasets(dc.index) + datasets = cast(dict[ProductBandQuery, xarray.DataArray], stacker.datasets(dc.index)) for flagband, dss in datasets.items(): if not dss.any(): _LOG.warning("Flag band %s returned no data", str(flagband)) @@ -181,20 +179,24 @@ def get_map(args): qprof.start_event("load-data") data = stacker.data(datasets) qprof.end_event("load-data") + if not data: + qprof["write_action"] = "No Data: Write Empty" + raise EmptyResponse() _LOG.debug("load stop %s %s", datetime.now().time(), args["requestid"]) qprof.start_event("build-masks") td_masks = [] for npdt in data.time.values: td = data.sel(time=npdt) - td_ext_mask = None + td_ext_mask_man: numpy.ndarray | None = None + td_ext_mask: xarray.DataArray | None = None band = "" for band in params.style.needed_bands: if band not in params.style.flag_bands: if params.product.data_manual_merge: - if td_ext_mask is None: - td_ext_mask = ~numpy.isnan(td[band]) + if td_ext_mask_man is None: + td_ext_mask_man = ~numpy.isnan(td[band]) else: - td_ext_mask &= ~numpy.isnan(td[band]) + td_ext_mask_man &= ~numpy.isnan(td[band]) else: for f in params.product.extent_mask_func: if td_ext_mask is None: @@ -202,7 +204,7 @@ def get_map(args): else: td_ext_mask &= f(td, band) if params.product.data_manual_merge: - td_ext_mask = xarray.DataArray(td_ext_mask) + td_ext_mask = xarray.DataArray(td_ext_mask_man) if td_ext_mask is None: td_ext_mask = xarray.DataArray( ~numpy.zeros( @@ -214,21 +216,17 @@ def get_map(args): td_masks.append(td_ext_mask) extent_mask = xarray.concat(td_masks, dim=data.time) qprof.end_event("build-masks") - if not data: - qprof["write_action"] = "No Data: Write Empty" - raise EmptyResponse() - else: - qprof["write_action"] = "Write Data" - if mdh and mdh.preserve_user_date_order: - sorter = user_date_sorter( - params.product, - data.time.values, - params.geobox.geographic_extent, - params.times) - data = data.sortby(sorter) - extent_mask = extent_mask.sortby(sorter) - - body = _write_png(data, params.style, extent_mask, qprof) + qprof["write_action"] = "Write Data" + if mdh and mdh.preserve_user_date_order: + sorter = user_date_sorter( + params.product, + data.time.values, + params.geobox.geographic_extent, + params.times) + data = data.sortby(sorter) + extent_mask = extent_mask.sortby(sorter) + + body = _write_png(data, params.style, extent_mask, qprof) except EmptyResponse: qprof.start_event("write") body = _write_empty(params.geobox) @@ -240,9 +238,10 @@ def get_map(args): return png_response(body, extra_headers=params.product.resource_limits.wms_cache_rules.cache_headers(n_datasets)) -def png_response(body, cfg=None, extra_headers=None): +def png_response(body: bytes, cfg: OWSConfig | None = None, extra_headers: dict[str, str] | None = None) -> FlaskResponse: if not cfg: cfg = get_config() + assert cfg is not None # For type checker if extra_headers is None: extra_headers = {} headers = {"Content-Type": "image/png"} @@ -252,7 +251,8 @@ def png_response(body, cfg=None, extra_headers=None): @log_call -def _write_png(data, style, extent_mask, qprof): +def _write_png(data: xarray.Dataset, style: StyleDef, extent_mask: xarray.DataArray, + qprof: QueryProfiler) -> bytes: qprof.start_event("combine-masks") mask = style.to_mask(data, extent_mask) qprof.end_event("combine-masks") @@ -272,7 +272,7 @@ def _write_png(data, style, extent_mask, qprof): @log_call -def _write_empty(geobox): +def _write_empty(geobox: GeoBox) -> bytes: with MemoryFile() as memfile: with memfile.open(driver='PNG', width=geobox.width, @@ -285,7 +285,7 @@ def _write_empty(geobox): return memfile.read() -def get_coordlist(geo, layer_name): +def get_coordlist(geo: geom.Geometry, layer_name: str) -> list[tuple[float | int, float | int]]: if geo.type == 'Polygon': coordinates_list = [geo.json["coordinates"]] elif geo.type == 'MultiPolygon': @@ -308,7 +308,7 @@ def get_coordlist(geo, layer_name): @log_call -def _write_polygon(geobox, polygon, zoom_fill, layer): +def _write_polygon(geobox: GeoBox, polygon: geom.Geometry, zoom_fill: list[int], layer: OWSNamedLayer) -> bytes: geobox_ext = geobox.extent if geobox_ext.within(polygon): data = numpy.full([geobox.height, geobox.width], fill_value=1, dtype="uint8") @@ -334,7 +334,9 @@ def _write_polygon(geobox, polygon, zoom_fill, layer): @log_call -def get_s3_browser_uris(datasets, pt=None, s3url="", s3bucket=""): +def get_s3_browser_uris(datasets: dict[ProductBandQuery, xarray.DataArray], + pt: geom.Geometry | None = None, + s3url: str = "", s3bucket: str = "") -> set[str]: uris = [] last_crs = None for pbq, dss in datasets.items(): @@ -357,7 +359,7 @@ def get_s3_browser_uris(datasets, pt=None, s3url="", s3bucket=""): regex = re.compile(r"s3:\/\/(?P[a-zA-Z0-9_\-\.]+)\/(?P[\S]+)/[a-zA-Z0-9_\-\.]+.yaml") # convert to browsable link - def convert(uri): + def convert(uri: str) -> str: uri_format = "http://{bucket}.s3-website-ap-southeast-2.amazonaws.com/?prefix={prefix}" uri_format_prod = str(s3url) + "/?prefix={prefix}" result = regex.match(uri) @@ -377,8 +379,8 @@ def convert(uri): @log_call -def _make_band_dict(prod_cfg, pixel_dataset): - band_dict = {} +def _make_band_dict(prod_cfg: OWSNamedLayer, pixel_dataset: xarray.Dataset) -> dict[str, dict[str, bool | str] | str]: + band_dict: dict[str, dict[str, bool | str] | str] = {} for k, v in pixel_dataset.data_vars.items(): band_val = pixel_dataset[k].item() flag_def = pixel_dataset[k].attrs.get("flags_definition") @@ -388,7 +390,7 @@ def _make_band_dict(prod_cfg, pixel_dataset): except TypeError as te: logging.warning('Working around for float bands') flag_dict = mask_to_dict(flag_def, int(band_val)) - ret_val = {} + ret_val: dict[str, bool | str] = {} for flag, val in flag_dict.items(): if not val: continue @@ -400,6 +402,7 @@ def _make_band_dict(prod_cfg, pixel_dataset): else: try: band_lbl = prod_cfg.band_idx.band_label(k) + assert k is not None # for type checker if band_val == pixel_dataset[k].nodata or numpy.isnan(band_val): band_dict[band_lbl] = "n/a" else: @@ -410,7 +413,7 @@ def _make_band_dict(prod_cfg, pixel_dataset): @log_call -def _make_derived_band_dict(pixel_dataset, style_index): +def _make_derived_band_dict(pixel_dataset: xarray.Dataset, style_index: dict[str, StyleDef]) -> dict[str, int | float]: """Creates a dict of values for bands derived by styles. This only works for styles with an `index_function` defined. @@ -431,24 +434,25 @@ def _make_derived_band_dict(pixel_dataset, style_index): return derived_band_dict -def geobox_is_point(geobox): - # TODO: Not 100% sure why this function is needed. +def geobox_is_point(geobox: GeoBox) -> bool: return geobox.height == 1 and geobox.width == 1 @log_call -def feature_info(args): +def feature_info(args: dict[str, str]) -> FlaskResponse: # pylint: disable=too-many-nested-blocks, too-many-branches, too-many-statements, too-many-locals # Parse GET parameters params = GetFeatureInfoParameters(args) - feature_json = {} + feature_json: CFG_DICT = {} geo_point = img_coords_to_geopoint(params.geobox, params.i, params.j) # shrink geobox to point # Prepare to extract feature info if geobox_is_point(params.geobox): + # request geobox is already 1x1 geo_point_geobox = params.geobox else: + # Make a 1x1 pixel geobox geo_point_geobox = GeoBox.from_geopolygon( geo_point, params.geobox.resolution, crs=params.geobox.crs) tz = tz_for_geometry(geo_point_geobox.geographic_extent) @@ -458,7 +462,7 @@ def feature_info(args): with cube() as dc: if not dc: raise WMSException("Database connectivity failure") - all_time_datasets = stacker.datasets(dc.index, all_time=True, point=geo_point) + all_time_datasets = cast(xarray.DataArray, stacker.datasets(dc.index, all_time=True, point=geo_point)) # Taking the data as a single point so our indexes into the data should be 0,0 h_coord = cfg.published_CRSs[params.crsid]["horizontal_coord"] @@ -473,8 +477,11 @@ def feature_info(args): # Group datasets by time, load only datasets that match the idx_date global_info_written = False feature_json["data"] = [] - fi_date_index = {} - time_datasets = stacker.datasets(dc.index, all_flag_bands=True, point=geo_point) + fi_date_index: dict[datetime, RAW_CFG] = {} + time_datasets = cast( + dict[ProductBandQuery, xarray.DataArray], + stacker.datasets(dc.index, all_flag_bands=True, point=geo_point) + ) data = stacker.data(time_datasets, skip_corrections=True) if data is not None: for dt in data.time.values: @@ -499,13 +506,14 @@ def feature_info(args): # Capture lat/long coordinates feature_json["lon"], feature_json["lat"] = ptg.coords[0] - date_info = {} + date_info: CFG_DICT = {} - ds = None + ds: Dataset | None = None for pbq, dss in time_datasets.items(): if pbq.main: ds = dss.sel(time=dt).values.tolist()[0] break + assert ds is not None if params.product.multi_product: if "platform" in ds.metadata_doc: date_info["source_product"] = "%s (%s)" % (ds.type.name, ds.metadata_doc["platform"]["code"]) @@ -513,29 +521,31 @@ def feature_info(args): date_info["source_product"] = ds.type.name # Extract data pixel - pixel_ds = td.isel(**isel_kwargs) + pixel_ds: xarray.Dataset = td.isel(**isel_kwargs) # type: ignore[arg-type] # Get accurate timestamp from dataset + assert ds.time is not None # For type checker if params.product.time_resolution.is_summary(): date_info["time"] = ds.time.begin.strftime("%Y-%m-%d") else: date_info["time"] = dataset_center_time(ds).strftime("%Y-%m-%d %H:%M:%S %Z") # Collect raw band values for pixel and derived bands from styles - date_info["bands"] = _make_band_dict(params.product, pixel_ds) - derived_band_dict = _make_derived_band_dict(pixel_ds, params.product.style_index) + date_info["bands"] = cast(RAW_CFG, _make_band_dict(params.product, pixel_ds)) + derived_band_dict = cast(RAW_CFG, _make_derived_band_dict(pixel_ds, params.product.style_index)) if derived_band_dict: date_info["band_derived"] = derived_band_dict # Add any custom-defined fields. for k, f in params.product.feature_info_custom_includes.items(): date_info[k] = f(date_info["bands"]) - feature_json["data"].append(date_info) - fi_date_index[dt] = feature_json["data"][-1] + cast(list[RAW_CFG], feature_json["data"]).append(date_info) + fi_date_index[dt] = cast(dict[str, list[RAW_CFG]], feature_json)["data"][-1] feature_json["data_available_for_dates"] = [] pt_native = None for d in all_time_datasets.coords["time"].values: dt_datasets = all_time_datasets.sel(time=d) for ds in dt_datasets.values.item(): + assert ds is not None # For type checker if pt_native is None: pt_native = geo_point.to_crs(ds.crs) elif pt_native.crs != ds.crs: @@ -544,18 +554,21 @@ def feature_info(args): # tolist() converts a numpy datetime64 to a python datatime dt = Timestamp(stacker.group_by.group_by_func(ds)).to_pydatetime() if params.product.time_resolution.is_subday(): - feature_json["data_available_for_dates"].append(dt.isoformat()) + cast(list[RAW_CFG], feature_json["data_available_for_dates"]).append(dt.isoformat()) else: - feature_json["data_available_for_dates"].append(dt.strftime("%Y-%m-%d")) + cast(list[RAW_CFG], feature_json["data_available_for_dates"]).append(dt.strftime("%Y-%m-%d")) break if time_datasets: - feature_json["data_links"] = sorted(get_s3_browser_uris(time_datasets, pt_native, s3_url, s3_bucket)) + feature_json["data_links"] = cast( + RAW_CFG, + sorted(get_s3_browser_uris(time_datasets, pt_native, s3_url, s3_bucket))) else: feature_json["data_links"] = [] if params.product.feature_info_include_utc_dates: - unsorted_dates = [] + unsorted_dates: list[str] = [] for tds in all_time_datasets: for ds in tds.values.item(): + assert ds is not None and ds.time is not None # for type checker if params.product.time_resolution.is_solar(): unsorted_dates.append(ds.center_time.strftime("%Y-%m-%d")) elif params.product.time_resolution.is_subday(): @@ -566,7 +579,7 @@ def feature_info(args): d.center_time.strftime("%Y-%m-%d") for d in all_time_datasets) # --- End code section requiring datacube. - result = { + result: CFG_DICT = { "type": "FeatureCollection", "features": [ { @@ -585,12 +598,16 @@ def feature_info(args): return json_response(result, cfg) -def json_response(result, cfg=None): +def json_response(result: CFG_DICT, cfg: OWSConfig | None = None) -> FlaskResponse: if not cfg: cfg = get_config() + assert cfg is not None # for type checker return json.dumps(result), 200, cfg.response_headers({"Content-Type": "application/json"}) -def html_json_response(result, cfg): +def html_json_response(result: CFG_DICT, cfg: OWSConfig | None = None) -> FlaskResponse: + if not cfg: + cfg = get_config() + assert cfg is not None # for type checker html_content = render_template("html_feature_info.html", result=result) return html_content, 200, cfg.response_headers({"Content-Type": "text/html"}) diff --git a/datacube_ows/loading.py b/datacube_ows/loading.py index c5375d36..9503b7ea 100644 --- a/datacube_ows/loading.py +++ b/datacube_ows/loading.py @@ -176,7 +176,7 @@ def datasets(self, index: datacube.index.Index, | xarray.DataArray | Geometry | None - | Mapping[ProductBandQuery, PerPBQReturnType]): + | dict[ProductBandQuery, PerPBQReturnType]): if mode == MVSelectOpts.EXTENT or all_time: # Not returning datasets - use main product only queries = [ @@ -235,17 +235,17 @@ def datasets(self, index: datacube.index.Index, return cast(Geometry | None, result) return OrderedDict(results) - def create_nodata_filled_flag_bands(self, data, pbq): + def create_nodata_filled_flag_bands(self, data: xarray.Dataset, pbq: ProductBandQuery) -> xarray.Dataset: var = None for var in data.data_vars.variables.keys(): break if var is None: raise WMSException("Cannot add default flag data as there is no non-flag data available") - template = getattr(data, var) + template = cast(xarray.DataArray, getattr(data, cast(str, var))) data_new_bands = {} for band in pbq.bands: default_value = pbq.products[0].measurements[band].nodata - new_data = numpy.ndarray(template.shape, dtype="uint8") + new_data: numpy.ndarray = numpy.ndarray(template.shape, dtype="uint8") new_data.fill(default_value) qry_result = template.copy(data=new_data) data_new_bands[band] = qry_result @@ -255,10 +255,12 @@ def create_nodata_filled_flag_bands(self, data, pbq): return data @log_call - def data(self, datasets_by_query, skip_corrections=False): + def data(self, + datasets_by_query: dict[ProductBandQuery, xarray.DataArray], + skip_corrections=False) -> xarray.Dataset | None: # pylint: disable=too-many-locals, consider-using-enumerate # datasets is an XArray DataArray of datasets grouped by time. - data = None + data: xarray.Dataset | None = None for pbq, datasets in datasets_by_query.items(): if data is not None and len(data.time) == 0: # No data, so no need for masking data. @@ -269,6 +271,8 @@ def data(self, datasets_by_query, skip_corrections=False): qry_result = self.manual_data_stack(datasets, measurements, pbq.bands, skip_corrections, fuse_func=fuse_func) else: qry_result = self.read_data(datasets, measurements, self._geobox, resampling=self._resampling, fuse_func=fuse_func) + if qry_result is None: + continue if data is None: data = qry_result continue @@ -301,18 +305,24 @@ def data(self, datasets_by_query, skip_corrections=False): # Time-aware mask product has no data, but main product does. data = self.create_nodata_filled_flag_bands(data, pbq) continue + assert data is not None qry_result.coords["time"] = data.coords["time"] - data = xarray.combine_by_coords([data, qry_result], join="exact") + data = cast(xarray.Dataset, xarray.combine_by_coords([data, qry_result], join="exact")) return data @log_call - def manual_data_stack(self, datasets, measurements, bands, skip_corrections, fuse_func): + def manual_data_stack(self, + datasets: xarray.DataArray, + measurements: Mapping[str, datacube.model.Measurement], + bands: set[str], + skip_corrections: bool, + fuse_func: datacube.api.core.FuserFunction | None) -> xarray.Dataset | None: # pylint: disable=too-many-locals, too-many-branches # manual merge if self.style: - flag_bands = set(filter(lambda b: b in self.style.flag_bands, bands)) - non_flag_bands = set(filter(lambda b: b not in self.style.flag_bands, bands)) + flag_bands: Iterable[str] = set(filter(lambda b: b in self.style.flag_bands, bands)) # type: ignore[arg-type] + non_flag_bands: Iterable[str] = set(filter(lambda b: b not in self.style.flag_bands, bands)) #type: ignore[arg-type] else: non_flag_bands = bands flag_bands = set() @@ -354,7 +364,13 @@ def manual_data_stack(self, datasets, measurements, bands, skip_corrections, fus # Read data for given datasets and measurements per the output_geobox # TODO: Make skip_broken passed in via config @log_call - def read_data(self, datasets, measurements, geobox, skip_broken = True, resampling="nearest", fuse_func=None): + def read_data(self, + datasets: xarray.DataArray, + measurements: Mapping[str, datacube.model.Measurement], + geobox: GeoBox, + skip_broken: bool = True, + resampling: Resampling = "nearest", + fuse_func: datacube.api.core.FuserFunction | None = None) -> xarray.Dataset: CredentialManager.check_cred() try: return datacube.Datacube.load_data( @@ -368,10 +384,16 @@ def read_data(self, datasets, measurements, geobox, skip_broken = True, resampli except Exception as e: _LOG.error("Error (%s) in load_data: %s", e.__class__.__name__, str(e)) raise - # Read data for single datasets and measurements per the output_geobox + # TODO: Make skip_broken passed in via config @log_call - def read_data_for_single_dataset(self, dataset, measurements, geobox, skip_broken = True, resampling="nearest", fuse_func=None): + def read_data_for_single_dataset(self, + dataset: datacube.model.Dataset, + measurements: Mapping[str, datacube.model.Measurement], + geobox: GeoBox, + skip_broken: bool = True, + resampling: Resampling = "nearest", + fuse_func: datacube.api.core.FuserFunction | None = None) -> xarray.Dataset: datasets = [dataset] dc_datasets = datacube.Datacube.group_datasets(datasets, self._product.time_resolution.dataset_groupby()) CredentialManager.check_cred() diff --git a/datacube_ows/ows_configuration.py b/datacube_ows/ows_configuration.py index fc786c1e..15bbd982 100644 --- a/datacube_ows/ows_configuration.py +++ b/datacube_ows/ows_configuration.py @@ -158,7 +158,7 @@ def locale_band(self, name_alias): return b raise ConfigException(f"Unknown band: {name_alias} in layer {self.product.name}") - def band_label(self, name_alias): + def band_label(self, name_alias) -> str | None: canonical_name = self.band(name_alias) return self.read_local_metadata(canonical_name)