Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] make per-cloud catalog lookup parallel #4483

Merged
merged 12 commits into from
Jan 13, 2025
9 changes: 6 additions & 3 deletions sky/clouds/service_catalog/__init__.py
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@
from sky.clouds.service_catalog.constants import CATALOG_SCHEMA_VERSION
from sky.clouds.service_catalog.constants import HOSTED_CATALOG_DIR_URL
from sky.utils import resources_utils
from sky.utils import subprocess_utils

if typing.TYPE_CHECKING:
from sky.clouds import cloud
@@ -31,8 +32,7 @@ def _map_clouds_catalog(clouds: CloudFilter, method_name: str, *args, **kwargs):
if single:
clouds = [clouds] # type: ignore

results = []
for cloud in clouds:
def _execute_catalog_method(cloud: str):
try:
cloud_module = importlib.import_module(
f'sky.clouds.service_catalog.{cloud.lower()}_catalog')
@@ -46,7 +46,10 @@ def _map_clouds_catalog(clouds: CloudFilter, method_name: str, *args, **kwargs):
raise AttributeError(
f'Module "{cloud}_catalog" does not '
f'implement the "{method_name}" method') from None
results.append(method(*args, **kwargs))
return method(*args, **kwargs)

results = subprocess_utils.maybe_parallelize_cloud_operation(
_execute_catalog_method, clouds) # type: ignore
if single:
return results[0]
return results
2 changes: 1 addition & 1 deletion sky/clouds/service_catalog/aws_catalog.py
Original file line number Diff line number Diff line change
@@ -101,7 +101,6 @@ def _get_az_mappings(aws_user_hash: str) -> Optional['pd.DataFrame']:
return az_mappings


@timeline.event
def _fetch_and_apply_az_mapping(df: common.LazyDataFrame) -> 'pd.DataFrame':
"""Maps zone IDs (use1-az1) to zone names (us-east-1x).
@@ -292,6 +291,7 @@ def get_region_zones_for_instance_type(instance_type: str,
return us_region_list + other_region_list


@timeline.event
def list_accelerators(
gpus_only: bool,
name_filter: Optional[str],
21 changes: 17 additions & 4 deletions sky/optimizer.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@
import enum
import json
import typing
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple

import colorama
import numpy as np
@@ -1254,6 +1254,18 @@ def _check_specified_clouds(dag: 'dag_lib.Dag') -> None:
f'{colorama.Fore.YELLOW}{msg}{colorama.Style.RESET_ALL}')


def _make_resource_finder(
resources: resources_lib.Resources,
num_nodes: int,
) -> Callable[[clouds.Cloud], Tuple[clouds.Cloud, resources_lib.Resources]]:

def fn(cloud: clouds.Cloud) -> Tuple[clouds.Cloud, resources_lib.Resources]:
return cloud, cloud.get_feasible_launchable_resources(
resources, num_nodes)

return fn


def _fill_in_launchable_resources(
task: task_lib.Task,
blocked_resources: Optional[Iterable[resources_lib.Resources]],
@@ -1293,9 +1305,10 @@ def _fill_in_launchable_resources(
if resources.cloud is not None else enabled_clouds)
# If clouds provide hints, store them for later printing.
hints: Dict[clouds.Cloud, str] = {}
for cloud in clouds_list:
feasible_resources = cloud.get_feasible_launchable_resources(
resources, num_nodes=task.num_nodes)

feasible_list = subprocess_utils.maybe_parallelize_cloud_operation(
_make_resource_finder(resources, task.num_nodes), clouds_list)
for cloud, feasible_resources in feasible_list:
if feasible_resources.hint is not None:
hints[cloud] = feasible_resources.hint
if len(feasible_resources.resources_list) > 0:
20 changes: 20 additions & 0 deletions sky/utils/subprocess_utils.py
Original file line number Diff line number Diff line change
@@ -293,3 +293,23 @@ def kill_process_daemon(process_pid: int) -> None:
# Disable input
stdin=subprocess.DEVNULL,
)


def maybe_parallelize_cloud_operation(
func: Callable,
clouds: List[Any],
num_threads: Optional[int] = None) -> List[Any]:
"""Apply a function to a list of clouds,
with parallelism if there is more than one cloud.
"""
count = len(clouds)
if count == 0:
return []
# Short-circuit in single cloud setup.
if count == 1:
return [func(clouds[0])]
# Cloud operations are assumed to be IO-bound, so the parallelism is set to
# the number of clouds by default, we are still safe because the number of
# clouds is enumarable even if this assumption does not hold.
processes = num_threads if num_threads is not None else count
return run_in_parallel(func, clouds, processes)