Skip to content

Commit

Permalink
refactor: moves auto-detect logic to entrypoint base
Browse files Browse the repository at this point in the history
All reusable input/output logic should be defined in
entrypoint_base.py for easier management and readability

Signed-off-by: Jennifer Power <[email protected]>
  • Loading branch information
jpower432 committed May 8, 2024
1 parent 52ac672 commit 8c99300
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 78 deletions.
114 changes: 72 additions & 42 deletions tests/trestlebot/entrypoints/test_entrypoint_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@

import argparse
from io import StringIO
from typing import Optional
from typing import Dict, Optional
from unittest.mock import patch

import pytest

from tests.testutils import args_dict_to_list
from trestlebot.entrypoints.entrypoint_base import (
EntrypointBase,
EntrypointInvalidArgException,
Expand All @@ -20,35 +21,46 @@
from trestlebot.provider import GitProvider, GitProviderException


@pytest.fixture
def valid_args_dict() -> Dict[str, str]:
return {
"branch": "main",
"committer-name": "test",
"committer-email": "[email protected]",
"working-dir": ".",
"file-patterns": ".",
"target-branch": "main",
}


def setup_base_cli() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Test parser")
EntrypointBase(parser=parser)
return parser.parse_args()


@patch.dict("os.environ", {"GITHUB_ACTIONS": "true"})
def test_set_git_provider_with_github() -> None:
def test_base_cli_with_github(valid_args_dict: Dict[str, str]) -> None:
"""Test set_git_provider function in Entrypoint Base for GitHub Actions"""
provider: Optional[GitProvider]
args = argparse.Namespace(
target_branch="main",
with_token=StringIO("fake_token"),
git_provider_type="",
git_server_url="",
)
provider = EntrypointBase.set_git_provider(args=args)
assert isinstance(provider, GitHub)
with patch("sys.argv", ["trestlebot", *args_dict_to_list(valid_args_dict)]):
args = setup_base_cli()
vars(args)["with_token"] = StringIO("fake_token")
provider = EntrypointBase.set_git_provider(args=args)
assert isinstance(provider, GitHub)


@patch.dict(
"os.environ",
{"GITHUB_ACTIONS": "true", "TRESTLEBOT_REPO_ACCESS_TOKEN": "fake_token"},
)
def test_set_git_provider_with_github_no_stdin() -> None:
def test_base_cli_with_github_no_stdin(valid_args_dict: Dict[str, str]) -> None:
"""Test set_git_provider function in Entrypoint Base for GitHub Actions"""
provider: Optional[GitProvider]
args = argparse.Namespace(
target_branch="main",
with_token=None,
git_provider_type="",
git_server_url="",
)
provider = EntrypointBase.set_git_provider(args=args)
assert isinstance(provider, GitHub)
with patch("sys.argv", ["trestlebot", *args_dict_to_list(valid_args_dict)]):
args = setup_base_cli()
provider = EntrypointBase.set_git_provider(args=args)
assert isinstance(provider, GitHub)


@patch.dict(
Expand All @@ -60,33 +72,38 @@ def test_set_git_provider_with_github_no_stdin() -> None:
"CI_SERVER_HOST": "test-gitlab.com",
},
)
def test_set_git_provider_with_gitlab() -> None:
def test_base_cli_with_gitlab(valid_args_dict: Dict[str, str]) -> None:
"""Test set_git_provider function in Entrypoint Base for GitLab CI"""
provider: Optional[GitProvider]
args = argparse.Namespace(
target_branch="main",
with_token=StringIO("fake_token"),
git_provider_type="",
git_server_url="",
)
provider = EntrypointBase.set_git_provider(args=args)
assert isinstance(provider, GitLab)
with patch("sys.argv", ["trestlebot", *args_dict_to_list(valid_args_dict)]):
args = setup_base_cli()
vars(args)["with_token"] = StringIO("fake_token")
provider = EntrypointBase.set_git_provider(args=args)
assert isinstance(provider, GitLab)


