Skip to content

Commit

Permalink
Support GCP Shared VPC for some subnets (#1933)
Browse files Browse the repository at this point in the history
  • Loading branch information
r4victor authored Oct 31, 2024
1 parent 404c77d commit 0723df6
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
18 changes: 13 additions & 5 deletions src/dstack/_internal/core/backends/gcp/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def get_availability_zones(


def check_vpc(
network_client: compute_v1.NetworksClient,
subnetworks_client: compute_v1.SubnetworksClient,
routers_client: compute_v1.RoutersClient,
project_id: str,
regions: List[str],
Expand All @@ -71,14 +71,21 @@ def check_vpc(
if shared_vpc_project_id:
vpc_project_id = shared_vpc_project_id
try:
network_client.get(project=vpc_project_id, network=vpc_name)
for region in regions:
get_vpc_subnet_or_error(
subnetworks_client=subnetworks_client,
vpc_project_id=vpc_project_id,
vpc_name=vpc_name,
region=region,
)
except google.api_core.exceptions.NotFound:
raise ComputeError(f"Failed to find VPC {vpc_name} in project {vpc_project_id}")
raise ComputeError(f"Failed to find Shared VPC project {vpc_project_id}")

if allocate_public_ip:
return

if nat_check:
# We may have no permissions to check NAT in a shared VPC
if nat_check and shared_vpc_project_id is None:
regions_without_nat = []
for region in regions:
if not has_vpc_nat_access(routers_client, vpc_project_id, vpc_name, region):
Expand Down Expand Up @@ -230,7 +237,8 @@ def get_vpc_subnet_or_error(
if network_name == vpc_name and subnet_region == region:
return subnet_resource_name
raise ComputeError(
f"No usable subnetwork found in region {region} for VPC {vpc_name} in project {vpc_project_id}"
f"No usable subnetwork found in region {region} for VPC {vpc_name} in project {vpc_project_id}."
f" Ensure that VPC {vpc_name} exists and has usable subnetworks."
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,11 @@ def _get_regions_element(
return element

def _check_config(self, config: GCPConfigInfoWithCredsPartial, credentials: Credentials):
network_client = compute_v1.NetworksClient(credentials=credentials)
subnetworks_client = compute_v1.SubnetworksClient(credentials=credentials)
routers_client = compute_v1.RoutersClient(credentials=credentials)
self._check_tags_config(config)
self._check_vpc_config(
network_client=network_client,
subnetworks_client=subnetworks_client,
routers_client=routers_client,
config=config,
)
Expand All @@ -245,14 +245,14 @@ def _check_tags_config(self, config: GCPConfigInfoWithCredsPartial):
def _check_vpc_config(
self,
config: GCPConfigInfoWithCredsPartial,
network_client: compute_v1.NetworksClient,
subnetworks_client: compute_v1.SubnetworksClient,
routers_client: compute_v1.RoutersClient,
):
allocate_public_ip = config.public_ips if config.public_ips is not None else True
nat_check = config.nat_check if config.nat_check is not None else True
try:
resources.check_vpc(
network_client=network_client,
subnetworks_client=subnetworks_client,
routers_client=routers_client,
project_id=config.project_id,
regions=config.regions or DEFAULT_REGIONS,
Expand Down

0 comments on commit 0723df6

Please sign in to comment.