Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set print_tasks up to use the get_areas function #11

Merged
merged 2 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/print_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from dep_tools.namers import DepItemPath
from dep_tools.azure import get_container_client

from grid import grid
from run_task import MANGROVES_BASE_PRODUCT, MANGROVES_DATASET_ID
from run_task import MANGROVES_BASE_PRODUCT, MANGROVES_DATASET_ID, get_areas


def main(
Expand All @@ -30,8 +29,9 @@ def main(
elif len(years) > 2:
raise ValueError(f"{datetime} is not a valid value for --datetime")

areas = get_areas()
grid_subset = (
grid.loc[grid.code.isin(region_codes)] if region_codes is not None else grid
areas.loc[areas.code.isin(region_codes)] if region_codes is not None else areas
)

itempath = DepItemPath(MANGROVES_BASE_PRODUCT, dataset_id, version, datetime)
Expand Down
9 changes: 6 additions & 3 deletions src/run_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import typer
from typing_extensions import Annotated
from typing import Optional
from xarray import DataArray
from xrspatial.classify import reclassify
import xrspatial.multispectral as ms
Expand Down Expand Up @@ -45,16 +46,18 @@ def process(self, xr: DataArray) -> DataArray:
return set_stac_properties(xr, ds)


def get_areas(region_code: str, region_index: str) -> gpd.GeoDataFrame:
def get_areas(
region_code: Optional[str] = None, region_index: Optional[str] = None
) -> gpd.GeoDataFrame:
with fsspec.open(GRID_URL) as f:
grid = gpd.read_parquet(f)
areas = None

# None would be better for default but typer doesn't support it (str|None)
if region_code != "":
if region_code is not None or region_code != "":
areas = grid[grid.index.get_level_values("code").isin([region_code])]

if region_index != "":
if region_index is not None or region_index != "":
areas = grid[grid.index == (region_code, region_index)]

return areas
Expand Down