diff --git a/sky/backends/wheel_utils.py b/sky/backends/wheel_utils.py index 254094ad81a..44d47b52926 100644 --- a/sky/backends/wheel_utils.py +++ b/sky/backends/wheel_utils.py @@ -39,29 +39,30 @@ f'{version.parse(sky.__version__)}-*.whl') -def _get_latest_wheel_and_remove_all_others() -> pathlib.Path: - wheel_name = (f'**/{_WHEEL_PATTERN}') +def _remove_stale_wheels(latest_wheel_dir: pathlib.Path) -> None: + """Remove all wheels except the latest one.""" + for f in WHEEL_DIR.iterdir(): + if f != latest_wheel_dir: + if f.is_dir() and not f.is_symlink(): + shutil.rmtree(f, ignore_errors=True) + + +def _get_latest_wheel() -> pathlib.Path: + wheel_name = f'**/{_WHEEL_PATTERN}' try: latest_wheel = max(WHEEL_DIR.glob(wheel_name), key=os.path.getctime) except ValueError: raise FileNotFoundError( 'Could not find built SkyPilot wheels with glob pattern ' f'{wheel_name} under {WHEEL_DIR!r}') from None - - latest_wheel_dir_name = latest_wheel.parent - # Cleanup older wheels. - for f in WHEEL_DIR.iterdir(): - if f != latest_wheel_dir_name: - if f.is_dir() and not f.is_symlink(): - shutil.rmtree(f, ignore_errors=True) return latest_wheel -def _build_sky_wheel(): - """Build a wheel for SkyPilot.""" - with tempfile.TemporaryDirectory() as tmp_dir: +def _build_sky_wheel() -> pathlib.Path: + """Build a wheel for SkyPilot and return the path to the wheel.""" + with tempfile.TemporaryDirectory() as tmp_dir_str: # prepare files - tmp_dir = pathlib.Path(tmp_dir) + tmp_dir = pathlib.Path(tmp_dir_str) sky_tmp_dir = tmp_dir / 'sky' sky_tmp_dir.mkdir() for item in SKY_PACKAGE_PATH.iterdir(): @@ -129,6 +130,7 @@ def _build_sky_wheel(): wheel_dir = WHEEL_DIR / hash_of_latest_wheel wheel_dir.mkdir(parents=True, exist_ok=True) shutil.move(str(wheel_path), wheel_dir) + return wheel_dir / wheel_path.name def build_sky_wheel() -> Tuple[pathlib.Path, str]: @@ -161,13 +163,22 @@ def _get_latest_modification_time(path: pathlib.Path) -> float: last_modification_time = _get_latest_modification_time(SKY_PACKAGE_PATH) last_wheel_modification_time = _get_latest_modification_time(WHEEL_DIR) - # only build wheels if the wheel is outdated - if last_wheel_modification_time < last_modification_time: + # Only build wheels if the wheel is outdated or wheel does not exist + # for the requested version. + if (last_wheel_modification_time < last_modification_time) or not any( + WHEEL_DIR.glob(f'**/{_WHEEL_PATTERN}')): if not WHEEL_DIR.exists(): WHEEL_DIR.mkdir(parents=True, exist_ok=True) - _build_sky_wheel() - - latest_wheel = _get_latest_wheel_and_remove_all_others() + latest_wheel = _build_sky_wheel() + else: + latest_wheel = _get_latest_wheel() + + # We remove all wheels except the latest one for garbage collection. + # Otherwise stale wheels will accumulate over time. + # TODO(romilb): If the user switches versions every alternate launch, + # the wheel will be rebuilt every time. At the risk of adding + # complexity, we can consider TTL caching wheels by version here. + _remove_stale_wheels(latest_wheel.parent) wheel_hash = latest_wheel.parent.name