From ecca47ea78227929c58fa8090593f34d6a0c0792 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Fri, 20 Oct 2023 15:15:14 -0700 Subject: [PATCH] fix --- sky/provision/gcp/instance.py | 9 ++++--- sky/provision/gcp/instance_utils.py | 39 ++++++++++++++++++++++++++--- 2 files changed, 41 insertions(+), 7 deletions(-) diff --git a/sky/provision/gcp/instance.py b/sky/provision/gcp/instance.py index d409d55a1037..06457d11128f 100644 --- a/sky/provision/gcp/instance.py +++ b/sky/provision/gcp/instance.py @@ -295,9 +295,10 @@ def get_cluster_info( label_filters, lambda _: ['RUNNING'], ) - all_instances = [ - i for instances in handler_to_instances.values() for i in instances - ] + instances = {} + for res, insts in handler_to_instances.items(): + for inst in insts: + instances[inst] = res.get_instance_info(project_id, zone, inst) head_instances = _filter_instances( handlers, @@ -315,7 +316,7 @@ def get_cluster_info( break return common.ClusterInfo( - instances=all_instances, + instances=instances, head_instance_id=head_instance_id, ) diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index 9f2d186d0cc3..407597ec01d0 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -9,6 +9,7 @@ from sky import sky_logging from sky.adaptors import gcp +from sky.provision import common from sky.provision.gcp.constants import MAX_POLLS from sky.provision.gcp.constants import POLL_INTERVAL from sky.utils import ux_utils @@ -269,6 +270,15 @@ def create_node_tag(cls, wait_for_operation: bool = True) -> str: raise NotImplementedError + @classmethod + def get_instance_info( + cls, + project_id: str, + availability_zone: str, + instance_id: str, + wait_for_operation: bool = True) -> common.InstanceInfo: + raise NotImplementedError + class GCPComputeInstance(GCPInstance): """Instance handler for GCP compute instances.""" @@ -525,12 +535,11 @@ def set_labels(cls, node_id: str, labels: dict, wait_for_operation: bool = True) -> dict: - response = (cls.load_resource().instances().get( + node = cls.load_resource().instances().get( project=project_id, instance=node_id, zone=availability_zone, - ).execute()) - node = response.get('items', [])[0] + ).execute() body = { "labels": dict(node["labels"], **labels), "labelFingerprint": node["labelFingerprint"], @@ -766,6 +775,30 @@ def start_instance(cls, return result + @classmethod + def get_instance_info( + cls, + project_id: str, + availability_zone: str, + instance_id: str, + wait_for_operation: bool = True) -> common.InstanceInfo: + result = cls.load_resource().instances().get( + project=project_id, + zone=availability_zone, + instance=instance_id, + ).execute() + external_ip = (result.get("networkInterfaces", + [{}])[0].get("accessConfigs", + [{}])[0].get("natIP", None)) + internal_ip = result.get("networkInterfaces", [{}])[0].get("networkIP") + + return common.InstanceInfo( + instance_id=instance_id, + internal_ip=internal_ip, + external_ip=external_ip, + tags=result.get('labels', {}), + ) + class GCPTPUVMInstance(GCPInstance): """Instance handler for GCP TPU node."""