Skip to content

Commit

Permalink
Extract project ID from GCP service connector SA credentials (#2708)
Browse files Browse the repository at this point in the history
* Detect project ID mismatch in GCP service connector SA credentials

* Detect project ID from service account creds

* Fix docstrings

---------

Co-authored-by: Safoine El Khabich <[email protected]>
Co-authored-by: Hamza Tahir <[email protected]>
  • Loading branch information
3 people authored Jun 26, 2024
1 parent 5cf196d commit d9c2a5e
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/zenml/integrations/gcp/google_credentials_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _get_authentication(self) -> Tuple["Credentials", str]:
"trying to use the linked connector, but got "
f"{type(credentials)}."
)
return credentials, connector.config.project_id
return credentials, connector.config.gcp_project_id

if self.config.service_account_path:
credentials, project_id = load_credentials_from_file(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,24 +351,72 @@ class GCPOAuth2Token(AuthenticationConfig):
class GCPBaseConfig(AuthenticationConfig):
"""GCP base configuration."""

@property
def gcp_project_id(self) -> str:
"""Get the GCP project ID.
This method must be implemented by subclasses to ensure that the GCP
project ID is always available.
Raises:
NotImplementedError: If the method is not implemented.
"""
raise NotImplementedError


class GCPBaseProjectIDConfig(GCPBaseConfig):
"""GCP base configuration with included project ID."""

project_id: str = Field(
title="GCP Project ID where the target resource is located.",
)

@property
def gcp_project_id(self) -> str:
"""Get the GCP project ID.
Returns:
The GCP project ID.
"""
return self.project_id

class GCPUserAccountConfig(GCPBaseConfig, GCPUserAccountCredentials):

class GCPUserAccountConfig(GCPBaseProjectIDConfig, GCPUserAccountCredentials):
"""GCP user account configuration."""


class GCPServiceAccountConfig(GCPBaseConfig, GCPServiceAccountCredentials):
"""GCP service account configuration."""

_project_id: Optional[str] = None

@property
def gcp_project_id(self) -> str:
"""Get the GCP project ID.
When a service account JSON is provided, the project ID can be extracted
from it instead of being provided explicitly.
class GCPExternalAccountConfig(GCPBaseConfig, GCPExternalAccountCredentials):
Returns:
The GCP project ID.
"""
if self._project_id is None:
self._project_id = json.loads(
self.service_account_json.get_secret_value()
)["project_id"]
# Guaranteed by the field validator
assert self._project_id is not None

return self._project_id


class GCPExternalAccountConfig(
GCPBaseProjectIDConfig, GCPExternalAccountCredentials
):
"""GCP external account configuration."""


class GCPOAuth2TokenConfig(GCPBaseConfig, GCPOAuth2Token):
class GCPOAuth2TokenConfig(GCPBaseProjectIDConfig, GCPOAuth2Token):
"""GCP OAuth 2.0 configuration."""

service_account_email: Optional[str] = Field(
Expand Down Expand Up @@ -540,7 +588,7 @@ def _get_security_credentials(
configured project has to be the same as the project of the attached service
account.
""",
config_class=GCPBaseConfig,
config_class=GCPBaseProjectIDConfig,
),
AuthenticationMethodModel(
name="GCP User Account",
Expand Down Expand Up @@ -1006,6 +1054,7 @@ def _authenticate(
# service account authentication)

assert isinstance(cfg, GCPServiceAccountConfig)

credentials = (
gcp_service_account.Credentials.from_service_account_info(
json.loads(
Expand Down Expand Up @@ -1115,7 +1164,7 @@ def _parse_gcr_resource_id(
#
# We need to extract the project ID and registry ID from
# the provided resource ID
config_project_id = self.config.project_id
config_project_id = self.config.gcp_project_id
project_id: Optional[str] = None
# A GCR repository URI uses one of several hostnames (gcr.io, us.gcr.io,
# eu.gcr.io, asia.gcr.io etc.) and the project ID is the first part of
Expand Down Expand Up @@ -1219,9 +1268,9 @@ def _get_default_resource_id(self, resource_type: str) -> str:
authorized.
"""
if resource_type == GCP_RESOURCE_TYPE:
return self.config.project_id
return self.config.gcp_project_id
elif resource_type == DOCKER_REGISTRY_RESOURCE_TYPE:
return f"gcr.io/{self.config.project_id}"
return f"gcr.io/{self.config.gcp_project_id}"

raise RuntimeError(
f"Default resource ID not supported for '{resource_type}' resource "
Expand Down Expand Up @@ -1278,7 +1327,7 @@ def _connect_to_resource(

# Create an GCS client for the bucket
client = storage.Client(
project=self.config.project_id, credentials=credentials
project=self.config.gcp_project_id, credentials=credentials
)
return client

Expand Down Expand Up @@ -1384,7 +1433,7 @@ def _configure_local_client(
"config",
"set",
"project",
self.config.project_id,
self.config.gcp_project_id,
],
check=True,
stderr=subprocess.STDOUT,
Expand Down Expand Up @@ -1488,7 +1537,7 @@ def _auto_configure(
)

if auth_method == GCPAuthenticationMethods.IMPLICIT:
auth_config = GCPBaseConfig(
auth_config = GCPBaseProjectIDConfig(
project_id=project_id,
)
elif auth_method == GCPAuthenticationMethods.OAUTH2_TOKEN:
Expand Down Expand Up @@ -1697,7 +1746,7 @@ def _verify(

if resource_type == GCS_RESOURCE_TYPE:
gcs_client = storage.Client(
project=self.config.project_id, credentials=credentials
project=self.config.gcp_project_id, credentials=credentials
)
if not resource_id:
# List all GCS buckets
Expand Down Expand Up @@ -1736,7 +1785,7 @@ def _verify(
# List all GKE clusters
try:
clusters = gke_client.list_clusters(
parent=f"projects/{self.config.project_id}/locations/-"
parent=f"projects/{self.config.gcp_project_id}/locations/-"
)
cluster_names = [cluster.name for cluster in clusters.clusters]
except google.api_core.exceptions.GoogleAPIError as e:
Expand Down Expand Up @@ -1810,7 +1859,7 @@ def _get_connector_client(
# object
auth_method: str = GCPAuthenticationMethods.OAUTH2_TOKEN
config: GCPBaseConfig = GCPOAuth2TokenConfig(
project_id=self.config.project_id,
project_id=self.config.gcp_project_id,
token=credentials.token,
service_account_email=credentials.signer_email
if hasattr(credentials, "signer_email")
Expand Down Expand Up @@ -1884,7 +1933,7 @@ def _get_connector_client(
# List all GKE clusters
try:
clusters = gke_client.list_clusters(
parent=f"projects/{self.config.project_id}/locations/-"
parent=f"projects/{self.config.gcp_project_id}/locations/-"
)
cluster_map = {
cluster.name: cluster for cluster in clusters.clusters
Expand Down Expand Up @@ -1928,7 +1977,7 @@ def _get_connector_client(
auth_method=KubernetesAuthenticationMethods.TOKEN,
resource_type=resource_type,
config=KubernetesTokenConfig(
cluster_name=f"gke_{self.config.project_id}_{cluster_name}",
cluster_name=f"gke_{self.config.gcp_project_id}_{cluster_name}",
certificate_authority=cluster_ca_cert,
server=f"https://{cluster_server}",
token=bearer_token,
Expand Down

0 comments on commit d9c2a5e

Please sign in to comment.