diff --git a/sky/optimizer.py b/sky/optimizer.py index 08efdd670a37..d295985f0d0f 100644 --- a/sky/optimizer.py +++ b/sky/optimizer.py @@ -19,6 +19,7 @@ from sky.adaptors import common as adaptors_common from sky.utils import env_options from sky.utils import log_utils +from sky.utils import subprocess_utils from sky.utils import ux_utils if typing.TYPE_CHECKING: @@ -279,7 +280,20 @@ def _estimate_nodes_cost_or_time( list(node.resources)[0]: list(node.resources) } + # Fetch reservations in advance and in parallel to speed up the + # reservation info fetching. num_resources = len(list(node.resources)) + num_available_reserved_nodes_per_resource = {} + + def get_reservations_available_resources( + resources: resources_lib.Resources): + num_available_reserved_nodes_per_resource[resources] = sum( + resources.get_reservations_available_resources().values()) + + launchable_resource_list: List[resources_lib.Resources] = sum( + launchable_resources.values(), []) + subprocess_utils.run_in_parallel( + get_reservations_available_resources, launchable_resource_list) for orig_resources, launchable_list in launchable_resources.items(): if num_resources == 1 and node.time_estimator_func is None: @@ -302,15 +316,16 @@ def _estimate_nodes_cost_or_time( else: estimated_runtime = node.estimate_runtime( orig_resources) + for resources in launchable_list: if do_print: logger.debug(f'resources: {resources}') if minimize_cost: cost_per_node = resources.get_cost(estimated_runtime) - num_available_reserved_nodes = sum( - resources.get_reservations_available_resources( - ).values()) + num_available_reserved_nodes = ( + num_available_reserved_nodes_per_resource[resources] + ) # We consider the cost of the unused reservation # resources to be 0 since we are already paying for