From 70ee8a108c14ca7edac0091dedfd48eae1065c3b Mon Sep 17 00:00:00 2001 From: Zongheng Yang Date: Sun, 10 Dec 2023 19:18:20 -0800 Subject: [PATCH] GCP: Fix subnet filtering. (#2854) * Fix GCP _list_subnets() filtering. * Format * updates --- sky/skylet/providers/gcp/config.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/sky/skylet/providers/gcp/config.py b/sky/skylet/providers/gcp/config.py index 781737c6cfc5..704cb34f1570 100644 --- a/sky/skylet/providers/gcp/config.py +++ b/sky/skylet/providers/gcp/config.py @@ -818,9 +818,7 @@ def get_usable_vpc_and_subnet( if len(vpcnets_all) == 1: # Skip checking any firewall rules if the user has specified a VPC. logger.info(f"Using user-specified VPC {specific_vpc_to_use!r}.") - subnets = _list_subnets( - config, compute, filter=f'(name="{specific_vpc_to_use}")' - ) + subnets = _list_subnets(config, compute, network=specific_vpc_to_use) if not subnets: _skypilot_log_error_and_exit_for_failover( f"No subnet for region {config['provider']['region']} found for specified VPC {specific_vpc_to_use!r}. " @@ -866,7 +864,7 @@ def get_usable_vpc_and_subnet( _create_rules(config, compute, FIREWALL_RULES_TEMPLATE, SKYPILOT_VPC_NAME, proj_id) usable_vpc_name = SKYPILOT_VPC_NAME - subnets = _list_subnets(config, compute, filter=f'(name="{usable_vpc_name}")') + subnets = _list_subnets(config, compute, network=usable_vpc_name) if not subnets: _skypilot_log_error_and_exit_for_failover( f"No subnet for region {config['provider']['region']} found for generated VPC {usable_vpc_name!r}. " @@ -994,19 +992,32 @@ def _list_vpcnets(config, compute, filter=None): def _list_subnets( - config, compute, filter=None + config, compute, network=None ) -> List["google.cloud.compute_v1.types.compute.Subnetwork"]: response = ( compute.subnetworks() .list( project=config["provider"]["project_id"], region=config["provider"]["region"], - filter=filter, ) .execute() ) - return response["items"] if "items" in response else [] + items = response["items"] if "items" in response else [] + if network is None: + return items + + # Filter by network (VPC) name. + # + # Note we do not directly use the filter (network=<...>) arg of the list() + # call above, because it'd involve constructing a long URL of the following + # format and passing it as the filter value: + # 'https://www.googleapis.com/compute/v1/projects//global/networks/' + matched_items = [] + for item in items: + if network == _network_interface_to_vpc_name(item): + matched_items.append(item) + return matched_items def _get_subnet(config, subnet_id, compute):