diff --git a/README.md b/README.md index 02704d3..86187a7 100644 --- a/README.md +++ b/README.md @@ -146,7 +146,7 @@ user data and OpenID Connect claim names. Hence the underlying data source must same names as the [standard claims of OpenID Connect](http://openid.net/specs/openid-connect-core-1_0.html#StandardClaims). ```python -from oic.oic.message import AuthorizationRequest +from pyop.message import AuthorizationRequest from pyop.util import should_fragment_encode diff --git a/src/pyop/authz_state.py b/src/pyop/authz_state.py index 07104e0..acaea9b 100644 --- a/src/pyop/authz_state.py +++ b/src/pyop/authz_state.py @@ -3,8 +3,8 @@ import uuid from oic.extension.message import TokenIntrospectionResponse -from oic.oic.message import AuthorizationRequest +from .message import AuthorizationRequest from .access_token import AccessToken from .exceptions import InvalidAccessToken from .exceptions import InvalidAuthorizationCode @@ -80,7 +80,7 @@ def __init__(self, subject_identifier_factory, authorization_code_db=None, acces self.subject_identifiers = subject_identifier_db if subject_identifier_db is not None else {} def create_authorization_code(self, authorization_request, subject_identifier, scope=None): - # type: (oic.oic.message.AuthorizationRequest, str, Optional[List[str]]) -> str + # type: (AuthorizationRequest, str, Optional[List[str]]) -> str """ Creates an authorization code bound to the authorization request and the authenticated user identified by the subject identifier. @@ -106,7 +106,7 @@ def create_authorization_code(self, authorization_request, subject_identifier, s return authorization_code def create_access_token(self, authorization_request, subject_identifier, scope=None): - # type: (oic.oic.message.AuthorizationRequest, str, Optional[List[str]]) -> se_leg_op.access_token.AccessToken + # type: (AuthorizationRequest, str, Optional[List[str]]) -> se_leg_op.access_token.AccessToken """ Creates an access token bound to the authentication request and the authenticated user identified by the subject identifier. @@ -315,7 +315,7 @@ def get_user_id_for_subject_identifier(self, subject_identifier): raise InvalidSubjectIdentifier('{} unknown'.format(subject_identifier)) def get_authorization_request_for_code(self, authorization_code): - # type: (str) -> oic.oic.message.AuthorizationRequest + # type: (str) -> AuthorizationRequest if authorization_code not in self.authorization_codes: raise InvalidAuthorizationCode('{} unknown'.format(authorization_code)) @@ -323,14 +323,14 @@ def get_authorization_request_for_code(self, authorization_code): self.authorization_codes[authorization_code][self.KEY_AUTHORIZATION_REQUEST]) def get_authorization_request_for_access_token(self, access_token_value): - # type: (str) -> oic.oic.message.AuthorizationRequest + # type: (str) -> if access_token_value not in self.access_tokens: raise InvalidAccessToken('{} unknown'.format(access_token_value)) return AuthorizationRequest().from_dict(self.access_tokens[access_token_value][self.KEY_AUTHORIZATION_REQUEST]) def get_subject_identifier_for_code(self, authorization_code): - # type: (str) -> oic.oic.message.AuthorizationRequest + # type: (str) -> AuthorizationRequest if authorization_code not in self.authorization_codes: raise InvalidAuthorizationCode('{} unknown'.format(authorization_code)) diff --git a/src/pyop/message.py b/src/pyop/message.py new file mode 100644 index 0000000..2e016b4 --- /dev/null +++ b/src/pyop/message.py @@ -0,0 +1,29 @@ +from oic.oauth2.message import SINGLE_OPTIONAL_STRING +from oic.oic import message + +class AccessTokenRequest(message.AccessTokenRequest): + c_param = message.AccessTokenRequest.c_param.copy() + c_param.update( + { + 'code_verifier': SINGLE_OPTIONAL_STRING + } + ) + +class AuthorizationRequest(message.AuthorizationRequest): + c_param = message.AuthorizationRequest.c_param.copy() + c_param.update( + { + 'code_challenge': SINGLE_OPTIONAL_STRING, + 'code_challenge_method': SINGLE_OPTIONAL_STRING + } + ) + + c_allowed_values = message.AuthorizationRequest.c_allowed_values.copy() + c_allowed_values.update( + { + "code_challenge_method": [ + "plain", + "S256" + ] + } + ) diff --git a/src/pyop/provider.py b/src/pyop/provider.py index 1ff0c8e..e69140c 100644 --- a/src/pyop/provider.py +++ b/src/pyop/provider.py @@ -11,9 +11,7 @@ from oic.exception import MessageException from oic.oic import PREFERENCE2PROVIDER from oic.oic import scope2claims -from oic.oic.message import AccessTokenRequest from oic.oic.message import AccessTokenResponse -from oic.oic.message import AuthorizationRequest from oic.oic.message import AuthorizationResponse from oic.oic.message import EndSessionRequest from oic.oic.message import EndSessionResponse @@ -23,7 +21,10 @@ from oic.oic.message import RefreshAccessTokenRequest from oic.oic.message import RegistrationRequest from oic.oic.message import RegistrationResponse +from oic.extension.provider import Provider as OICProviderExtensions +from .message import AuthorizationRequest +from .message import AccessTokenRequest from .access_token import extract_bearer_token_from_http_request from .client_authentication import verify_client_authentication from .exceptions import AuthorizationError @@ -81,7 +82,7 @@ def __init__(self, signing_key, configuration_information, authz_state, clients, self.userinfo = userinfo self.id_token_lifetime = id_token_lifetime - self.authentication_request_validators = [] # type: List[Callable[[oic.oic.message.AuthorizationRequest], Boolean]] + self.authentication_request_validators = [] # type: List[Callable[[AuthorizationRequest], Boolean]] self.authentication_request_validators.append(authorization_request_verify) self.authentication_request_validators.append( functools.partial(client_id_is_known, self)) @@ -114,7 +115,7 @@ def jwks(self): return {'keys': keys} def parse_authentication_request(self, request_body, http_headers=None): - # type: (str, Optional[Mapping[str, str]]) -> oic.oic.message.AuthorizationRequest + # type: (str, Optional[Mapping[str, str]]) -> AuthorizationRequest """ Parses and verifies an authentication request. @@ -130,7 +131,7 @@ def parse_authentication_request(self, request_body, http_headers=None): logger.debug('parsed authentication_request: %s', auth_req) return auth_req - def authorize(self, authentication_request, # type: oic.oic.message.AuthorizationRequest + def authorize(self, authentication_request, # type: AuthorizationRequest user_id, # type: str extra_id_token_claims=None # type: Optional[Union[Mapping[str, Union[str, List[str]]], Callable[[str, str], Mapping[str, Union[str, List[str]]]]] @@ -216,7 +217,7 @@ def _create_subject_identifier(self, user_id, client_id, redirect_uri): return self.authz_state.get_subject_identifier(subject_type, user_id, sector_identifier) def _get_requested_claims_in(self, authentication_request, response_method): - # type (oic.oic.message.AuthorizationRequest, str) -> Mapping[str, Optional[Mapping[str, Union[str, List[str]]]] + # type (AuthorizationRequest, str) -> Mapping[str, Optional[Mapping[str, Union[str, List[str]]]] """ Parses any claims requested using the 'claims' request parameter, see @@ -284,7 +285,7 @@ def _create_signed_id_token(self, return id_token.to_jwt([self.signing_key], alg) def _check_subject_identifier_matches_requested(self, authentication_request, sub): - # type (oic.message.AuthorizationRequest, str) -> None + # type (AuthorizationRequest, str) -> None """ Verifies the subject identifier against any requested subject identifier using the claims request parameter. :param authentication_request: authentication request @@ -328,6 +329,58 @@ def handle_token_request(self, request_body, # type: str raise InvalidTokenRequest('grant_type \'{}\' unknown'.format(token_request['grant_type']), token_request, oauth_error='unsupported_grant_type') + def _PKCE_verify(self, + token_request, # type: AccessTokenRequest + authentication_request # type: AuthorizationRequest + ): + # type: (...) -> bool + """ + Verify that the given code_verifier complies with the initially supplied code_challenge. + + Only supports the SHA256 code challenge method, plaintext is regarded as unsafe. + + :param token_request: the token request containing the initially supplied code challenge and code_challenge method. + :param authentication_request: the code_verfier to check against the code challenge. + :returns: whether the code_verifier is what was expected given the cc_cm + """ + if not 'code_verifier' in token_request: + return False + + if not 'code_challenge_method' in authentication_request: + raise InvalidTokenRequest("A code_challenge and code_verifier have been supplied" + "but missing code_challenge_method in authentication_request", token_request) + + # OIC Provider extension returns either a boolean or Response object containing an error. To support + # stricter typing guidelines, return if True. Error handling support should be in encapsulating function. + return OICProviderExtensions.verify_code_challenge(token_request['code_verifier'], + authentication_request['code_challenge'], authentication_request['code_challenge_method']) == True + + def _verify_code_exchange_req(self, + token_request, # type: AccessTokenRequest + authentication_request # type: AuthorizationRequest + ): + # type: (...) -> None + """ + Verify that the code exchange request is valid. In order to be valid we validate + the expected client and redirect_uri. Finally, if requested by the client, perform a + PKCE check. + + :param token_request: The request asking for a token given a code, and optionally a code_verifier + :param authentication_request: The authentication request belonging to the provided code. + :raises InvalidTokenRequest, InvalidAuthorizationCode: If request is invalid, throw a representing exception. + """ + if token_request['client_id'] != authentication_request['client_id']: + logger.info('Authorization code \'%s\' belonging to \'%s\' was used by \'%s\'', + token_request['code'], authentication_request['client_id'], token_request['client_id']) + raise InvalidAuthorizationCode('{} unknown'.format(token_request['code'])) + if token_request['redirect_uri'] != authentication_request['redirect_uri']: + raise InvalidTokenRequest('Invalid redirect_uri: {} != {}'.format(token_request['redirect_uri'], + authentication_request['redirect_uri']), + token_request) + if 'code_challenge' in authentication_request and not self._PKCE_verify(token_request, authentication_request): + raise InvalidTokenRequest('Unexpected Code Verifier: {}'.format(authentication_request['code_challenge']), + token_request) + def _do_code_exchange(self, request, # type: Dict[str, str] extra_id_token_claims=None # type: Optional[Union[Mapping[str, Union[str, List[str]]], Callable[[str, str], Mapping[str, Union[str, List[str]]]]] @@ -351,14 +404,7 @@ def _do_code_exchange(self, request, # type: Dict[str, str] authentication_request = self.authz_state.get_authorization_request_for_code(token_request['code']) - if token_request['client_id'] != authentication_request['client_id']: - logger.info('Authorization code \'%s\' belonging to \'%s\' was used by \'%s\'', - token_request['code'], authentication_request['client_id'], token_request['client_id']) - raise InvalidAuthorizationCode('{} unknown'.format(token_request['code'])) - if token_request['redirect_uri'] != authentication_request['redirect_uri']: - raise InvalidTokenRequest('Invalid redirect_uri: {} != {}'.format(token_request['redirect_uri'], - authentication_request['redirect_uri']), - token_request) + self._verify_code_exchange_req(token_request, authentication_request) sub = self.authz_state.get_subject_identifier_for_code(token_request['code']) user_id = self.authz_state.get_user_id_for_subject_identifier(sub) diff --git a/tests/pyop/test_authz_state.py b/tests/pyop/test_authz_state.py index 2eb9ade..2812973 100644 --- a/tests/pyop/test_authz_state.py +++ b/tests/pyop/test_authz_state.py @@ -4,8 +4,8 @@ from unittest.mock import patch, Mock import pytest -from oic.oic.message import AuthorizationRequest +from pyop.message import AuthorizationRequest from pyop.authz_state import AccessToken, InvalidScope from pyop.authz_state import AuthorizationState from pyop.exceptions import InvalidSubjectIdentifier, InvalidAccessToken, InvalidAuthorizationCode, InvalidRefreshToken diff --git a/tests/pyop/test_exceptions.py b/tests/pyop/test_exceptions.py index 64fb050..7ec5112 100644 --- a/tests/pyop/test_exceptions.py +++ b/tests/pyop/test_exceptions.py @@ -1,7 +1,6 @@ from urllib.parse import urlparse, parse_qsl -from oic.oic.message import AuthorizationRequest - +from pyop.message import AuthorizationRequest from pyop.exceptions import InvalidAuthenticationRequest diff --git a/tests/pyop/test_provider.py b/tests/pyop/test_provider.py index 5f85482..9c20fa0 100644 --- a/tests/pyop/test_provider.py +++ b/tests/pyop/test_provider.py @@ -12,8 +12,9 @@ from oic import rndstr from oic.oauth2.message import MissingRequiredValue, MissingRequiredAttribute from oic.oic import PREFERENCE2PROVIDER -from oic.oic.message import IdToken, AuthorizationRequest, ClaimsRequest, Claims, EndSessionRequest, EndSessionResponse +from oic.oic.message import IdToken, ClaimsRequest, Claims, EndSessionRequest, EndSessionResponse +from pyop.message import AuthorizationRequest from pyop.access_token import BearerTokenError from pyop.authz_state import AuthorizationState from pyop.client_authentication import InvalidClientAuthentication @@ -318,6 +319,20 @@ def test_code_exchange_request(self): assert_id_token_base_claims(response['id_token'], self.provider.signing_key, self.provider, self.authn_request_args) + @patch('time.time', MOCK_TIME) + def test_pkce_code_exchange_request(self): + self.authorization_code_exchange_request_args['code'] = self.create_authz_code( + { + "code_challenge": "_1f8tFjAtu6D1Df-GOyDPoMjCJdEvaSWsnqR6SLpzsw", + "code_challenge_method": "S256" + } + ) + self.authorization_code_exchange_request_args['code_verifier'] = "SoOEDN-mZKNhw7Mc52VXxyiqTvFB3mod36MwPru253c" + response = self.provider._do_code_exchange(self.authorization_code_exchange_request_args, None) + assert response['access_token'] in self.provider.authz_state.access_tokens + assert_id_token_base_claims(response['id_token'], self.provider.signing_key, self.provider, + self.authn_request_args) + @patch('time.time', MOCK_TIME) def test_code_exchange_request_with_claims_requested_in_id_token(self): claims_req = {'claims': ClaimsRequest(id_token=Claims(email=None))} @@ -374,6 +389,36 @@ def test_handle_token_request_reject_missing_grant_type(self): with pytest.raises(InvalidTokenRequest): self.provider.handle_token_request(urlencode(self.authorization_code_exchange_request_args)) + def test_handle_token_request_reject_invalid_code_verifier(self): + self.authorization_code_exchange_request_args['code'] = self.create_authz_code( + { + "code_challenge": "_1f8tFjAtu6D1Df-GOyDPoMjCJdEvaSWsnqR6SLpzsw=", + "code_challenge_method": "S256" + } + ) + self.authorization_code_exchange_request_args['code_verifier'] = "ThiS Cer_tainly Ain't Valid" + with pytest.raises(InvalidTokenRequest): + self.provider.handle_token_request(urlencode(self.authorization_code_exchange_request_args)) + + def test_handle_token_request_reject_unsynced_requests(self): + self.authorization_code_exchange_request_args['code'] = self.create_authz_code( + { + "code_challenge": "_1f8tFjAtu6D1Df-GOyDPoMjCJdEvaSWsnqR6SLpzsw=", + "code_challenge_method": "S256" + } + ) + with pytest.raises(InvalidTokenRequest): + self.provider.handle_token_request(urlencode(self.authorization_code_exchange_request_args)) + + def test_handle_token_request_reject_missing_code_challenge_method(self): + self.authorization_code_exchange_request_args['code'] = self.create_authz_code( + { + "code_challenge": "_1f8tFjAtu6D1Df-GOyDPoMjCJdEvaSWsnqR6SLpzsw=", + } + ) + with pytest.raises(InvalidTokenRequest): + self.provider.handle_token_request(urlencode(self.authorization_code_exchange_request_args)) + def test_refresh_request(self): self.provider.authz_state = AuthorizationState(HashBasedSubjectIdentifierFactory('salt'), refresh_token_lifetime=600) diff --git a/tests/pyop/test_util.py b/tests/pyop/test_util.py index be22121..5251838 100644 --- a/tests/pyop/test_util.py +++ b/tests/pyop/test_util.py @@ -1,6 +1,6 @@ import pytest -from oic.oic.message import AuthorizationRequest +from pyop.message import AuthorizationRequest from pyop.util import should_fragment_encode