@patch.dict("os.environ", {"GITHUB_ACTIONS": "false", "GITLAB_CI": "true"})
def test_set_git_provider_with_gitlab_with_failure() -> None:
def test_base_cli_with_gitlab_with_failure(valid_args_dict: Dict[str, str]) -> None:
"""Trigger error with GitLab provider with insufficient environment variables"""
args = argparse.Namespace(
target_branch="main",
with_token=StringIO("fake_token"),
git_provider_type="",
git_server_url="",
)
with pytest.raises(
GitProviderException,
match="Set CI_SERVER_PROTOCOL and CI SERVER HOST environment variables",
):
EntrypointBase.set_git_provider(args=args)
with patch("sys.argv", ["trestlebot", *args_dict_to_list(valid_args_dict)]):
with pytest.raises(
GitProviderException,
match="Set CI_SERVER_PROTOCOL and CI SERVER HOST environment variables",
):
setup_base_cli()


@patch.dict("os.environ", {"GITHUB_ACTIONS": "false", "GITLAB_CI": "false"})
def test_base_cli_with_provide_type_set(valid_args_dict: Dict[str, str]) -> None:
"""Trigger error with GitLab provider with insufficient environment variables"""
args_dict = valid_args_dict
args_dict["git-provider-type"] = "gitlab"
args_dict["git-server-url"] = "https://mygitlab.com"
with patch("sys.argv", ["trestlebot", *args_dict_to_list(valid_args_dict)]):
args = setup_base_cli()
vars(args)["with_token"] = StringIO("fake_token")
provider = EntrypointBase.set_git_provider(args=args)
assert isinstance(provider, GitLab)


@patch.dict("os.environ", {"GITHUB_ACTIONS": "false"})
Expand All @@ -103,7 +120,7 @@ def test_set_git_provider_with_none() -> None:
with pytest.raises(
EntrypointInvalidArgException,
match="Invalid args --target-branch, --git-provider-type: "
"Could not detect Git provider from environment or inputs",
"Could not determine Git provider from inputs",
):
EntrypointBase.set_git_provider(args=args)

Expand Down Expand Up @@ -156,6 +173,19 @@ def test_set_provider_with_input() -> None:
)
with pytest.raises(
EntrypointInvalidArgException,
match="Invalid args --server-url: GitHub provider does not support custom server URLs",
match="Invalid args --git-server-url: GitHub provider does not support custom server URLs",
):
EntrypointBase.set_git_provider(args=args)

args = argparse.Namespace(
target_branch="main",
with_token=StringIO("fake_token"),
git_provider_type="",
git_server_url="https://github.com",
)
with pytest.raises(
EntrypointInvalidArgException,
match="Invalid args --git-provider-type: git-provider-type must be set when using "
"git-server-url",
):
EntrypointBase.set_git_provider(args=args)
1 change: 1 addition & 0 deletions trestlebot/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,4 @@
# Git Provider Types
GITHUB = "github"
GITLAB = "gitlab"
GITHUB_SERVER_URL = "https://github.com"
37 changes: 34 additions & 3 deletions trestlebot/entrypoints/entrypoint_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
import logging
import os
import sys
from typing import List, Optional
from typing import List, Optional, Tuple

from trestlebot import const
from trestlebot.bot import TrestleBot
from trestlebot.github import GitHubActionsResultsReporter, is_github_actions
from trestlebot.gitlab import GitLabCIResultsReporter, is_gitlab_ci
from trestlebot.gitlab import GitLabCIResultsReporter, get_gitlab_root_url, is_gitlab_ci
from trestlebot.provider import GitProvider
from trestlebot.provider_factory import GitProviderFactory
from trestlebot.reporter import BotResults, ResultsReporter
Expand Down Expand Up @@ -125,6 +125,10 @@ def _set_git_provider_args(self) -> None:
git_provider_arg_group = self.parser.add_argument_group(
"optional arguments for interacting with the git provider"
)

# Detect default args for git provider type and server url
detected_provider_type, detected_server_url = load_provider_from_environment()

