Skip to content

Commit

Permalink
[GCP] GCE DWS Support (#3574)
Browse files Browse the repository at this point in the history
* [GCP] initial take for dws support with migs

* fix lint errors

* dependency and format fix

* refactor mig instance creation

* fix

* remove unecessary instance creation code for mig

* Fix deletion

* Fix instance template logic

* Restart

* format

* format

* move to REST APIs instead of python APIs

* add multi-node back

* Fix multi-node

* Avoid spot

* format

* format

* fix scheduling

* fix cancel

* Add smoke test

* revert some changes

* fix smoke

* Fix

* fix

* Fix smoke

* [GCP] Changing the config name for DWS support and fix for resize request cancellation (#5)

* Fix config fields

* fix cancel

* Add loggings

* remove useless codes

---------

Co-authored-by: Zhanghao Wu <[email protected]>
Co-authored-by: Zhanghao Wu <[email protected]>
  • Loading branch information
3 people authored Jun 11, 2024
1 parent 9a1aa5e commit e6ee397
Show file tree
Hide file tree
Showing 10 changed files with 662 additions and 71 deletions.
24 changes: 24 additions & 0 deletions docs/source/reference/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,30 @@ Available fields and semantics:
- projects/my-project/reservations/my-reservation2
# Managed instance group / DWS (optional).
#
# SkyPilot supports launching instances in a managed instance group (MIG)
# which schedules the GPU instance creation through DWS, offering a better
# availability. This feature is only applied when a resource request
# contains GPU instances.
managed_instance_group:
# Duration for a created instance to be kept alive (in seconds, required).
#
# This is required for the DWS to work properly. After the
# specified duration, the instance will be terminated.
run_duration: 3600
# Timeout for provisioning an instance by DWS (in seconds, optional).
#
# This timeout determines how long SkyPilot will wait for a managed
# instance group to create the requested resources before giving up,
# deleting the MIG and failing over to other locations. Larger timeouts
# may increase the chance for getting a resource, but will blcok failover
# to go to other zones/regions/clouds.
#
# Default: 900
provision_timeout: 900
# Identity to use for all GCP instances (optional).
#
# LOCAL_CREDENTIALS: The user's local credential files will be uploaded to
Expand Down
34 changes: 26 additions & 8 deletions sky/clouds/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sky import clouds
from sky import exceptions
from sky import sky_logging
from sky import skypilot_config
from sky.adaptors import gcp
from sky.clouds import service_catalog
from sky.clouds.utils import gcp_utils
Expand Down Expand Up @@ -179,20 +180,31 @@ class GCP(clouds.Cloud):
def _unsupported_features_for_resources(
cls, resources: 'resources.Resources'
) -> Dict[clouds.CloudImplementationFeatures, str]:
unsupported = {}
if gcp_utils.is_tpu_vm_pod(resources):
return {
unsupported = {
clouds.CloudImplementationFeatures.STOP: (
'TPU VM pods cannot be stopped. Please refer to: https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm#stopping_your_resources'
'TPU VM pods cannot be stopped. Please refer to: '
'https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm#stopping_your_resources'
)
}
if gcp_utils.is_tpu(resources) and not gcp_utils.is_tpu_vm(resources):
# TPU node does not support multi-node.
return {
clouds.CloudImplementationFeatures.MULTI_NODE:
('TPU node does not support multi-node. Please set '
'num_nodes to 1.')
}
return {}
unsupported[clouds.CloudImplementationFeatures.MULTI_NODE] = (
'TPU node does not support multi-node. Please set '
'num_nodes to 1.')
# TODO(zhwu): We probably need to store the MIG requirement in resources
# because `skypilot_config` may change for an existing cluster.
# Clusters created with MIG (only GPU clusters) cannot be stopped.
if (skypilot_config.get_nested(
('gcp', 'managed_instance_group'), None) is not None and
resources.accelerators):
unsupported[clouds.CloudImplementationFeatures.STOP] = (
'Managed Instance Group (MIG) does not support stopping yet.')
unsupported[clouds.CloudImplementationFeatures.SPOT_INSTANCE] = (
'Managed Instance Group with DWS does not support '
'spot instances.')
return unsupported

@classmethod
def max_cluster_name_length(cls) -> Optional[int]:
Expand Down Expand Up @@ -493,6 +505,12 @@ def make_deploy_resources_variables(

resources_vars['tpu_node_name'] = tpu_node_name

managed_instance_group_config = skypilot_config.get_nested(
('gcp', 'managed_instance_group'), None)
use_mig = managed_instance_group_config is not None
resources_vars['gcp_use_managed_instance_group'] = use_mig
if use_mig:
resources_vars.update(managed_instance_group_config)
return resources_vars

def _get_feasible_launchable_resources(
Expand Down
12 changes: 12 additions & 0 deletions sky/provision/gcp/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,15 @@
MAX_POLLS = 60 // POLL_INTERVAL
# Stopping instances can take several minutes, so we increase the timeout
MAX_POLLS_STOP = MAX_POLLS * 8

TAG_SKYPILOT_HEAD_NODE = 'skypilot-head-node'
# Tag uniquely identifying all nodes of a cluster
TAG_RAY_CLUSTER_NAME = 'ray-cluster-name'
TAG_RAY_NODE_KIND = 'ray-node-type'
TAG_SKYPILOT_CLUSTER_NAME = 'skypilot-cluster-name'

# MIG constants
MANAGED_INSTANCE_GROUP_CONFIG = 'managed-instance-group'
DEFAULT_MANAGED_INSTANCE_GROUP_PROVISION_TIMEOUT = 900 # 15 minutes
MIG_NAME_PREFIX = 'sky-mig-'
INSTANCE_TEMPLATE_NAME_PREFIX = 'sky-it-'
65 changes: 39 additions & 26 deletions sky/provision/gcp/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@

logger = sky_logging.init_logger(__name__)

TAG_SKYPILOT_HEAD_NODE = 'skypilot-head-node'
# Tag uniquely identifying all nodes of a cluster
TAG_RAY_CLUSTER_NAME = 'ray-cluster-name'
TAG_RAY_NODE_KIND = 'ray-node-type'

_INSTANCE_RESOURCE_NOT_FOUND_PATTERN = re.compile(
r'The resource \'projects/.*/zones/.*/instances/.*\' was not found')

Expand Down Expand Up @@ -66,7 +61,7 @@ def query_instances(
assert provider_config is not None, (cluster_name_on_cloud, provider_config)
zone = provider_config['availability_zone']
project_id = provider_config['project_id']
label_filters = {TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud}
label_filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud}

handler: Type[
instance_utils.GCPInstance] = instance_utils.GCPComputeInstance
Expand Down Expand Up @@ -124,15 +119,15 @@ def _wait_for_operations(
logger.debug(
f'wait_for_compute_{op_type}_operation: '
f'Waiting for operation {operation["name"]} to finish...')
handler.wait_for_operation(operation, project_id, zone)
handler.wait_for_operation(operation, project_id, zone=zone)


def _get_head_instance_id(instances: List) -> Optional[str]:
head_instance_id = None
for inst in instances:
labels = inst.get('labels', {})
if (labels.get(TAG_RAY_NODE_KIND) == 'head' or
labels.get(TAG_SKYPILOT_HEAD_NODE) == '1'):
if (labels.get(constants.TAG_RAY_NODE_KIND) == 'head' or
labels.get(constants.TAG_SKYPILOT_HEAD_NODE) == '1'):
head_instance_id = inst['name']
break
return head_instance_id
Expand All @@ -158,12 +153,14 @@ def _run_instances(region: str, cluster_name_on_cloud: str,
resource: Type[instance_utils.GCPInstance]
if node_type == instance_utils.GCPNodeType.COMPUTE:
resource = instance_utils.GCPComputeInstance
elif node_type == instance_utils.GCPNodeType.MIG:
resource = instance_utils.GCPManagedInstanceGroup
elif node_type == instance_utils.GCPNodeType.TPU:
resource = instance_utils.GCPTPUVMInstance
else:
raise ValueError(f'Unknown node type {node_type}')

filter_labels = {TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud}
filter_labels = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud}

# wait until all stopping instances are stopped/terminated
while True:
Expand Down Expand Up @@ -264,12 +261,16 @@ def get_order_key(node):
if config.resume_stopped_nodes and to_start_count > 0 and stopped_instances:
resumed_instance_ids = [n['name'] for n in stopped_instances]
if resumed_instance_ids:
for instance_id in resumed_instance_ids:
resource.start_instance(instance_id, project_id,
availability_zone)
resource.set_labels(project_id, availability_zone, instance_id,
labels)
to_start_count -= len(resumed_instance_ids)
resumed_instance_ids = resource.start_instances(
cluster_name_on_cloud, project_id, availability_zone,
resumed_instance_ids, labels)
# In MIG case, the resumed_instance_ids will include the previously
# PENDING and RUNNING instances. To avoid double counting, we need to
# remove them from the resumed_instance_ids.
ready_instances = set(resumed_instance_ids)
ready_instances |= set([n['name'] for n in running_instances])
ready_instances |= set([n['name'] for n in pending_instances])
to_start_count = config.count - len(ready_instances)

if head_instance_id is None:
head_instance_id = resource.create_node_tag(
Expand All @@ -281,9 +282,14 @@ def get_order_key(node):

if to_start_count > 0:
errors, created_instance_ids = resource.create_instances(
cluster_name_on_cloud, project_id, availability_zone,
config.node_config, labels, to_start_count,
head_instance_id is None)
cluster_name_on_cloud,
project_id,
availability_zone,
config.node_config,
labels,
to_start_count,
total_count=config.count,
include_head_node=head_instance_id is None)
if errors:
error = common.ProvisionerError('Failed to launch instances.')
error.errors = errors
Expand Down Expand Up @@ -387,7 +393,7 @@ def get_cluster_info(
assert provider_config is not None, cluster_name_on_cloud
zone = provider_config['availability_zone']
project_id = provider_config['project_id']
label_filters = {TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud}
label_filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud}

handlers: List[Type[instance_utils.GCPInstance]] = [
instance_utils.GCPComputeInstance
Expand Down Expand Up @@ -415,7 +421,7 @@ def get_cluster_info(
project_id,
zone,
{
**label_filters, TAG_RAY_NODE_KIND: 'head'
**label_filters, constants.TAG_RAY_NODE_KIND: 'head'
},
lambda h: [h.RUNNING_STATE],
)
Expand All @@ -441,14 +447,14 @@ def stop_instances(
assert provider_config is not None, cluster_name_on_cloud
zone = provider_config['availability_zone']
project_id = provider_config['project_id']
label_filters = {TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud}
label_filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud}

tpu_node = provider_config.get('tpu_node')
if tpu_node is not None:
instance_utils.delete_tpu_node(project_id, zone, tpu_node)

if worker_only:
label_filters[TAG_RAY_NODE_KIND] = 'worker'
label_filters[constants.TAG_RAY_NODE_KIND] = 'worker'

handlers: List[Type[instance_utils.GCPInstance]] = [
instance_utils.GCPComputeInstance
Expand Down Expand Up @@ -510,9 +516,16 @@ def terminate_instances(
if tpu_node is not None:
instance_utils.delete_tpu_node(project_id, zone, tpu_node)

label_filters = {TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud}
use_mig = provider_config.get('use_managed_instance_group', False)
if use_mig:
# Deleting the MIG will also delete the instances.
instance_utils.GCPManagedInstanceGroup.delete_mig(
project_id, zone, cluster_name_on_cloud)
return

label_filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud}
if worker_only:
label_filters[TAG_RAY_NODE_KIND] = 'worker'
label_filters[constants.TAG_RAY_NODE_KIND] = 'worker'

handlers: List[Type[instance_utils.GCPInstance]] = [
instance_utils.GCPComputeInstance
Expand Down Expand Up @@ -555,7 +568,7 @@ def open_ports(
project_id = provider_config['project_id']
firewall_rule_name = provider_config['firewall_rule']

label_filters = {TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud}
label_filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud}
handlers: List[Type[instance_utils.GCPInstance]] = [
instance_utils.GCPComputeInstance,
instance_utils.GCPTPUVMInstance,
Expand Down
Loading

0 comments on commit e6ee397

Please sign in to comment.