From 0f5048a360710b9f733d3b381715d894e821c2be Mon Sep 17 00:00:00 2001 From: Sergio Date: Fri, 9 Aug 2024 07:53:16 -0400 Subject: [PATCH] solve comments --- prowler/providers/gcp/gcp_provider.py | 68 ++++++++++++++++----------- 1 file changed, 40 insertions(+), 28 deletions(-) diff --git a/prowler/providers/gcp/gcp_provider.py b/prowler/providers/gcp/gcp_provider.py index 3af5296a156..357d56dd3ab 100644 --- a/prowler/providers/gcp/gcp_provider.py +++ b/prowler/providers/gcp/gcp_provider.py @@ -15,7 +15,7 @@ ) from prowler.lib.logger import logger from prowler.lib.utils.utils import print_boxes -from prowler.providers.common.models import Audit_Metadata +from prowler.providers.common.models import Audit_Metadata, Connection from prowler.providers.common.provider import Provider from prowler.providers.gcp.lib.mutelist.mutelist import GCPMutelist from prowler.providers.gcp.models import ( @@ -46,7 +46,7 @@ def __init__(self, arguments): self._impersonated_service_account = arguments.impersonate_service_account list_project_ids = arguments.list_project_id - self._session = GcpProvider.test_connection( + self._session = self.setup_session( credentials_file, self._impersonated_service_account ) @@ -208,37 +208,45 @@ def setup_session(credentials_file: str, service_account: str) -> Credentials: Returns: Credentials object """ - scopes = ["https://www.googleapis.com/auth/cloud-platform"] + try: + scopes = ["https://www.googleapis.com/auth/cloud-platform"] - if credentials_file: - logger.info(f"Using credentials file: {credentials_file}") - logger.info( - "GCP provider: Setting GOOGLE_APPLICATION_CREDENTIALS environment variable..." - ) - client_secrets_path = os.path.abspath(credentials_file) - os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = client_secrets_path + if credentials_file: + logger.info(f"Using credentials file: {credentials_file}") + logger.info( + "GCP provider: Setting GOOGLE_APPLICATION_CREDENTIALS environment variable..." + ) + client_secrets_path = os.path.abspath(credentials_file) + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = client_secrets_path - # Get default credentials - credentials, _ = default(scopes=scopes) + # Get default credentials + credentials, _ = default(scopes=scopes) - # Refresh the credentials to ensure they are valid - credentials.refresh(Request()) + # Refresh the credentials to ensure they are valid + credentials.refresh(Request()) - logger.info(f"Initial credentials: {credentials}") + logger.info(f"Initial credentials: {credentials}") - if service_account: - # Create the impersonated credentials - credentials = impersonated_credentials.Credentials( - source_credentials=credentials, - target_principal=service_account, - target_scopes=scopes, - ) - logger.info(f"Impersonated credentials: {credentials}") + if service_account: + # Create the impersonated credentials + credentials = impersonated_credentials.Credentials( + source_credentials=credentials, + target_principal=service_account, + target_scopes=scopes, + ) + logger.info(f"Impersonated credentials: {credentials}") - return credentials + return credentials + except Exception as error: + logger.critical( + f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" + ) + sys.exit(1) @staticmethod - def test_connection(credentials_file: str, service_account: str) -> Credentials: + def test_connection( + credentials_file: str, service_account: str, raise_on_exception: bool = True + ) -> Connection: """ Test the connection with the provided credentials file or service account to impersonate Args: @@ -252,7 +260,7 @@ def test_connection(credentials_file: str, service_account: str) -> Credentials: service = discovery.build("cloudresourcemanager", "v1", credentials=session) request = service.projects().list() request.execute() - return session + return Connection(is_connected=True) except HttpError as http_error: if "Cloud Resource Manager API has not been used" in str(http_error): logger.critical( @@ -265,12 +273,16 @@ def test_connection(credentials_file: str, service_account: str) -> Credentials: logger.critical( f"{http_error.__class__.__name__}[{http_error.__traceback__.tb_lineno}]: {http_error}" ) - raise http_error + if raise_on_exception: + raise http_error + return Connection(error=http_error) except Exception as error: logger.critical( f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" ) - raise error + if raise_on_exception: + raise error + return Connection(error=error) def print_credentials(self): # TODO: Beautify audited profile, set "default" if there is no profile set