Skip to content

Commit

Permalink
feat(gcp): add a test_connection method (#4616)
Browse files Browse the repository at this point in the history
Co-authored-by: Rubén De la Torre Vico <[email protected]>
  • Loading branch information
MrCloudSec and puchy22 authored Aug 19, 2024
1 parent a126fd8 commit 84a76f4
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 9 deletions.
63 changes: 54 additions & 9 deletions prowler/providers/gcp/gcp_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from prowler.lib.logger import logger
from prowler.lib.utils.utils import print_boxes
from prowler.providers.common.models import Audit_Metadata
from prowler.providers.common.models import Audit_Metadata, Connection
from prowler.providers.common.provider import Provider
from prowler.providers.gcp.lib.mutelist.mutelist import GCPMutelist
from prowler.providers.gcp.models import (
Expand Down Expand Up @@ -198,7 +198,8 @@ def get_output_mapping(self):
# "partition": "identity.partition",
}

def setup_session(self, credentials_file: str, service_account: str) -> Credentials:
@staticmethod
def setup_session(credentials_file: str, service_account: str) -> Credentials:
"""
Setup the GCP session with the provided credentials file or service account to impersonate
Args:
Expand All @@ -212,7 +213,11 @@ def setup_session(self, credentials_file: str, service_account: str) -> Credenti

if credentials_file:
logger.info(f"Using credentials file: {credentials_file}")
self.__set_gcp_creds_env_var__(credentials_file)
logger.info(
"GCP provider: Setting GOOGLE_APPLICATION_CREDENTIALS environment variable..."
)
client_secrets_path = os.path.abspath(credentials_file)
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = client_secrets_path

# Get default credentials
credentials, _ = default(scopes=scopes)
Expand All @@ -238,12 +243,52 @@ def setup_session(self, credentials_file: str, service_account: str) -> Credenti
)
sys.exit(1)

def __set_gcp_creds_env_var__(self, credentials_file):
logger.info(
"GCP provider: Setting GOOGLE_APPLICATION_CREDENTIALS environment variable..."
)
client_secrets_path = os.path.abspath(credentials_file)
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = client_secrets_path
@staticmethod
def test_connection(
credentials_file: str = None,
service_account: str = None,
raise_on_exception: bool = True,
) -> Connection:
"""
Test the connection to GCP with the provided credentials file or service account to impersonate.
If the connection is successful, return a Connection object with is_connected set to True. If the connection fails, return a Connection object with error set to the exception.
Raise an exception if raise_on_exception is set to True.
If the Cloud Resource Manager API has not been used before or it is disabled, log a critical message and return a Connection object with error set to the exception.
Args:
credentials_file: str
service_account: str
Returns:
Connection object with is_connected set to True if the connection is successful, or error set to the exception if the connection fails
"""
try:
session = GcpProvider.setup_session(credentials_file, service_account)
service = discovery.build("cloudresourcemanager", "v1", credentials=session)
request = service.projects().list()
request.execute()
return Connection(is_connected=True)
except HttpError as http_error:
if "Cloud Resource Manager API has not been used" in str(http_error):
logger.critical(
"Cloud Resource Manager API has not been used before or it is disabled. Enable it by visiting https://console.developers.google.com/apis/api/cloudresourcemanager.googleapis.com/ then retry."
)
if raise_on_exception:
raise Exception(
"Cloud Resource Manager API has not been used before or it is disabled. Enable it by visiting https://console.developers.google.com/apis/api/cloudresourcemanager.googleapis.com/ then retry."
)
else:
logger.critical(
f"{http_error.__class__.__name__}[{http_error.__traceback__.tb_lineno}]: {http_error}"
)
if raise_on_exception:
raise http_error
return Connection(error=http_error)
except Exception as error:
logger.critical(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
if raise_on_exception:
raise error
return Connection(error=error)

def print_credentials(self):
# TODO: Beautify audited profile, set "default" if there is no profile set
Expand Down
78 changes: 78 additions & 0 deletions tests/providers/gcp/gcp_provider_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from prowler.providers.gcp.gcp_provider import GcpProvider
from prowler.providers.gcp.models import GCPIdentityInfo, GCPOutputOptions, GCPProject
from tests.providers.gcp.gcp_fixtures import mock_api_client


class TestGCPProvider:
Expand All @@ -33,6 +34,13 @@ def test_gcp_provider(self):
lifecycle_state="",
)
}

mocked_service = MagicMock()

mocked_service.projects.list.return_value = MagicMock(
execute=MagicMock(return_value={"projects": projects})
)

with patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.setup_session",
return_value=None,
Expand All @@ -42,6 +50,9 @@ def test_gcp_provider(self):
), patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.update_projects_with_organizations",
return_value=None,
), patch(
"prowler.providers.gcp.gcp_provider.discovery.build",
return_value=mocked_service,
):
gcp_provider = GcpProvider(arguments)
assert gcp_provider.session is None
Expand Down Expand Up @@ -79,6 +90,12 @@ def test_gcp_provider_output_options(self):
lifecycle_state="",
)
}

mocked_service = MagicMock()

mocked_service.projects.list.return_value = MagicMock(
execute=MagicMock(return_value={"projects": projects})
)
with patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.setup_session",
return_value=None,
Expand All @@ -88,6 +105,12 @@ def test_gcp_provider_output_options(self):
), patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.update_projects_with_organizations",
return_value=None,
), patch(
"prowler.providers.gcp.gcp_provider.discovery.build",
new=mock_api_client,
), patch(
"prowler.providers.gcp.gcp_provider.discovery.build",
return_value=mocked_service,
):
gcp_provider = GcpProvider(arguments)
# This is needed since the output_options requires to get the global provider to get the audit config
Expand Down Expand Up @@ -152,6 +175,12 @@ def test_is_project_matching(self):
lifecycle_state="",
)
}

mocked_service = MagicMock()

mocked_service.projects.list.return_value = MagicMock(
execute=MagicMock(return_value={"projects": projects})
)
with patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.setup_session",
return_value=None,
Expand All @@ -161,6 +190,9 @@ def test_is_project_matching(self):
), patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.update_projects_with_organizations",
return_value=None,
), patch(
"prowler.providers.gcp.gcp_provider.discovery.build",
return_value=mocked_service,
):
gcp_provider = GcpProvider(arguments)

Expand Down Expand Up @@ -204,6 +236,12 @@ def test_setup_session_with_credentials_file_no_impersonate(self):
lifecycle_state="",
)
}

mocked_service = MagicMock()

mocked_service.projects.list.return_value = MagicMock(
execute=MagicMock(return_value={"projects": projects})
)
with patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.get_projects",
return_value=projects,
Expand All @@ -216,6 +254,9 @@ def test_setup_session_with_credentials_file_no_impersonate(self):
), patch(
"prowler.providers.gcp.gcp_provider.default",
return_value=(mocked_credentials, MagicMock()),
), patch(
"prowler.providers.gcp.gcp_provider.discovery.build",
return_value=mocked_service,
):
gcp_provider = GcpProvider(arguments)
assert environ["GOOGLE_APPLICATION_CREDENTIALS"] == "test_credentials_file"
Expand Down Expand Up @@ -246,6 +287,12 @@ def test_setup_session_with_credentials_file_and_impersonate(self):
lifecycle_state="",
)
}

mocked_service = MagicMock()

mocked_service.projects.list.return_value = MagicMock(
execute=MagicMock(return_value={"projects": projects})
)
with patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.get_projects",
return_value=projects,
Expand All @@ -258,6 +305,9 @@ def test_setup_session_with_credentials_file_and_impersonate(self):
), patch(
"prowler.providers.gcp.gcp_provider.default",
return_value=(mocked_credentials, MagicMock()),
), patch(
"prowler.providers.gcp.gcp_provider.discovery.build",
return_value=mocked_service,
):
gcp_provider = GcpProvider(arguments)
assert environ["GOOGLE_APPLICATION_CREDENTIALS"] == "test_credentials_file"
Expand Down Expand Up @@ -296,6 +346,12 @@ def test_print_credentials_default_options(self, capsys):
lifecycle_state="",
)
}

mocked_service = MagicMock()

mocked_service.projects.list.return_value = MagicMock(
execute=MagicMock(return_value={"projects": projects})
)
with patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.get_projects",
return_value=projects,
Expand All @@ -308,6 +364,9 @@ def test_print_credentials_default_options(self, capsys):
), patch(
"prowler.providers.gcp.gcp_provider.default",
return_value=(mocked_credentials, MagicMock()),
), patch(
"prowler.providers.gcp.gcp_provider.discovery.build",
return_value=mocked_service,
):
gcp_provider = GcpProvider(arguments)
gcp_provider.print_credentials()
Expand Down Expand Up @@ -345,6 +404,12 @@ def test_print_credentials_impersonated_service_account(self, capsys):
lifecycle_state="",
)
}

mocked_service = MagicMock()

mocked_service.projects.list.return_value = MagicMock(
execute=MagicMock(return_value={"projects": projects})
)
with patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.get_projects",
return_value=projects,
Expand All @@ -357,6 +422,9 @@ def test_print_credentials_impersonated_service_account(self, capsys):
), patch(
"prowler.providers.gcp.gcp_provider.default",
return_value=(mocked_credentials, MagicMock()),
), patch(
"prowler.providers.gcp.gcp_provider.discovery.build",
return_value=mocked_service,
):
gcp_provider = GcpProvider(arguments)
gcp_provider.print_credentials()
Expand Down Expand Up @@ -401,6 +469,13 @@ def test_print_credentials_excluded_project_ids(self, capsys):
lifecycle_state="",
),
}

mocked_service = MagicMock()

mocked_service.projects.list.return_value = MagicMock(
execute=MagicMock(return_value={"projects": projects})
)

with patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.get_projects",
return_value=projects,
Expand All @@ -413,6 +488,9 @@ def test_print_credentials_excluded_project_ids(self, capsys):
), patch(
"prowler.providers.gcp.gcp_provider.default",
return_value=(mocked_credentials, MagicMock()),
), patch(
"prowler.providers.gcp.gcp_provider.discovery.build",
return_value=mocked_service,
):
gcp_provider = GcpProvider(arguments)
gcp_provider.print_credentials()
Expand Down

0 comments on commit 84a76f4

Please sign in to comment.