Skip to content

Commit

Permalink
fix: updates with-token flag to use const parameter
Browse files Browse the repository at this point in the history
This returns the with-token flag to its original implementation
to avoid breaking changes, but uses the const instead of default parameter.
This will make sure sys.stdin is only set if the flag is present.

Signed-off-by: Jennifer Power <[email protected]>
  • Loading branch information
jpower432 committed May 6, 2024
1 parent 79aad6e commit d40757f
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 97 deletions.
180 changes: 87 additions & 93 deletions tests/trestlebot/entrypoints/test_entrypoint_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,15 @@
@patch.dict("os.environ", {"GITHUB_ACTIONS": "true"})
def test_set_git_provider_with_github() -> None:
"""Test set_git_provider function in Entrypoint Base for GitHub Actions"""
with patch("sys.stdin", return_value=StringIO("fake_token")):
provider: Optional[GitProvider]
args = argparse.Namespace(
target_branch="main",
with_token=True,
git_provider_type="",
git_server_url="",
)
provider = EntrypointBase.set_git_provider(args=args)
assert isinstance(provider, GitHub)
provider: Optional[GitProvider]
args = argparse.Namespace(
target_branch="main",
with_token=StringIO("fake_token"),
git_provider_type="",
git_server_url="",
)
provider = EntrypointBase.set_git_provider(args=args)
assert isinstance(provider, GitHub)


@patch.dict(
Expand All @@ -41,16 +40,15 @@ def test_set_git_provider_with_github() -> None:
)
def test_set_git_provider_with_github_no_stdin() -> None:
"""Test set_git_provider function in Entrypoint Base for GitHub Actions"""
with patch("sys.stdin", return_value=StringIO("fake_token")):
provider: Optional[GitProvider]
args = argparse.Namespace(
target_branch="main",
with_token=False,
git_provider_type="",
git_server_url="",
)
provider = EntrypointBase.set_git_provider(args=args)
assert isinstance(provider, GitHub)
provider: Optional[GitProvider]
args = argparse.Namespace(
target_branch="main",
with_token=None,
git_provider_type="",
git_server_url="",
)
provider = EntrypointBase.set_git_provider(args=args)
assert isinstance(provider, GitHub)


@patch.dict(
Expand All @@ -64,63 +62,60 @@ def test_set_git_provider_with_github_no_stdin() -> None:
)
def test_set_git_provider_with_gitlab() -> None:
"""Test set_git_provider function in Entrypoint Base for GitLab CI"""
with patch("sys.stdin", return_value=StringIO("fake_token")):
provider: Optional[GitProvider]
args = argparse.Namespace(
target_branch="main",
with_token=True,
git_provider_type="",
git_server_url="",
)
provider = EntrypointBase.set_git_provider(args=args)
assert isinstance(provider, GitLab)
provider: Optional[GitProvider]
args = argparse.Namespace(
target_branch="main",
with_token=StringIO("fake_token"),
git_provider_type="",
git_server_url="",
)
provider = EntrypointBase.set_git_provider(args=args)
assert isinstance(provider, GitLab)


@patch.dict("os.environ", {"GITHUB_ACTIONS": "false", "GITLAB_CI": "true"})
def test_set_git_provider_with_gitlab_with_failure() -> None:
"""Trigger error with GitLab provider with insufficient environment variables"""
with patch("sys.stdin", return_value=StringIO("fake_token")):
args = argparse.Namespace(
target_branch="main",
with_token=True,
git_provider_type="",
git_server_url="",
)
with pytest.raises(
GitProviderException,
match="Set CI_SERVER_PROTOCOL and CI SERVER HOST environment variables",
):
EntrypointBase.set_git_provider(args=args)
args = argparse.Namespace(
target_branch="main",
with_token=StringIO("fake_token"),
git_provider_type="",
git_server_url="",
)
with pytest.raises(
GitProviderException,
match="Set CI_SERVER_PROTOCOL and CI SERVER HOST environment variables",
):
EntrypointBase.set_git_provider(args=args)


@patch.dict("os.environ", {"GITHUB_ACTIONS": "false"})
def test_set_git_provider_with_none() -> None:
"""Test set_git_provider function when no git provider is set"""
with patch("sys.stdin", return_value=StringIO("fake_token")):
provider: Optional[GitProvider]
args = argparse.Namespace(
target_branch="main",
with_token=True,
git_provider_type="",
git_server_url="",
)

