diff --git a/tests/trestlebot/entrypoints/test_entrypoint_base.py b/tests/trestlebot/entrypoints/test_entrypoint_base.py index c765f1e8..d1a0a56e 100644 --- a/tests/trestlebot/entrypoints/test_entrypoint_base.py +++ b/tests/trestlebot/entrypoints/test_entrypoint_base.py @@ -23,16 +23,15 @@ @patch.dict("os.environ", {"GITHUB_ACTIONS": "true"}) def test_set_git_provider_with_github() -> None: """Test set_git_provider function in Entrypoint Base for GitHub Actions""" - with patch("sys.stdin", return_value=StringIO("fake_token")): - provider: Optional[GitProvider] - args = argparse.Namespace( - target_branch="main", - with_token=True, - git_provider_type="", - git_server_url="", - ) - provider = EntrypointBase.set_git_provider(args=args) - assert isinstance(provider, GitHub) + provider: Optional[GitProvider] + args = argparse.Namespace( + target_branch="main", + with_token=StringIO("fake_token"), + git_provider_type="", + git_server_url="", + ) + provider = EntrypointBase.set_git_provider(args=args) + assert isinstance(provider, GitHub) @patch.dict( @@ -41,16 +40,15 @@ def test_set_git_provider_with_github() -> None: ) def test_set_git_provider_with_github_no_stdin() -> None: """Test set_git_provider function in Entrypoint Base for GitHub Actions""" - with patch("sys.stdin", return_value=StringIO("fake_token")): - provider: Optional[GitProvider] - args = argparse.Namespace( - target_branch="main", - with_token=False, - git_provider_type="", - git_server_url="", - ) - provider = EntrypointBase.set_git_provider(args=args) - assert isinstance(provider, GitHub) + provider: Optional[GitProvider] + args = argparse.Namespace( + target_branch="main", + with_token=None, + git_provider_type="", + git_server_url="", + ) + provider = EntrypointBase.set_git_provider(args=args) + assert isinstance(provider, GitHub) @patch.dict( @@ -64,63 +62,60 @@ def test_set_git_provider_with_github_no_stdin() -> None: ) def test_set_git_provider_with_gitlab() -> None: """Test set_git_provider function in Entrypoint Base for GitLab CI""" - with patch("sys.stdin", return_value=StringIO("fake_token")): - provider: Optional[GitProvider] - args = argparse.Namespace( - target_branch="main", - with_token=True, - git_provider_type="", - git_server_url="", - ) - provider = EntrypointBase.set_git_provider(args=args) - assert isinstance(provider, GitLab) + provider: Optional[GitProvider] + args = argparse.Namespace( + target_branch="main", + with_token=StringIO("fake_token"), + git_provider_type="", + git_server_url="", + ) + provider = EntrypointBase.set_git_provider(args=args) + assert isinstance(provider, GitLab) @patch.dict("os.environ", {"GITHUB_ACTIONS": "false", "GITLAB_CI": "true"}) def test_set_git_provider_with_gitlab_with_failure() -> None: """Trigger error with GitLab provider with insufficient environment variables""" - with patch("sys.stdin", return_value=StringIO("fake_token")): - args = argparse.Namespace( - target_branch="main", - with_token=True, - git_provider_type="", - git_server_url="", - ) - with pytest.raises( - GitProviderException, - match="Set CI_SERVER_PROTOCOL and CI SERVER HOST environment variables", - ): - EntrypointBase.set_git_provider(args=args) + args = argparse.Namespace( + target_branch="main", + with_token=StringIO("fake_token"), + git_provider_type="", + git_server_url="", + ) + with pytest.raises( + GitProviderException, + match="Set CI_SERVER_PROTOCOL and CI SERVER HOST environment variables", + ): + EntrypointBase.set_git_provider(args=args) @patch.dict("os.environ", {"GITHUB_ACTIONS": "false"}) def test_set_git_provider_with_none() -> None: """Test set_git_provider function when no git provider is set""" - with patch("sys.stdin", return_value=StringIO("fake_token")): - provider: Optional[GitProvider] - args = argparse.Namespace( - target_branch="main", - with_token=True, - git_provider_type="", - git_server_url="", - ) - - with pytest.raises( - EntrypointInvalidArgException, - match="Invalid args --target-branch, --git-provider-type: " - "Could not detect Git provider from environment or inputs", - ): - EntrypointBase.set_git_provider(args=args) - - # Now test with no target branch which is a valid case - args = argparse.Namespace(target_branch=None) - provider = EntrypointBase.set_git_provider(args=args) - assert provider is None + provider: Optional[GitProvider] + args = argparse.Namespace( + target_branch="main", + with_token=StringIO("fake_token"), + git_provider_type="", + git_server_url="", + ) + + with pytest.raises( + EntrypointInvalidArgException, + match="Invalid args --target-branch, --git-provider-type: " + "Could not detect Git provider from environment or inputs", + ): + EntrypointBase.set_git_provider(args=args) + + # Now test with no target branch which is a valid case + args = argparse.Namespace(target_branch=None) + provider = EntrypointBase.set_git_provider(args=args) + assert provider is None def test_set_provider_with_no_token() -> None: """Test set_git_provider function with no token""" - args = argparse.Namespace(target_branch="main", with_token=False) + args = argparse.Namespace(target_branch="main", with_token=None) with pytest.raises( EntrypointInvalidArgException, match="Invalid args --with-token: with-token flag must be set to read from standard input " @@ -131,33 +126,32 @@ def test_set_provider_with_no_token() -> None: def test_set_provider_with_input() -> None: """Test set_git_provider function with type and server url input.""" - with patch("sys.stdin", return_value=StringIO("fake_token")): - provider: Optional[GitProvider] - args = argparse.Namespace( - target_branch="main", - with_token=True, - git_provider_type="github", - git_server_url="", - ) - provider = EntrypointBase.set_git_provider(args=args) - assert isinstance(provider, GitHub) - args = argparse.Namespace( - target_branch="main", - with_token=True, - git_provider_type="gitlab", - git_server_url="", - ) - provider = EntrypointBase.set_git_provider(args=args) - assert isinstance(provider, GitLab) - - args = argparse.Namespace( - target_branch="main", - with_token=True, - git_provider_type="github", - git_server_url="https://notgithub.com", - ) - with pytest.raises( - EntrypointInvalidArgException, - match="Invalid args --server-url: GitHub provider does not support custom server URLs", - ): - EntrypointBase.set_git_provider(args=args) + provider: Optional[GitProvider] + args = argparse.Namespace( + target_branch="main", + with_token=StringIO("fake_token"), + git_provider_type="github", + git_server_url="", + ) + provider = EntrypointBase.set_git_provider(args=args) + assert isinstance(provider, GitHub) + args = argparse.Namespace( + target_branch="main", + with_token=StringIO("fake_token"), + git_provider_type="gitlab", + git_server_url="", + ) + provider = EntrypointBase.set_git_provider(args=args) + assert isinstance(provider, GitLab) + + args = argparse.Namespace( + target_branch="main", + with_token=StringIO("fake_token"), + git_provider_type="github", + git_server_url="https://notgithub.com", + ) + with pytest.raises( + EntrypointInvalidArgException, + match="Invalid args --server-url: GitHub provider does not support custom server URLs", + ): + EntrypointBase.set_git_provider(args=args) diff --git a/trestlebot/entrypoints/entrypoint_base.py b/trestlebot/entrypoints/entrypoint_base.py index 53ad6e22..158796e5 100644 --- a/trestlebot/entrypoints/entrypoint_base.py +++ b/trestlebot/entrypoints/entrypoint_base.py @@ -135,8 +135,9 @@ def _set_git_provider_args(self) -> None: git_provider_arg_group.add_argument( "--with-token", required=False, - default=False, - action="store_true", + nargs="?", + type=argparse.FileType("r"), + const=sys.stdin, help="Read token from standard input for authenticated requests with \ Git provider (e.g. create pull requests)", ) @@ -167,7 +168,7 @@ def set_git_provider(args: argparse.Namespace) -> Optional[GitProvider]: """Get the git provider based on the environment and args.""" git_provider: Optional[GitProvider] = None if args.target_branch is not None: - if not args.with_token: + if args.with_token is None: # Attempts to read from env var access_token = os.environ.get("TRESTLEBOT_REPO_ACCESS_TOKEN", "") if not access_token: @@ -177,7 +178,7 @@ def set_git_provider(args: argparse.Namespace) -> Optional[GitProvider]: "TRESTLEBOT_REPO_ACCESS_TOKEN environment variable when using target-branch", ) else: - access_token = sys.stdin.read() + access_token = args.with_token.read() try: access_token = access_token.strip() git_provider_type = args.git_provider_type