Skip to content

Commit

Permalink
First pass of typehint cleanup complete - no errors but many untyped …
Browse files Browse the repository at this point in the history
…functions.
  • Loading branch information
SpacemanPaul committed Apr 23, 2024
1 parent bfb517a commit a3c880c
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 35 deletions.
10 changes: 5 additions & 5 deletions datacube_ows/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

_LOG = logging.getLogger(__name__)

RAW_CFG = None | str | int | float | list["RAW_CFG"] | dict[str, "RAW_CFG"]
RAW_CFG = None | str | int | float | bool | list["RAW_CFG"] | dict[str, "RAW_CFG"]

CFG_DICT = dict[str, RAW_CFG]

Expand Down Expand Up @@ -79,7 +79,7 @@ def cfg_expand(cfg_unexpanded: RAW_CFG,
except Exception:
json_obj = None
if json_obj is None:
raise ConfigException("Could not find json file %s" % raw_path)
raise ConfigException(f"Could not find json file {raw_path}")
return cfg_expand(json_obj, cwd=cwd, inclusions=ninclusions)
elif cfg_unexpanded["type"] == "python":
# Python Expansion
Expand Down Expand Up @@ -666,9 +666,9 @@ def __init__(self, cfg: CFG_DICT, product_cfg: "datacube_ows.ows_configuration.O
self.pq_fuse_func: Optional[FunctionWrapper] = FunctionWrapper(self.product, cast(Mapping[str, Any], cfg["fuse_func"]))
else:
self.pq_fuse_func = None
self.pq_ignore_time = cfg.get("ignore_time", False)
self.pq_ignore_time = bool(cfg.get("ignore_time", False))
self.ignore_info_flags = cast(list[str], cfg.get("ignore_info_flags", []))
self.pq_manual_merge = cfg.get("manual_merge", False)
self.pq_manual_merge = bool(cfg.get("manual_merge", False))
self.declare_unready("pq_products")
self.declare_unready("flags_def")
self.declare_unready("info_mask")
Expand Down Expand Up @@ -750,7 +750,7 @@ def __init__(self, flag_band: FlagBand,
self.ignore_time = flag_band.pq_ignore_time
self.declare_unready("products")
self.declare_unready("low_res_products")
self.manual_merge = flag_band.pq_manual_merge
self.manual_merge = bool(flag_band.pq_manual_merge)
self.fuse_func = flag_band.pq_fuse_func
# pyre-ignore[16]
self.main_product = self.products_match(layer.product_names)
Expand Down
57 changes: 37 additions & 20 deletions datacube_ows/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

import datetime
import logging
from typing import Iterable
from typing import Iterable, cast, Mapping
from uuid import UUID

import datacube
import numpy
import xarray
from rasterio.enums import Resampling

from sqlalchemy.engine import Row

from odc.geo.geom import Geometry
from odc.geo.geobox import GeoBox
Expand All @@ -26,12 +28,12 @@
class ProductBandQuery:
def __init__(self,
products: list[datacube.model.Product],
bands: list[datacube.model.Measurement],
bands: Iterable[str],
main: bool = False, manual_merge: bool = False, ignore_time: bool = False,
fuse_func: datacube.api.core.FuserFunction | None = None
):
self.products = products
self.bands = bands
self.bands = set(bands)
self.manual_merge = manual_merge
self.fuse_func = fuse_func
self.ignore_time = ignore_time
Expand Down Expand Up @@ -66,7 +68,7 @@ def style_queries(cls, style: StyleDef, resource_limited: bool = False) -> list[
pq_products = fp.products
queries.append(cls(
pq_products,
tuple(fp.bands),
list(fp.bands),
manual_merge=fp.manual_merge,
ignore_time=fp.ignore_time,
fuse_func=fp.fuse_func
Expand All @@ -76,9 +78,9 @@ def style_queries(cls, style: StyleDef, resource_limited: bool = False) -> list[
@classmethod
def full_layer_queries(cls,
layer: OWSNamedLayer,
main_bands: list[datacube.model.Measurement] | None = None) -> list["ProductBandQuery"]:
main_bands: list[str] | None = None) -> list["ProductBandQuery"]:
if main_bands:
needed_bands = main_bands
needed_bands: Iterable[str] = main_bands
else:
needed_bands = set(layer.band_idx.band_cfg.keys())
queries = [
Expand All @@ -95,7 +97,7 @@ def full_layer_queries(cls,
pq_products = fpb.products
queries.append(cls(
pq_products,
tuple(fpb.bands),
list(fpb.bands),
manual_merge=fpb.manual_merge,
ignore_time=fpb.ignore_time,
fuse_func=fpb.fuse_func
Expand All @@ -104,7 +106,7 @@ def full_layer_queries(cls,

@classmethod
def simple_layer_query(cls, layer: OWSNamedLayer,
bands: list[datacube.model.Measurement],
bands: Iterable[str],
manual_merge: bool = False,
fuse_func: datacube.api.core.FuserFunction | None = None,
resource_limited: bool = False) -> "ProductBandQuery":
Expand All @@ -114,6 +116,7 @@ def simple_layer_query(cls, layer: OWSNamedLayer,
main_products = layer.products
return cls(main_products, bands, manual_merge=manual_merge, main=True, fuse_func=fuse_func)

PerPBQReturnType = xarray.DataArray | Iterable[UUID]

class DataStacker:
@log_call
Expand Down Expand Up @@ -159,15 +162,21 @@ def n_datasets(self,
index: datacube.index.Index,
all_time: bool = False,
point: Geometry | None = None) -> int:
return self.datasets(index,
return cast(int, self.datasets(index,
all_time=all_time, point=point,
mode=MVSelectOpts.COUNT)
mode=MVSelectOpts.COUNT))

def datasets(self, index: datacube.index.Index,
all_flag_bands: bool = False,
all_time: bool = False,
point: Geometry | None = None,
mode: MVSelectOpts = MVSelectOpts.DATASETS) -> int | Iterable[datacube.model.Dataset]:
mode: MVSelectOpts = MVSelectOpts.DATASETS) -> (int
| Iterable[Row]
| Iterable[UUID]
| xarray.DataArray
| Geometry
| None
| Mapping[ProductBandQuery, PerPBQReturnType]):
if mode == MVSelectOpts.EXTENT or all_time:
# Not returning datasets - use main product only
queries = [
Expand All @@ -194,7 +203,7 @@ def datasets(self, index: datacube.index.Index,
times = None
else:
times = self._times
results = []
results: list[tuple[ProductBandQuery, PerPBQReturnType]] = []
for query in queries:
if query.ignore_time:
qry_times = None
Expand All @@ -206,16 +215,24 @@ def datasets(self, index: datacube.index.Index,
geom=geom,
products=query.products)
if mode == MVSelectOpts.DATASETS:
result = datacube.Datacube.group_datasets(result, self.group_by)
grpd_result = datacube.Datacube.group_datasets(
cast(Iterable[datacube.model.Dataset], result),
self.group_by
)
if all_time:
return result
results.append((query, result))
return grpd_result
results.append((query, grpd_result))
elif mode == MVSelectOpts.IDS:
result_ids = cast(Iterable[UUID], result)
if all_time:
return result
results.append((query, result))
else:
return result
return result_ids
results.append((query, result_ids))
elif mode == MVSelectOpts.ALL:
return cast(Iterable[Row], result)
elif mode == MVSelectOpts.COUNT:
return cast(int, result)
else: # MVSelectOpts.EXTENT
return cast(Geometry | None, result)
return OrderedDict(results)

def create_nodata_filled_flag_bands(self, data, pbq):
Expand Down
22 changes: 17 additions & 5 deletions datacube_ows/mv_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import datetime
import json
from enum import Enum
from typing import Any, Iterable, Optional, Tuple, Union, cast
from types import UnionType
from typing import Any, Iterable, Type, TypeVar, cast
from uuid import UUID as UUID_

import pytz
from geoalchemy2 import Geometry
Expand All @@ -19,6 +21,8 @@

from datacube.index import Index
from datacube.model import Product, Dataset

from sqlalchemy.engine import Row
from sqlalchemy.engine.base import Engine
from sqlalchemy.sql.elements import ClauseElement

Expand Down Expand Up @@ -59,7 +63,6 @@ class MVSelectOpts(Enum):
COUNT = 2
EXTENT = 3
DATASETS = 4
INVALID = 9999

def sel(self, stv: Table) -> list[ClauseElement]:
if self == self.ALL:
Expand All @@ -73,16 +76,25 @@ def sel(self, stv: Table) -> list[ClauseElement]:
raise AssertionError("Invalid selection option")


selection_return_types: dict[MVSelectOpts, Type | UnionType] = {
MVSelectOpts.ALL: Iterable[Row],
MVSelectOpts.IDS: Iterable[UUID_],
MVSelectOpts.DATASETS: Iterable[Dataset],
MVSelectOpts.COUNT: int,
MVSelectOpts.EXTENT: ODCGeom | None,
}


SelectOut = Iterable[Row] | Iterable[UUID_] | Iterable[Dataset] | int | ODCGeom | None
DateOrDateTime = datetime.datetime | datetime.date
TimeSearchTerm = tuple[datetime.datetime, datetime.datetime] | tuple[datetime.date, datetime.date] | DateOrDateTime

MVSearchResult = Iterable[Iterable[Any]] | Iterable[str] | Iterable[Dataset] | int | None | ODCGeom

def mv_search(index: Index,
sel: MVSelectOpts = MVSelectOpts.IDS,
times: Iterable[TimeSearchTerm] | None = None,
geom: ODCGeom | None = None,
products: Iterable[Product] | None = None) -> MVSearchResult:
products: Iterable[Product] | None = None) -> SelectOut:
"""
Perform a dataset query via the space_time_view
Expand Down Expand Up @@ -147,7 +159,7 @@ def mv_search(index: Index,
elif sel in (MVSelectOpts.COUNT, MVSelectOpts.EXTENT):
for r in conn.execute(s):
if sel == MVSelectOpts.COUNT:
return r[0]
return cast(int, r[0])
else: # MVSelectOpts.EXTENT
geojson = r[0]
if geojson is None:
Expand Down
5 changes: 2 additions & 3 deletions integration_tests/test_mv_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,8 @@ def test_no_products():
def test_bad_set_opt():
cfg = get_config()
lyr = list(cfg.product_index.values())[0]
with cube() as dc:
with pytest.raises(AssertionError) as e:
sel = mv_search(dc.index, MVSelectOpts.INVALID, products=lyr.products)
with pytest.raises(ValueError) as e:
sel = MVSelectOpts("INVALID")


class MockGeobox:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,8 @@ def test_pbq_ctor_full(product_layer): # noqa: F811
assert "Query bands {" in str(pbqs[0])
assert "} from products [FakeODCProduct(test_odc_product)]" in str(pbqs[0])
assert str(pbqs[1]) in (
"Query bands ('wongle', 'pq') from products [FakeODCProduct(test_masking_product)]",
"Query bands ('pq', 'wongle') from products [FakeODCProduct(test_masking_product)]",
"Query bands {'wongle', 'pq'} from products [FakeODCProduct(test_masking_product)]",
"Query bands {'pq', 'wongle'} from products [FakeODCProduct(test_masking_product)]",
)


Expand Down

0 comments on commit a3c880c

Please sign in to comment.