with pytest.raises(
EntrypointInvalidArgException,
match="Invalid args --target-branch, --git-provider-type: "
"Could not detect Git provider from environment or inputs",
):
EntrypointBase.set_git_provider(args=args)

# Now test with no target branch which is a valid case
args = argparse.Namespace(target_branch=None)
provider = EntrypointBase.set_git_provider(args=args)
assert provider is None
provider: Optional[GitProvider]
args = argparse.Namespace(
target_branch="main",
with_token=StringIO("fake_token"),
git_provider_type="",
git_server_url="",
)

with pytest.raises(
EntrypointInvalidArgException,
match="Invalid args --target-branch, --git-provider-type: "
"Could not detect Git provider from environment or inputs",
):
EntrypointBase.set_git_provider(args=args)

# Now test with no target branch which is a valid case
args = argparse.Namespace(target_branch=None)
provider = EntrypointBase.set_git_provider(args=args)
assert provider is None


def test_set_provider_with_no_token() -> None:
"""Test set_git_provider function with no token"""
args = argparse.Namespace(target_branch="main", with_token=False)
args = argparse.Namespace(target_branch="main", with_token=None)
with pytest.raises(
EntrypointInvalidArgException,
match="Invalid args --with-token: with-token flag must be set to read from standard input "
Expand All @@ -131,33 +126,32 @@ def test_set_provider_with_no_token() -> None:

def test_set_provider_with_input() -> None:
"""Test set_git_provider function with type and server url input."""
with patch("sys.stdin", return_value=StringIO("fake_token")):
provider: Optional[GitProvider]
args = argparse.Namespace(
target_branch="main",
with_token=True,
git_provider_type="github",
git_server_url="",
)
provider = EntrypointBase.set_git_provider(args=args)
assert isinstance(provider, GitHub)
args = argparse.Namespace(
target_branch="main",
with_token=True,
git_provider_type="gitlab",
git_server_url="",
)
provider = EntrypointBase.set_git_provider(args=args)
assert isinstance(provider, GitLab)

args = argparse.Namespace(
target_branch="main",
with_token=True,
git_provider_type="github",
git_server_url="https://notgithub.com",
)
with pytest.raises(
EntrypointInvalidArgException,
match="Invalid args --server-url: GitHub provider does not support custom server URLs",
):
EntrypointBase.set_git_provider(args=args)
provider: Optional[GitProvider]
args = argparse.Namespace(
target_branch="main",
with_token=StringIO("fake_token"),
git_provider_type="github",
git_server_url="",
)
provider = EntrypointBase.set_git_provider(args=args)
assert isinstance(provider, GitHub)
args = argparse.Namespace(
target_branch="main",
with_token=StringIO("fake_token"),
git_provider_type="gitlab",
git_server_url="",
)
provider = EntrypointBase.set_git_provider(args=args)
assert isinstance(provider, GitLab)

args = argparse.Namespace(
target_branch="main",
with_token=StringIO("fake_token"),
git_provider_type="github",
git_server_url="https://notgithub.com",
)
with pytest.raises(
EntrypointInvalidArgException,
match="Invalid args --server-url: GitHub provider does not support custom server URLs",
):
EntrypointBase.set_git_provider(args=args)
9 changes: 5 additions & 4 deletions trestlebot/entrypoints/entrypoint_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,9 @@ def _set_git_provider_args(self) -> None:
git_provider_arg_group.add_argument(
"--with-token",
required=False,
default=False,
action="store_true",
nargs="?",
type=argparse.FileType("r"),
const=sys.stdin,
help="Read token from standard input for authenticated requests with \
Git provider (e.g. create pull requests)",
)
Expand Down Expand Up @@ -167,7 +168,7 @@ def set_git_provider(args: argparse.Namespace) -> Optional[GitProvider]:
"""Get the git provider based on the environment and args."""
git_provider: Optional[GitProvider] = None
if args.target_branch is not None:
if not args.with_token:
if args.with_token is None:
# Attempts to read from env var
access_token = os.environ.get("TRESTLEBOT_REPO_ACCESS_TOKEN", "")
if not access_token:
Expand All @@ -177,7 +178,7 @@ def set_git_provider(args: argparse.Namespace) -> Optional[GitProvider]:
"TRESTLEBOT_REPO_ACCESS_TOKEN environment variable when using target-branch",
)
else:
access_token = sys.stdin.read()
access_token = args.with_token.read()
try:
access_token = access_token.strip()
git_provider_type = args.git_provider_type
Expand Down

0 comments on commit d40757f

Please sign in to comment.