Skip to content

Commit

Permalink
ENH: Add get_access_token method to clients (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
snowman2 authored Mar 1, 2024
1 parent c3a2c79 commit aed9857
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 6 deletions.
21 changes: 16 additions & 5 deletions msal_requests_auth/auth/base_auth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,33 @@ def __call__(
"""
Adds the token to the authorization header.
"""
token = self.get_access_token()
input_request.headers[
"Authorization"
] = f"{token['token_type']} {token['access_token']}"
return input_request

def get_access_token(self) -> Dict[str, str]:
"""
Retrieves the token dictionary from Azure AD.
Returns
-------
dict
"""
token = self._get_access_token()
if "access_token" not in token:
error = token.get("error")
description = token.get("error_description")
raise AuthenticationError(
f"Unable to get token. Error: {error} (Details: {description})."
)
input_request.headers[
"Authorization"
] = f"{token['token_type']} {token['access_token']}"
return input_request
return token

@abstractmethod
def _get_access_token(self) -> Dict[str, str]:
"""
Retrieves the token dictionary from Azure AD.
Abstract method to return the token dictionary from Azure AD.
Returns
-------
Expand Down
Empty file added test/test_base_client.py
Empty file.
36 changes: 36 additions & 0 deletions test/test_client_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,42 @@
from msal_requests_auth.exceptions import AuthenticationError


@patch("msal.ConfidentialClientApplication", autospec=True)
@patch(
"msal_requests_auth.auth.client_credential.ClientCredentialAuth._get_access_token"
)
def test_client_credential_auth__get_access_token__error(access_token_mock, cca_mock):
access_token_mock.return_value = {
"error": "BAD REQUEST",
"error_description": "Request to get token was bad.",
}
with pytest.raises(
AuthenticationError,
match=(
r"Unable to get token\. Error: BAD REQUEST "
r"\(Details: Request to get token was bad\.\)\."
),
):
ClientCredentialAuth(client=cca_mock, scopes=["TEST SCOPE"]).get_access_token()


@patch("msal.ConfidentialClientApplication", autospec=True)
@patch(
"msal_requests_auth.auth.client_credential.ClientCredentialAuth._get_access_token"
)
def test_client_credential_auth__get_access_token__valid(access_token_mock, cca_mock):
access_token_mock.return_value = {
"token_type": "Bearer",
"access_token": "TEST TOKEN",
}
assert ClientCredentialAuth(
client=cca_mock, scopes=["TEST SCOPE"]
).get_access_token() == {
"token_type": "Bearer",
"access_token": "TEST TOKEN",
}


@patch("msal.ConfidentialClientApplication", autospec=True)
def test_client_credential_auth__no_cache(cca_mock):
cca_mock.acquire_token_silent.return_value = None
Expand Down
36 changes: 35 additions & 1 deletion test/test_devide_code.py → test/test_device_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_device_code_auth__headless(pca_mock, headless):
@patch("msal.PublicClientApplication", autospec=True)
@patch("msal_requests_auth.auth.device_code.webbrowser")
@patch("msal_requests_auth.auth.device_code.pyperclip")
def test_device_code_auth__no_accounts__unable_to_get_token(
def test_device_code_auth__no_accounts__unable_to_get_token__call(
pyperclip_patch, webbrowser_patch, pca_mock
):
pca_mock.get_accounts.return_value = None
Expand Down Expand Up @@ -102,6 +102,40 @@ def test_device_code_auth__no_accounts__unable_to_get_token(
pyperclip_patch.copy.assert_called_with("TEST CODE")


@patch.dict(os.environ, {}, clear=True)
@patch("msal.PublicClientApplication", autospec=True)
@patch("msal_requests_auth.auth.device_code.DeviceCodeAuth._get_access_token")
def test_device_code_auth__get_access_token__error(access_token_mock, pca_mock):
access_token_mock.return_value = {
"error": "BAD REQUEST",
"error_description": "Request to get token was bad.",
}
with pytest.raises(
AuthenticationError,
match=(
r"Unable to get token\. Error: BAD REQUEST "
r"\(Details: Request to get token was bad\.\)\."
),
):
DeviceCodeAuth(client=pca_mock, scopes=["TEST SCOPE"]).get_access_token()


@patch.dict(os.environ, {}, clear=True)
@patch("msal.PublicClientApplication", autospec=True)
@patch("msal_requests_auth.auth.device_code.DeviceCodeAuth._get_access_token")
def test_device_code_auth__get_access_token__valid(access_token_mock, pca_mock):
access_token_mock.return_value = {
"token_type": "Bearer",
"access_token": "TEST TOKEN",
}
assert DeviceCodeAuth(
client=pca_mock, scopes=["TEST SCOPE"]
).get_access_token() == {
"token_type": "Bearer",
"access_token": "TEST TOKEN",
}


@patch.dict(os.environ, {}, clear=True)
@patch("msal.PublicClientApplication", autospec=True)
@patch("msal_requests_auth.auth.device_code.webbrowser")
Expand Down

0 comments on commit aed9857

Please sign in to comment.