From 03d64a5751b9a35fca994283a8fe38c05feb26b6 Mon Sep 17 00:00:00 2001 From: Alex Leith Date: Tue, 26 Sep 2023 11:06:53 +1000 Subject: [PATCH 1/2] Set print_tasks up to use the get_areas function --- src/print_tasks.py | 6 +++--- src/run_task.py | 7 ++++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/print_tasks.py b/src/print_tasks.py index e0a72c9..9fa692c 100644 --- a/src/print_tasks.py +++ b/src/print_tasks.py @@ -8,8 +8,7 @@ from dep_tools.namers import DepItemPath from dep_tools.azure import get_container_client -from grid import grid -from run_task import MANGROVES_BASE_PRODUCT, MANGROVES_DATASET_ID +from run_task import MANGROVES_BASE_PRODUCT, MANGROVES_DATASET_ID, get_areas def main( @@ -30,8 +29,9 @@ def main( elif len(years) > 2: raise ValueError(f"{datetime} is not a valid value for --datetime") + areas = get_areas() grid_subset = ( - grid.loc[grid.code.isin(region_codes)] if region_codes is not None else grid + areas.loc[areas.code.isin(region_codes)] if region_codes is not None else areas ) itempath = DepItemPath(MANGROVES_BASE_PRODUCT, dataset_id, version, datetime) diff --git a/src/run_task.py b/src/run_task.py index 924d0c4..1194bf5 100755 --- a/src/run_task.py +++ b/src/run_task.py @@ -5,6 +5,7 @@ import numpy as np import typer from typing_extensions import Annotated +from typing import Optional from xarray import DataArray from xrspatial.classify import reclassify import xrspatial.multispectral as ms @@ -45,16 +46,16 @@ def process(self, xr: DataArray) -> DataArray: return set_stac_properties(xr, ds) -def get_areas(region_code: str, region_index: str) -> gpd.GeoDataFrame: +def get_areas(region_code: Optional[str] = None, region_index: Optional[str] = None) -> gpd.GeoDataFrame: with fsspec.open(GRID_URL) as f: grid = gpd.read_parquet(f) areas = None # None would be better for default but typer doesn't support it (str|None) - if region_code != "": + if region_code is not None or region_code != "": areas = grid[grid.index.get_level_values("code").isin([region_code])] - if region_index != "": + if region_index is not None or region_index != "": areas = grid[grid.index == (region_code, region_index)] return areas From b1cd84f96b7a7661ba0fbdc83a7cdce95d2d17a0 Mon Sep 17 00:00:00 2001 From: Alex Leith Date: Tue, 26 Sep 2023 11:13:45 +1000 Subject: [PATCH 2/2] Fix formatting --- src/run_task.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/run_task.py b/src/run_task.py index 1194bf5..d53303a 100755 --- a/src/run_task.py +++ b/src/run_task.py @@ -46,7 +46,9 @@ def process(self, xr: DataArray) -> DataArray: return set_stac_properties(xr, ds) -def get_areas(region_code: Optional[str] = None, region_index: Optional[str] = None) -> gpd.GeoDataFrame: +def get_areas( + region_code: Optional[str] = None, region_index: Optional[str] = None +) -> gpd.GeoDataFrame: with fsspec.open(GRID_URL) as f: grid = gpd.read_parquet(f) areas = None