diff --git a/oauthenticator/oauth2.py b/oauthenticator/oauth2.py index 793643a0..f3f55551 100644 --- a/oauthenticator/oauth2.py +++ b/oauthenticator/oauth2.py @@ -483,6 +483,36 @@ def _logout_redirect_url_default(self): """, ) + allowed_scopes = List( + Unicode(), + config=True, + help=""" + Allow users who have been granted *all* these scopes to log in. + + We request all the scopes listed in the 'scope' config, but only a + subset of these may be granted by the authorization server. This may + happen if the user does not have permissions to access a requested + scope, or has chosen to not give consent for a particular scope. If the + scopes listed in this config are not granted, the user will not be + allowed to log in. + + The granted scopes will be part of the access token (fetched from self.token_url). + See https://datatracker.ietf.org/doc/html/rfc6749#section-3.3 for more + information. + + See the OAuth documentation of your OAuth provider for various options. + """, + ) + + @validate('allowed_scopes') + def _allowed_scopes_validation(self, proposal): + # allowed scopes must be a subset of requested scopes + if set(proposal.value) - set(self.scope): + raise ValueError( + f"Allowed scopes must be a subset of requested scopes. {self.scope} is requested but {proposal.value} is allowed" + ) + return proposal.value + extra_authorize_params = Dict( config=True, help=""" @@ -1060,6 +1090,8 @@ async def check_allowed(self, username, auth_model): """ Returns True for users allowed to be authorized + If a user must be *disallowed*, raises a 403 exception. + Overrides Authenticator.check_allowed that is called from `Authenticator.get_authenticated_user` after `OAuthenticator.authenticate` has been called, and therefore also after @@ -1074,6 +1106,15 @@ async def check_allowed(self, username, auth_model): if auth_model is None: return True + # Allow users who have been granted specific scopes that grant them entry + if self.allowed_scopes: + granted_scopes = auth_model.get('auth_state', {}).get('scope', []) + missing_scopes = set(self.allowed_scopes) - set(granted_scopes) + if not missing_scopes: + message = f"Granting access to user {username}, as they had {self.allowed_scopes}" + self.log.info(message) + return True + if self.allow_all: return True diff --git a/oauthenticator/tests/mocks.py b/oauthenticator/tests/mocks.py index 15d2bd3f..efeac575 100644 --- a/oauthenticator/tests/mocks.py +++ b/oauthenticator/tests/mocks.py @@ -104,6 +104,8 @@ def setup_oauth_mock( access_token_path, user_path=None, token_type='Bearer', + token_request_style='post', + scope="", ): """setup the mock client for OAuth @@ -125,6 +127,7 @@ def setup_oauth_mock( access_token_path (str): The path for the access token request (e.g. /access_token) user_path (str): The path for requesting (e.g. /user) token_type (str): the token_type field for the provider + scope (str): The scope field returned by the provider """ client.oauth_codes = oauth_codes = {} @@ -161,6 +164,8 @@ def access_token(request): 'access_token': token, 'token_type': token_type, } + if scope: + model['scope'] = scope if 'id_token' in user: model['id_token'] = user['id_token'] return model @@ -172,6 +177,7 @@ def get_user(request): token = auth_header.split(None, 1)[1] else: query = parse_qs(urlparse(request.url).query) + if 'access_token' in query: token = query['access_token'][0] else: diff --git a/oauthenticator/tests/test_generic.py b/oauthenticator/tests/test_generic.py index a23a4723..07e6752a 100644 --- a/oauthenticator/tests/test_generic.py +++ b/oauthenticator/tests/test_generic.py @@ -1,8 +1,9 @@ import json +import re from functools import partial import jwt -from pytest import fixture, mark +from pytest import fixture, mark, raises from traitlets.config import Config from ..generic import GenericOAuthenticator @@ -35,6 +36,7 @@ def generic_client(client): host='generic.horse', access_token_path='/oauth/access_token', user_path='/oauth/userinfo', + scope='basic', ) return client @@ -293,6 +295,37 @@ async def test_generic_data(get_authenticator, generic_client): assert auth_model +@mark.parametrize( + ["allowed_scopes", "allowed"], [(["advanced"], False), (["basic"], True)] +) +async def test_allowed_scopes( + get_authenticator, generic_client, allowed_scopes, allowed +): + c = Config() + c.GenericOAuthenticator.allowed_scopes = allowed_scopes + c.GenericOAuthenticator.scope = list(allowed_scopes) + authenticator = get_authenticator(config=c) + + handled_user_model = user_model("user1") + handler = generic_client.handler_for_user(handled_user_model) + auth_model = await authenticator.authenticate(handler) + assert allowed == await authenticator.check_allowed(auth_model["name"], auth_model) + + +async def test_allowed_scopes_validation_scope_subset(get_authenticator): + c = Config() + # Test that if we require more scopes than we request, validation fails + c.GenericOAuthenticator.allowed_scopes = ["a", "b"] + c.GenericOAuthenticator.scope = ["a"] + with raises( + ValueError, + match=re.escape( + "Allowed scopes must be a subset of requested scopes. ['a'] is requested but ['a', 'b'] is allowed" + ), + ): + get_authenticator(config=c) + + async def test_generic_callable_username_key(get_authenticator, generic_client): c = Config() c.GenericOAuthenticator.allow_all = True