diff --git a/sky/clouds/service_catalog/common.py b/sky/clouds/service_catalog/common.py index bcc579bb58c..fbbe0fdcef1 100644 --- a/sky/clouds/service_catalog/common.py +++ b/sky/clouds/service_catalog/common.py @@ -5,7 +5,7 @@ import os import time import typing -from typing import Dict, List, NamedTuple, Optional, Tuple +from typing import Callable, Dict, List, NamedTuple, Optional, Tuple import filelock import requests @@ -118,19 +118,21 @@ def _get_modified_catalogs() -> List[str]: class LazyDataFrame: - """A lazy data frame that reads the catalog on demand. + """A lazy data frame that updates and reads the catalog on demand. We don't need to load the catalog for every SkyPilot call, and this class allows us to load the catalog only when needed. """ - def __init__(self, filename: str): + def __init__(self, filename: str, update_func: Callable[[], None]): self._filename = filename self._df: Optional['pd.DataFrame'] = None + self._update_func = update_func def _load_df(self) -> 'pd.DataFrame': if self._df is None: try: + self._update_func() self._df = pd.read_csv(self._filename) except Exception as e: # pylint: disable=broad-except # As users can manually modify the catalog, read_csv can fail. @@ -172,58 +174,60 @@ def read_catalog(filename: str, meta_path = os.path.join(_ABSOLUTE_VERSIONED_CATALOG_DIR, '.meta', filename) os.makedirs(os.path.dirname(meta_path), exist_ok=True) - # Atomic check, to avoid conflicts with other processes. - # TODO(mraheja): remove pylint disabling when filelock version updated - # pylint: disable=abstract-class-instantiated - with filelock.FileLock(meta_path + '.lock'): - - def _need_update() -> bool: - if not os.path.exists(catalog_path): - return True - if pull_frequency_hours is None: - return False - if is_catalog_modified(filename): - # If the catalog is modified by a user manually, we should - # avoid overwriting the catalog by fetching from GitHub. - return False - - last_update = os.path.getmtime(catalog_path) - return last_update + pull_frequency_hours * 3600 < time.time() - - if _need_update(): - url = f'{constants.HOSTED_CATALOG_DIR_URL}/{constants.CATALOG_SCHEMA_VERSION}/{filename}' # pylint: disable=line-too-long - update_frequency_str = '' - if pull_frequency_hours is not None: - update_frequency_str = f' (every {pull_frequency_hours} hours)' - with rich_utils.safe_status((f'Updating {cloud} catalog: ' - f'{filename}' - f'{update_frequency_str}')): - try: - r = requests.get(url) - r.raise_for_status() - except requests.exceptions.RequestException as e: - error_str = (f'Failed to fetch {cloud} catalog ' - f'{filename}. ') - if os.path.exists(catalog_path): - logger.warning( - f'{error_str}Using cached catalog files.') - # Update catalog file modification time. - os.utime(catalog_path, None) # Sets to current time + def _need_update() -> bool: + if not os.path.exists(catalog_path): + return True + if pull_frequency_hours is None: + return False + if is_catalog_modified(filename): + # If the catalog is modified by a user manually, we should + # avoid overwriting the catalog by fetching from GitHub. + return False + + last_update = os.path.getmtime(catalog_path) + return last_update + pull_frequency_hours * 3600 < time.time() + + def _update_catalog(): + # Atomic check, to avoid conflicts with other processes. + with filelock.FileLock(meta_path + '.lock'): + if _need_update(): + url = f'{constants.HOSTED_CATALOG_DIR_URL}/{constants.CATALOG_SCHEMA_VERSION}/{filename}' # pylint: disable=line-too-long + update_frequency_str = '' + if pull_frequency_hours is not None: + update_frequency_str = ( + f' (every {pull_frequency_hours} hours)') + with rich_utils.safe_status((f'Updating {cloud} catalog: ' + f'{filename}' + f'{update_frequency_str}')): + try: + r = requests.get(url) + r.raise_for_status() + except requests.exceptions.RequestException as e: + error_str = (f'Failed to fetch {cloud} catalog ' + f'{filename}. ') + if os.path.exists(catalog_path): + logger.warning( + f'{error_str}Using cached catalog files.') + # Update catalog file modification time. + os.utime(catalog_path, None) # Sets to current time + else: + logger.error( + f'{error_str}Please check your internet ' + 'connection.') + with ux_utils.print_exception_no_traceback(): + raise e else: - logger.error( - f'{error_str}Please check your internet connection.' - ) - with ux_utils.print_exception_no_traceback(): - raise e - else: - # Download successful, save the catalog to a local file. - os.makedirs(os.path.dirname(catalog_path), exist_ok=True) - with open(catalog_path, 'w', encoding='utf-8') as f: - f.write(r.text) - with open(meta_path + '.md5', 'w', encoding='utf-8') as f: - f.write(hashlib.md5(r.text.encode()).hexdigest()) - - return LazyDataFrame(catalog_path) + # Download successful, save the catalog to a local file. + os.makedirs(os.path.dirname(catalog_path), + exist_ok=True) + with open(catalog_path, 'w', encoding='utf-8') as f: + f.write(r.text) + with open(meta_path + '.md5', 'w', + encoding='utf-8') as f: + f.write(hashlib.md5(r.text.encode()).hexdigest()) + logger.info(f'Updated {cloud} catalog.') + + return LazyDataFrame(catalog_path, update_func=_update_catalog) def _get_instance_type(