Skip to content

Commit

Permalink
Merge pull request #38 from maxxiefjv/add-PKCE-pyop
Browse files Browse the repository at this point in the history
Add PKCE support

Note that plaintext support is lacking, because it is considered unsafe.
  • Loading branch information
c00kiemon5ter authored Sep 3, 2021
2 parents 0ba83aa + 1c642b8 commit 2f110dc
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 27 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ user data and OpenID Connect claim names. Hence the underlying data source must
same names as the [standard claims of OpenID Connect](http://openid.net/specs/openid-connect-core-1_0.html#StandardClaims).

```python
from oic.oic.message import AuthorizationRequest
from pyop.message import AuthorizationRequest

from pyop.util import should_fragment_encode

Expand Down
12 changes: 6 additions & 6 deletions src/pyop/authz_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import uuid

from oic.extension.message import TokenIntrospectionResponse
from oic.oic.message import AuthorizationRequest

from .message import AuthorizationRequest
from .access_token import AccessToken
from .exceptions import InvalidAccessToken
from .exceptions import InvalidAuthorizationCode
Expand Down Expand Up @@ -80,7 +80,7 @@ def __init__(self, subject_identifier_factory, authorization_code_db=None, acces
self.subject_identifiers = subject_identifier_db if subject_identifier_db is not None else {}

def create_authorization_code(self, authorization_request, subject_identifier, scope=None):
# type: (oic.oic.message.AuthorizationRequest, str, Optional[List[str]]) -> str
# type: (AuthorizationRequest, str, Optional[List[str]]) -> str
"""
Creates an authorization code bound to the authorization request and the authenticated user identified
by the subject identifier.
Expand All @@ -106,7 +106,7 @@ def create_authorization_code(self, authorization_request, subject_identifier, s
return authorization_code

def create_access_token(self, authorization_request, subject_identifier, scope=None):
# type: (oic.oic.message.AuthorizationRequest, str, Optional[List[str]]) -> se_leg_op.access_token.AccessToken
# type: (AuthorizationRequest, str, Optional[List[str]]) -> se_leg_op.access_token.AccessToken
"""
Creates an access token bound to the authentication request and the authenticated user identified by the
subject identifier.
Expand Down Expand Up @@ -315,22 +315,22 @@ def get_user_id_for_subject_identifier(self, subject_identifier):
raise InvalidSubjectIdentifier('{} unknown'.format(subject_identifier))

def get_authorization_request_for_code(self, authorization_code):
# type: (str) -> oic.oic.message.AuthorizationRequest
# type: (str) -> AuthorizationRequest
if authorization_code not in self.authorization_codes:
raise InvalidAuthorizationCode('{} unknown'.format(authorization_code))

return AuthorizationRequest().from_dict(
self.authorization_codes[authorization_code][self.KEY_AUTHORIZATION_REQUEST])

def get_authorization_request_for_access_token(self, access_token_value):
# type: (str) -> oic.oic.message.AuthorizationRequest
# type: (str) ->
if access_token_value not in self.access_tokens:
raise InvalidAccessToken('{} unknown'.format(access_token_value))

return AuthorizationRequest().from_dict(self.access_tokens[access_token_value][self.KEY_AUTHORIZATION_REQUEST])

def get_subject_identifier_for_code(self, authorization_code):
# type: (str) -> oic.oic.message.AuthorizationRequest
# type: (str) -> AuthorizationRequest
if authorization_code not in self.authorization_codes:
raise InvalidAuthorizationCode('{} unknown'.format(authorization_code))

Expand Down
29 changes: 29 additions & 0 deletions src/pyop/message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from oic.oauth2.message import SINGLE_OPTIONAL_STRING
from oic.oic import message

class AccessTokenRequest(message.AccessTokenRequest):
c_param = message.AccessTokenRequest.c_param.copy()
c_param.update(
{
'code_verifier': SINGLE_OPTIONAL_STRING
}
)

class AuthorizationRequest(message.AuthorizationRequest):
c_param = message.AuthorizationRequest.c_param.copy()
c_param.update(
{
'code_challenge': SINGLE_OPTIONAL_STRING,
'code_challenge_method': SINGLE_OPTIONAL_STRING
}
)

c_allowed_values = message.AuthorizationRequest.c_allowed_values.copy()
c_allowed_values.update(
{
"code_challenge_method": [
"plain",
"S256"
]
}
)
76 changes: 61 additions & 15 deletions src/pyop/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
from oic.exception import MessageException
from oic.oic import PREFERENCE2PROVIDER
from oic.oic import scope2claims
from oic.oic.message import AccessTokenRequest
from oic.oic.message import AccessTokenResponse
from oic.oic.message import AuthorizationRequest
from oic.oic.message import AuthorizationResponse
from oic.oic.message import EndSessionRequest
from oic.oic.message import EndSessionResponse
Expand All @@ -23,7 +21,10 @@
from oic.oic.message import RefreshAccessTokenRequest
from oic.oic.message import RegistrationRequest
from oic.oic.message import RegistrationResponse
from oic.extension.provider import Provider as OICProviderExtensions

from .message import AuthorizationRequest
from .message import AccessTokenRequest
from .access_token import extract_bearer_token_from_http_request
from .client_authentication import verify_client_authentication
from .exceptions import AuthorizationError
Expand Down Expand Up @@ -81,7 +82,7 @@ def __init__(self, signing_key, configuration_information, authz_state, clients,
self.userinfo = userinfo
self.id_token_lifetime = id_token_lifetime

self.authentication_request_validators = [] # type: List[Callable[[oic.oic.message.AuthorizationRequest], Boolean]]
self.authentication_request_validators = [] # type: List[Callable[[AuthorizationRequest], Boolean]]
self.authentication_request_validators.append(authorization_request_verify)
self.authentication_request_validators.append(
functools.partial(client_id_is_known, self))
Expand Down Expand Up @@ -114,7 +115,7 @@ def jwks(self):
return {'keys': keys}

def parse_authentication_request(self, request_body, http_headers=None):
# type: (str, Optional[Mapping[str, str]]) -> oic.oic.message.AuthorizationRequest
# type: (str, Optional[Mapping[str, str]]) -> AuthorizationRequest
"""
Parses and verifies an authentication request.
Expand All @@ -130,7 +131,7 @@ def parse_authentication_request(self, request_body, http_headers=None):
logger.debug('parsed authentication_request: %s', auth_req)
return auth_req

def authorize(self, authentication_request, # type: oic.oic.message.AuthorizationRequest
def authorize(self, authentication_request, # type: AuthorizationRequest
user_id, # type: str
extra_id_token_claims=None
# type: Optional[Union[Mapping[str, Union[str, List[str]]], Callable[[str, str], Mapping[str, Union[str, List[str]]]]]
Expand Down Expand Up @@ -216,7 +217,7 @@ def _create_subject_identifier(self, user_id, client_id, redirect_uri):
return self.authz_state.get_subject_identifier(subject_type, user_id, sector_identifier)

def _get_requested_claims_in(self, authentication_request, response_method):
# type (oic.oic.message.AuthorizationRequest, str) -> Mapping[str, Optional[Mapping[str, Union[str, List[str]]]]
# type (AuthorizationRequest, str) -> Mapping[str, Optional[Mapping[str, Union[str, List[str]]]]
"""
Parses any claims requested using the 'claims' request parameter, see
<a href="http://openid.net/specs/openid-connect-core-1_0.html#ClaimsParameter">
Expand Down Expand Up @@ -284,7 +285,7 @@ def _create_signed_id_token(self,
return id_token.to_jwt([self.signing_key], alg)

def _check_subject_identifier_matches_requested(self, authentication_request, sub):
# type (oic.message.AuthorizationRequest, str) -> None
# type (AuthorizationRequest, str) -> None
"""
Verifies the subject identifier against any requested subject identifier using the claims request parameter.
:param authentication_request: authentication request
Expand Down Expand Up @@ -328,6 +329,58 @@ def handle_token_request(self, request_body, # type: str
raise InvalidTokenRequest('grant_type \'{}\' unknown'.format(token_request['grant_type']), token_request,
oauth_error='unsupported_grant_type')

def _PKCE_verify(self,
token_request, # type: AccessTokenRequest
authentication_request # type: AuthorizationRequest
):
# type: (...) -> bool
"""
Verify that the given code_verifier complies with the initially supplied code_challenge.
Only supports the SHA256 code challenge method, plaintext is regarded as unsafe.
:param token_request: the token request containing the initially supplied code challenge and code_challenge method.
:param authentication_request: the code_verfier to check against the code challenge.
:returns: whether the code_verifier is what was expected given the cc_cm
"""
if not 'code_verifier' in token_request:
return False

if not 'code_challenge_method' in authentication_request:
raise InvalidTokenRequest("A code_challenge and code_verifier have been supplied"
"but missing code_challenge_method in authentication_request", token_request)

# OIC Provider extension returns either a boolean or Response object containing an error. To support
# stricter typing guidelines, return if True. Error handling support should be in encapsulating function.
return OICProviderExtensions.verify_code_challenge(token_request['code_verifier'],
authentication_request['code_challenge'], authentication_request['code_challenge_method']) == True

def _verify_code_exchange_req(self,
token_request, # type: AccessTokenRequest
authentication_request # type: AuthorizationRequest
):
# type: (...) -> None
"""
Verify that the code exchange request is valid. In order to be valid we validate
the expected client and redirect_uri. Finally, if requested by the client, perform a
PKCE check.
:param token_request: The request asking for a token given a code, and optionally a code_verifier
:param authentication_request: The authentication request belonging to the provided code.
:raises InvalidTokenRequest, InvalidAuthorizationCode: If request is invalid, throw a representing exception.
"""
if token_request['client_id'] != authentication_request['client_id']:
logger.info('Authorization code \'%s\' belonging to \'%s\' was used by \'%s\'',
token_request['code'], authentication_request['client_id'], token_request['client_id'])
raise InvalidAuthorizationCode('{} unknown'.format(token_request['code']))
if token_request['redirect_uri'] != authentication_request['redirect_uri']:
raise InvalidTokenRequest('Invalid redirect_uri: {} != {}'.format(token_request['redirect_uri'],
authentication_request['redirect_uri']),
token_request)
if 'code_challenge' in authentication_request and not self._PKCE_verify(token_request, authentication_request):
raise InvalidTokenRequest('Unexpected Code Verifier: {}'.format(authentication_request['code_challenge']),
token_request)

def _do_code_exchange(self, request, # type: Dict[str, str]
extra_id_token_claims=None
# type: Optional[Union[Mapping[str, Union[str, List[str]]], Callable[[str, str], Mapping[str, Union[str, List[str]]]]]
Expand All @@ -351,14 +404,7 @@ def _do_code_exchange(self, request, # type: Dict[str, str]

authentication_request = self.authz_state.get_authorization_request_for_code(token_request['code'])

if token_request['client_id'] != authentication_request['client_id']:
logger.info('Authorization code \'%s\' belonging to \'%s\' was used by \'%s\'',
token_request['code'], authentication_request['client_id'], token_request['client_id'])
raise InvalidAuthorizationCode('{} unknown'.format(token_request['code']))
if token_request['redirect_uri'] != authentication_request['redirect_uri']:
raise InvalidTokenRequest('Invalid redirect_uri: {} != {}'.format(token_request['redirect_uri'],
authentication_request['redirect_uri']),
token_request)
self._verify_code_exchange_req(token_request, authentication_request)

sub = self.authz_state.get_subject_identifier_for_code(token_request['code'])
user_id = self.authz_state.get_user_id_for_subject_identifier(sub)
Expand Down
2 changes: 1 addition & 1 deletion tests/pyop/test_authz_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from unittest.mock import patch, Mock

import pytest
from oic.oic.message import AuthorizationRequest

from pyop.message import AuthorizationRequest
from pyop.authz_state import AccessToken, InvalidScope
from pyop.authz_state import AuthorizationState
from pyop.exceptions import InvalidSubjectIdentifier, InvalidAccessToken, InvalidAuthorizationCode, InvalidRefreshToken
Expand Down
3 changes: 1 addition & 2 deletions tests/pyop/test_exceptions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from urllib.parse import urlparse, parse_qsl

from oic.oic.message import AuthorizationRequest

from pyop.message import AuthorizationRequest
from pyop.exceptions import InvalidAuthenticationRequest


Expand Down
47 changes: 46 additions & 1 deletion tests/pyop/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from oic import rndstr
from oic.oauth2.message import MissingRequiredValue, MissingRequiredAttribute
from oic.oic import PREFERENCE2PROVIDER
from oic.oic.message import IdToken, AuthorizationRequest, ClaimsRequest, Claims, EndSessionRequest, EndSessionResponse
from oic.oic.message import IdToken, ClaimsRequest, Claims, EndSessionRequest, EndSessionResponse

from pyop.message import AuthorizationRequest
from pyop.access_token import BearerTokenError
from pyop.authz_state import AuthorizationState
from pyop.client_authentication import InvalidClientAuthentication
Expand Down Expand Up @@ -318,6 +319,20 @@ def test_code_exchange_request(self):
assert_id_token_base_claims(response['id_token'], self.provider.signing_key, self.provider,
self.authn_request_args)

@patch('time.time', MOCK_TIME)
def test_pkce_code_exchange_request(self):
self.authorization_code_exchange_request_args['code'] = self.create_authz_code(
{
"code_challenge": "_1f8tFjAtu6D1Df-GOyDPoMjCJdEvaSWsnqR6SLpzsw",
"code_challenge_method": "S256"
}
)
self.authorization_code_exchange_request_args['code_verifier'] = "SoOEDN-mZKNhw7Mc52VXxyiqTvFB3mod36MwPru253c"
response = self.provider._do_code_exchange(self.authorization_code_exchange_request_args, None)
assert response['access_token'] in self.provider.authz_state.access_tokens
assert_id_token_base_claims(response['id_token'], self.provider.signing_key, self.provider,
self.authn_request_args)

@patch('time.time', MOCK_TIME)
def test_code_exchange_request_with_claims_requested_in_id_token(self):
claims_req = {'claims': ClaimsRequest(id_token=Claims(email=None))}
Expand Down Expand Up @@ -374,6 +389,36 @@ def test_handle_token_request_reject_missing_grant_type(self):
with pytest.raises(InvalidTokenRequest):
self.provider.handle_token_request(urlencode(self.authorization_code_exchange_request_args))

def test_handle_token_request_reject_invalid_code_verifier(self):
self.authorization_code_exchange_request_args['code'] = self.create_authz_code(
{
"code_challenge": "_1f8tFjAtu6D1Df-GOyDPoMjCJdEvaSWsnqR6SLpzsw=",
"code_challenge_method": "S256"
}
)
self.authorization_code_exchange_request_args['code_verifier'] = "ThiS Cer_tainly Ain't Valid"
with pytest.raises(InvalidTokenRequest):
self.provider.handle_token_request(urlencode(self.authorization_code_exchange_request_args))

def test_handle_token_request_reject_unsynced_requests(self):
self.authorization_code_exchange_request_args['code'] = self.create_authz_code(
{
"code_challenge": "_1f8tFjAtu6D1Df-GOyDPoMjCJdEvaSWsnqR6SLpzsw=",
"code_challenge_method": "S256"
}
)
with pytest.raises(InvalidTokenRequest):
self.provider.handle_token_request(urlencode(self.authorization_code_exchange_request_args))

def test_handle_token_request_reject_missing_code_challenge_method(self):
self.authorization_code_exchange_request_args['code'] = self.create_authz_code(
{
"code_challenge": "_1f8tFjAtu6D1Df-GOyDPoMjCJdEvaSWsnqR6SLpzsw=",
}
)
with pytest.raises(InvalidTokenRequest):
self.provider.handle_token_request(urlencode(self.authorization_code_exchange_request_args))

def test_refresh_request(self):
self.provider.authz_state = AuthorizationState(HashBasedSubjectIdentifierFactory('salt'),
refresh_token_lifetime=600)
Expand Down
2 changes: 1 addition & 1 deletion tests/pyop/test_util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from oic.oic.message import AuthorizationRequest

from pyop.message import AuthorizationRequest
from pyop.util import should_fragment_encode


Expand Down

0 comments on commit 2f110dc

Please sign in to comment.