Skip to content

Commit

Permalink
merge get_vpc_name and create_or_update_firewall_rule
Browse files Browse the repository at this point in the history
  • Loading branch information
cblmemo committed Sep 16, 2023
1 parent c546115 commit a23e520
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 41 deletions.
14 changes: 2 additions & 12 deletions sky/provision/gcp/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,24 +196,14 @@ def open_ports(
logger.warning(f'No instance found for cluster '
f'{cluster_name_on_cloud}.')
continue
# If multiple instances are found, they are in the same cluster,
# i.e. the same VPC. So we can just pick one.
try:
vpc_name = handler.get_vpc_name(project_id, zone, instances[0])
except gcp.http_error_exception() as e:
if _INSTANCE_RESOURCE_NOT_FOUND_PATTERN.search(e.reason) is None:
logger.warning(f'Failed to get VPC name for instance '
f'{instances[0]}. Skip opening ports for it.')
else:
logger.warning(f'Instance {instances[0]} does not exist. '
f'Skip opening ports for it.')
else:
op = handler.create_or_update_firewall_rule(
firewall_rule_name,
project_id,
zone,
instances,
cluster_name_on_cloud,
ports,
vpc_name,
)
# op is None if any error occurs.
if op is not None:
Expand Down
50 changes: 21 additions & 29 deletions sky/provision/gcp/instance_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,27 +79,19 @@ def filter(
def delete_firewall_rule(
cls,
project_id: str,
cluster_name_on_cloud: str,
firewall_rule_name: str,
) -> None:
raise NotImplementedError

@classmethod
def get_vpc_name(
cls,
project_id: str,
zone: str,
instance: str,
) -> str:
raise NotImplementedError

@classmethod
def create_or_update_firewall_rule(
cls,
firewall_rule_name: str,
project_id: str,
zone: str,
instances: List[str],
cluster_name_on_cloud: str,
ports: List[str],
vpc_name: str,
) -> Optional[dict]:
raise NotImplementedError

Expand Down Expand Up @@ -241,32 +233,32 @@ def delete_firewall_rule(
firewall=firewall_rule_name,
).execute()

@classmethod
def get_vpc_name(
cls,
project_id: str,
zone: str,
instance: str,
) -> str:
# Any errors will be handled in the caller function.
instance = cls.load_resource().instances().get(
project=project_id,
zone=zone,
instance=instance,
).execute()
# Format: projects/PROJECT_ID/global/networks/VPC_NAME
vpc_link = instance['networkInterfaces'][0]['network']
return vpc_link.split('/')[-1]

@classmethod
def create_or_update_firewall_rule(
cls,
firewall_rule_name: str,
project_id: str,
zone: str,
instances: List[str],
cluster_name_on_cloud: str,
ports: List[str],
vpc_name: str,
) -> Optional[dict]:
try:
# If we have multiple instances, they are in the same cluster,
# i.e. the same VPC. So we can just pick one.
response = cls.load_resource().instances().get(
project=project_id,
zone=zone,
instance=instances[0],
).execute()
# Format: projects/PROJECT_ID/global/networks/VPC_NAME
vpc_link = response['networkInterfaces'][0]['network']
vpc_name = vpc_link.split('/')[-1]
except gcp.http_error_exception() as e:
logger.warning(
f'Failed to get VPC name for instance {instances[0]}: '
f'{e.reason}. Skip opening ports for it.')
return None
try:
body = cls.load_resource().firewalls().get(
project=project_id, firewall=firewall_rule_name).execute()
Expand Down

0 comments on commit a23e520

Please sign in to comment.