git_provider_arg_group.add_argument(
"--target-branch",
type=str,
Expand Down Expand Up @@ -152,13 +156,15 @@ def _set_git_provider_args(self) -> None:
"--git-provider-type",
required=False,
choices=[const.GITHUB, const.GITLAB],
default=detected_provider_type,
help="Optional supported Git provider to identify. "
"Defaults to auto detection based on pre-defined CI environment variables.",
)
git_provider_arg_group.add_argument(
"--git-server-url",
type=str,
required=False,
default=detected_server_url,
help="Optional git server url for supported type. "
"Defaults to auto detection based on pre-defined CI environment variables.",
)
Expand All @@ -183,11 +189,16 @@ def set_git_provider(args: argparse.Namespace) -> Optional[GitProvider]:
access_token = access_token.strip()
git_provider_type = args.git_provider_type
git_server_url = args.git_server_url
if git_server_url and not git_provider_type:
raise EntrypointInvalidArgException(
"--git-provider-type",
"git-provider-type must be set when using git-server-url",
)
git_provider = GitProviderFactory.provider_factory(
access_token, git_provider_type, git_server_url
)
except ValueError as e:
raise EntrypointInvalidArgException("--server-url", str(e))
raise EntrypointInvalidArgException("--git-server-url", str(e))
except RuntimeError as e:
raise EntrypointInvalidArgException(
"--target-branch, --git-provider-type", str(e)
Expand Down Expand Up @@ -233,6 +244,26 @@ def run_base(self, args: argparse.Namespace, pre_tasks: List[TaskBase]) -> None:
results_reporter.report_results(results)


def load_provider_from_environment() -> Tuple[str, str]:
"""
Detect the Git provider from the environment.
Returns:
A tuple with the provider type string and server url string
Note:
The environment variables are expected to be pre-defined
and set through the CI environment and not set by the user.
"""
if is_github_actions():
logging.debug("Detected GitHub Actions environment")
return const.GITHUB, const.GITHUB_SERVER_URL
elif is_gitlab_ci():
logging.debug("Detected GitLab CI environment")
return const.GITLAB, get_gitlab_root_url()
return "", ""


def comma_sep_to_list(string: str) -> List[str]:
"""Convert comma-sep string to list of strings and strip."""
string = string.strip() if string else ""
Expand Down
44 changes: 11 additions & 33 deletions trestlebot/provider_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from typing import Optional

from trestlebot import const
from trestlebot.github import GitHub, is_github_actions
from trestlebot.gitlab import GitLab, get_gitlab_root_url, is_gitlab_ci
from trestlebot.github import GitHub
from trestlebot.gitlab import GitLab
from trestlebot.provider import GitProvider


Expand All @@ -17,65 +17,43 @@ class GitProviderFactory:
"""Factory class for creating Git provider objects"""

@staticmethod
def provider_factory(
access_token: str, type: str = "", server_url: str = ""
) -> GitProvider:
def provider_factory(access_token: str, type: str, server_url: str) -> GitProvider:
"""
Factory class for creating Git provider objects
Args:
access_token: Access token for the Git provider
type: Type of Git provider. Supported values are "github" or "gitlab"
server_url: URL of the Git provider server
server_url: URL of the Git provider server.
Returns:
a GitProvider object
Notes:
If type is not provided, the method will attempt to detect the Git provider from the
environment.
Raises:
ValueError: If the server URL is provided for GitHub provider
RuntimeError: If the Git provider cannot be detected
Note: The GitHub provider currently only support GitHub and not
GitHub Enterprise. So the server value must be https://github.com.
"""

git_provider: Optional[GitProvider] = None

if type == const.GITHUB:
logger.debug("Creating GitHub provider")
if server_url and server_url != "https://github.com":
if server_url and server_url != const.GITHUB_SERVER_URL:
raise ValueError("GitHub provider does not support custom server URLs")
git_provider = GitHub(access_token=access_token)
elif type == const.GITLAB:
logger.debug("Creating GitLab provider")
if not server_url:
# No server URL will use default https://gitlab.com
git_provider = GitLab(api_token=access_token)
else:
git_provider = GitLab(api_token=access_token, server_url=server_url)
else:
logger.debug(
"No type or server_url provided."
"Detecting Git provider from environment."
)
git_provider = GitProviderFactory._detect_from_environment(access_token)

if git_provider is None:
raise RuntimeError(
"Could not detect Git provider from environment or inputs"
)

return git_provider
raise RuntimeError("Could not determine Git provider from inputs")

@staticmethod
def _detect_from_environment(access_token: str) -> Optional[GitProvider]:
"""Detect the Git provider from the environment"""
git_provider: Optional[GitProvider] = None
if is_github_actions():
logging.debug("Detected GitHub Actions environment")
git_provider = GitHub(access_token=access_token)
elif is_gitlab_ci():
logging.debug("Detected GitLab CI environment")
server_api_url = get_gitlab_root_url()
git_provider = GitLab(api_token=access_token, server_url=server_api_url)
return git_provider

0 comments on commit 8c99300

Please sign in to comment.