diff --git a/sky/clouds/service_catalog/data_fetchers/fetch_gcp.py b/sky/clouds/service_catalog/data_fetchers/fetch_gcp.py index 3a4b177fa2a..e95e1c83a6e 100644 --- a/sky/clouds/service_catalog/data_fetchers/fetch_gcp.py +++ b/sky/clouds/service_catalog/data_fetchers/fetch_gcp.py @@ -10,7 +10,7 @@ import multiprocessing import os import textwrap -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Set import google.auth from googleapiclient import discovery @@ -83,8 +83,8 @@ tpu_client = discovery.build('tpu', 'v1') SINGLE_THREADED = False -ZONES: List[str] = [] -EXCLUDED_REGIONS: List[str] = [] +ZONES: Set[str] = set() +EXCLUDED_REGIONS: Set[str] = set() def get_skus(service_id: str) -> List[Dict[str, Any]]: @@ -148,18 +148,17 @@ def filter_zones(func: Callable[[], List[str]]) -> Callable[[], List[str]]: removes any zones present in the global EXCLUDED_REGIONS (if defined). """ - def wrapper(*arguments, - **keyword_args) -> List[str]: # Renamed args to arguments + def wrapper(*arguments, **kwargs) -> List[str]: # Renamed args to arguments # Get the original zones from the decorated function - original_zones = set(func(*arguments, **keyword_args)) + original_zones = set(func(*arguments, **kwargs)) # Intersect with ZONES if defined if ZONES: - original_zones &= set(ZONES) + original_zones &= ZONES # Remove zones from EXCLUDED_REGIONS if defined if EXCLUDED_REGIONS: - original_zones -= set(EXCLUDED_REGIONS) + original_zones -= EXCLUDED_REGIONS return list(original_zones) @@ -533,10 +532,10 @@ def get_catalog_df(region_prefix: str) -> pd.DataFrame: args = parser.parse_args() SINGLE_THREADED = args.single_threaded - ZONES = args.zones - EXCLUDED_REGIONS = args.exclude + ZONES = set(args.zones) + EXCLUDED_REGIONS = set(args.exclude) - region_prefix_filter = '' if args.all_regions else 'us-' + region_prefix_filter = '' if args.zones or args.all_regions else 'us-' catalog_df = get_catalog_df(region_prefix_filter) os.makedirs('gcp', exist_ok=True)