diff --git a/oauthenticator/oauth2.py b/oauthenticator/oauth2.py index bdf9f285..626ec092 100644 --- a/oauthenticator/oauth2.py +++ b/oauthenticator/oauth2.py @@ -1054,12 +1054,34 @@ def build_access_tokens_request_params(self, handler, data=None): return params + def build_refresh_token_request_params(self, refresh_token): + """ + Builds the parameters that should be passed to the URL request + to renew the Access Token based on the Refresh Token + + Called by the :meth:`oauthenticator.OAuthenticator.refresh_user`. + """ + params = { + "grant_type": "refresh_token", + "refresh_token": refresh_token, + } + + # the client_id and client_secret should not be included in the access token request params + # when basic authentication is used + # ref: https://www.rfc-editor.org/rfc/rfc6749#section-2.3.1 + if not self.basic_auth: + params["client_id"] = self.client_id + params["client_secret"] = self.client_secret + + return params + async def get_token_info(self, handler, params): """ Makes a "POST" request to `self.token_url`, with the parameters received as argument. Returns: - the JSON response to the `token_url` the request. + the JSON response to the `token_url` the request as described in + https://www.rfc-editor.org/rfc/rfc6749#section-5.1 Called by the :meth:`oauthenticator.OAuthenticator.authenticate` """ @@ -1146,7 +1168,7 @@ async def token_to_user(self, token_info): def build_auth_state_dict(self, token_info, user_info): """ - Builds the `auth_state` dict that will be returned by a succesfull `authenticate` method call. + Builds the `auth_state` dict that will be returned by a successful `authenticate` method call. May be async (requires oauthenticator >= 17.0). Args: @@ -1168,7 +1190,7 @@ def build_auth_state_dict(self, token_info, user_info): This method may be async. """ - # We know for sure the `access_token` key exists, oterwise we would have errored out already + # We know for sure the `access_token` key exists, otherwise we would have errored out already access_token = token_info["access_token"] refresh_token = token_info.get("refresh_token", None) @@ -1288,24 +1310,84 @@ async def authenticate(self, handler, data=None, **kwargs): """ # build the parameters to be used in the request exchanging the oauth code for the access token access_token_params = self.build_access_tokens_request_params(handler, data) - # exchange the oauth code for an access token and get the JSON with info about it token_info = await self.get_token_info(handler, access_token_params) + # call the oauth endpoints + return await self._token_to_auth_model(token_info) + + async def refresh_user(self, user, handler=None, **kwargs): + """ + Renew the Access Token with a valid Refresh Token + """ + if not self.enable_auth_state: + # auth state not enabled, can't refresh + self.log.debug("auth_state disabled, no auth state to refresh") + return True + auth_state = await user.get_auth_state() + if not auth_state: + self.log.info( + f"No auth_state found for user {user.name} refresh, need full authentication", + ) + return False + + token_info = auth_state.get("token_response") + auth_model = None + try: + auth_model = await self._token_to_auth_model(token_info) + except HTTPClientError as e: + # assume any client error means an expired token + # most likely 401 or 403 for well-behaved providers + if 400 <= e.code < 500: + self.log.info( + f"Error refreshing auth with current access_token for {user.name}: {e}. Will try to refresh, if possible." + ) + else: + raise + refresh_token = auth_state.get("refresh_token", None) + if refresh_token and not auth_model: + self.log.info(f"Refreshing oauth access token for {user.name}") + # access_token expired, try refreshing with refresh_token + refresh_token_params = self.build_refresh_token_request_params( + refresh_token + ) + try: + token_info = await self.get_token_info(handler, refresh_token_params) + except Exception as e: + self.log.info( + f"Error using refresh_token for {user.name}: {e}. Requiring fresh login." + ) + return False + else: + self.log.debug( + f"Received fresh access_token for {user.name} via refresh_token" + ) + # refresh_token may not be returned when refreshing a token + # in which case, keep the current one + if not token_info.get("refresh_token"): + token_info["refresh_token"] = refresh_token + try: + auth_model = await self._token_to_auth_model(token_info) + except Exception as e: + # this means we were issued a fresh access token, + # but it didn't work! Fail harder? + self.log.error( + f"Error refreshing auth with fresh access_token for {user.name}: {e}. Requiring fresh login." + ) + return False + + # return False if auth_model is None for "needs new login" + return auth_model or False + + async def _token_to_auth_model(self, token_info): + """ + Common logic shared by authenticate() and refresh_user() + """ + # use the access_token to get userdata info user_info = await self.token_to_user(token_info) # extract the username out of the user_info dict and normalize it username = self.user_info_to_username(user_info) username = self.normalize_username(username) - # check if there any refresh_token in the token_info dict - refresh_token = token_info.get("refresh_token", None) - if self.enable_auth_state and not refresh_token: - self.log.debug( - "Refresh token was empty, will try to pull refresh_token from previous auth_state" - ) - refresh_token = await self.get_prev_refresh_token(handler, username) - if refresh_token: - token_info["refresh_token"] = refresh_token - auth_state = self.build_auth_state_dict(token_info, user_info) if isawaitable(auth_state): auth_state = await auth_state diff --git a/oauthenticator/tests/mocks.py b/oauthenticator/tests/mocks.py index d9174604..5186f5ef 100644 --- a/oauthenticator/tests/mocks.py +++ b/oauthenticator/tests/mocks.py @@ -107,6 +107,7 @@ def setup_oauth_mock( user_path=None, token_type='Bearer', token_request_style='post', + enable_refresh_tokens=False, scope="", ): """setup the mock client for OAuth @@ -134,6 +135,8 @@ def setup_oauth_mock( client.oauth_codes = oauth_codes = {} client.access_tokens = access_tokens = {} + client.refresh_tokens = refresh_tokens = {} + client.enable_refresh_tokens = enable_refresh_tokens def access_token(request): """Handler for access token endpoint @@ -146,26 +149,53 @@ def access_token(request): if not query: query = request.body.decode('utf8') query = parse_qs(query) - if 'code' not in query: + grant_type = query.get("grant_type", [""])[0] + if grant_type == 'authorization_code': + if 'code' not in query: + return HTTPResponse( + request=request, + code=400, + reason=f"No code in access token request: url={request.url}, body={request.body}", + ) + code = query['code'][0] + if code not in oauth_codes: + return HTTPResponse( + request=request, code=403, reason=f"No such code: {code}" + ) + user = oauth_codes.pop(code) + elif grant_type == 'refresh_token': + if 'refresh_token' not in query: + return HTTPResponse( + request=request, + code=400, + reason=f"No refresh_token in access token request: url={request.url}, body={request.body}", + ) + refresh_token = query['refresh_token'][0] + if refresh_token not in refresh_token: + return HTTPResponse( + request=request, + code=403, + reason=f"No such refresh_toekn: {refresh_token}", + ) + user = refresh_tokens[refresh_token] + else: return HTTPResponse( request=request, code=400, - reason=f"No code in access token request: url={request.url}, body={request.body}", - ) - code = query['code'][0] - if code not in oauth_codes: - return HTTPResponse( - request=request, code=403, reason=f"No such code: {code}" + reason=f"Invalid grant_type={grant_type}: url={request.url}, body={request.body}", ) # consume code, allocate token - token = uuid.uuid4().hex - user = oauth_codes.pop(code) - access_tokens[token] = user + access_token = uuid.uuid4().hex + access_tokens[access_token] = user model = { - 'access_token': token, + 'access_token': access_token, 'token_type': token_type, } + if client.enable_refresh_tokens: + refresh_token = uuid.uuid4().hex + refresh_tokens[refresh_token] = user + model['refresh_token'] = refresh_token if scope: model['scope'] = scope if 'id_token' in user: diff --git a/oauthenticator/tests/test_generic.py b/oauthenticator/tests/test_generic.py index e3225039..e5c99602 100644 --- a/oauthenticator/tests/test_generic.py +++ b/oauthenticator/tests/test_generic.py @@ -15,14 +15,15 @@ def user_model(username, **kwargs): """Return a user model""" - return { + model = { "username": username, "aud": client_id, "sub": "oauth2|cilogon|http://cilogon.org/servera/users/43431", "scope": "basic", "groups": ["group1"], - **kwargs, } + model.update(kwargs) + return model @fixture(params=["id_token", "userdata_url"]) @@ -505,6 +506,90 @@ async def test_check_allowed_no_auth_state(get_authenticator, name, allowed): assert await authenticator.check_allowed(name, None) +class MockUser: + """Mock subset of JupyterHub User API from the `auth_model` dict""" + + name: str + + def __init__(self, auth_model): + self._auth_model = auth_model + self.name = auth_model["name"] + + async def get_auth_state(self): + return self._auth_model["auth_state"] + + +@mark.parametrize("enable_refresh_tokens", [True, False]) +async def test_refresh_user(get_authenticator, generic_client, enable_refresh_tokens): + generic_client.enable_refresh_tokens = enable_refresh_tokens + authenticator = get_authenticator(allowed_users={"user1"}) + authenticator.manage_groups = True + authenticator.auth_state_groups_key = "oauth_user.groups" + oauth_userinfo = user_model("user1", groups=["round1"]) + handler = generic_client.handler_for_user(oauth_userinfo) + auth_model = await authenticator.get_authenticated_user(handler, None) + auth_state = auth_model["auth_state"] + assert auth_model["groups"] == ["round1"] + if enable_refresh_tokens: + assert "refresh_token" in auth_state + assert "refresh_token" in auth_state["token_response"] + assert ( + auth_state["refresh_token"] == auth_state["token_response"]["refresh_token"] + ) + else: + assert "refresh_token" not in auth_state["token_response"] + assert auth_state.get("refresh_token") is None + user = MockUser(auth_model) + # case: auth_state not enabled, nothing to refresh + refreshed = await authenticator.refresh_user(user, handler) + assert refreshed is True + + # from here on, enable auth state required for refresh to do anything + authenticator.enable_auth_state = True + + # case: no auth state, but auth state enabled needs refresh + auth_without_state = auth_model.copy() + auth_without_state["auth_state"] = None + user_without_state = MockUser(auth_without_state) + refreshed = await authenticator.refresh_user(user_without_state, handler) + assert refreshed is False + + # case: actually refresh + oauth_userinfo["groups"] = ["refreshed"] + refreshed = await authenticator.refresh_user(user, handler) + assert refreshed + assert refreshed["name"] == auth_model["name"] + assert refreshed["groups"] == ["refreshed"] + refreshed_state = refreshed["auth_state"] + assert "access_token" in refreshed_state + # refresh with access token succeeds, keeps tokens unchanged + assert refreshed_state.get("refresh_token") == auth_state.get("refresh_token") + assert refreshed_state["access_token"] == auth_state["access_token"] + + # case: access token is no longer valid, triggers refresh + oauth_userinfo["groups"] = ["token_refreshed"] + generic_client.access_tokens.pop(refreshed_state["access_token"]) + refreshed = await authenticator.refresh_user(user, handler) + if enable_refresh_tokens: + # access_token refreshed + assert refreshed + refreshed_state = refreshed["auth_state"] + assert ( + refreshed_state["access_token"] != auth_model["auth_state"]["access_token"] + ) + assert refreshed["groups"] == ["token_refreshed"] + else: + assert refreshed is False + + if enable_refresh_tokens: + # case: token used for refresh is no longer valid + user = MockUser(refreshed) + generic_client.access_tokens.pop(refreshed_state["access_token"]) + generic_client.refresh_tokens.pop(refreshed_state["refresh_token"]) + refreshed = await authenticator.refresh_user(user, handler) + assert refreshed is False + + @mark.parametrize( "test_variation_id,class_config,expect_config,expect_loglevel,expect_message", [ diff --git a/oauthenticator/tests/test_github.py b/oauthenticator/tests/test_github.py index e49fe064..1ca5f590 100644 --- a/oauthenticator/tests/test_github.py +++ b/oauthenticator/tests/test_github.py @@ -141,7 +141,7 @@ async def test_github( assert user_info == handled_user_model assert auth_model["name"] == user_info[authenticator.username_claim] else: - assert auth_model == None + assert auth_model is None def make_link_header(urlinfo, page):