Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement oauthenticator.OAuthenticator.refresh_user method #579

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
82 changes: 78 additions & 4 deletions oauthenticator/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import base64
import json
import os
import time
import uuid
from functools import reduce
from inspect import isawaitable
Expand Down Expand Up @@ -695,6 +696,16 @@ def _client_secret_default(self):
return client_secret
return os.getenv("OAUTH_CLIENT_SECRET", "")

access_token_expiration = Unicode(
config=True,
default="3600",
help="""
If the `expires_in` field is omitted in the OAuth 2.0 token response
then this value will be the default expiration, in seconds, of the
access token.
""",
)

validate_server_cert_env = "OAUTH_TLS_VERIFY"
validate_server_cert = Bool(
config=True,
Expand Down Expand Up @@ -992,12 +1003,35 @@ def build_access_tokens_request_params(self, handler, data=None):

return params

def build_refresh_token_request_params(self, refresh_token):
"""
Builds the parameters that should be passed to the URL request
to renew the Access Token based on the Refresh Token

Called by the :meth:`oauthenticator.OAuthenticator.refresh_user`.
"""
params = {
"grant_type": "refresh_token",
"refresh_token": refresh_token,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to include 'scope' here too - the same one used on initial login, to ensure you get back an equivalent id_token/access_token

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apologies, looks like this is only required on the original authorization code request. the scopes for the token_info and refresh_token requests do not affect the responses

}

# the client_id and client_secret should not be included in the access token request params
# when basic authentication is used
# ref: https://www.rfc-editor.org/rfc/rfc6749#section-2.3.1
if not self.basic_auth:
params.update(
[("client_id", self.client_id), ("client_secret", self.client_secret)]
)

return params

async def get_token_info(self, handler, params):
"""
Makes a "POST" request to `self.token_url`, with the parameters received as argument.

Returns:
the JSON response to the `token_url` the request.
the JSON response to the `token_url` the request as described in
https://www.rfc-editor.org/rfc/rfc6749#section-5.1

Called by the :meth:`oauthenticator.OAuthenticator.authenticate`
"""
Expand Down Expand Up @@ -1084,7 +1118,7 @@ async def token_to_user(self, token_info):

def build_auth_state_dict(self, token_info, user_info):
"""
Builds the `auth_state` dict that will be returned by a succesfull `authenticate` method call.
Builds the `auth_state` dict that will be returned by a successful `authenticate` method call.
May be async (requires oauthenticator >= 17.0).

Args:
Expand All @@ -1094,6 +1128,8 @@ def build_auth_state_dict(self, token_info, user_info):
Returns:
auth_state: a dictionary of auth state that should be persisted with the following keys:
- "access_token": the access_token
- "created_at": creation date, in seconds, of the access_token
- "expires_in": expiration date, in seconds, of the access_token
- "refresh_token": the refresh_token, if available
- "id_token": the id_token, if available
- "scope": the scopes, if available
Expand All @@ -1106,8 +1142,10 @@ def build_auth_state_dict(self, token_info, user_info):
This method be async.
"""

# We know for sure the `access_token` key exists, oterwise we would have errored out already
# We know for sure the `access_token` key exists, otherwise we would have errored out already
access_token = token_info["access_token"]
created_at = token_info.get("created_at", time.time())
expires_in = token_info.get("expires_in", self.access_token_expiration)

refresh_token = token_info.get("refresh_token", None)
id_token = token_info.get("id_token", None)
Expand All @@ -1118,6 +1156,8 @@ def build_auth_state_dict(self, token_info, user_info):

return {
"access_token": access_token,
"created_at": created_at,
"expires_in": expires_in,
"refresh_token": refresh_token,
"id_token": id_token,
"scope": scope,
Expand Down Expand Up @@ -1212,8 +1252,42 @@ async def authenticate(self, handler, data=None, **kwargs):
"""
# build the parameters to be used in the request exchanging the oauth code for the access token
access_token_params = self.build_access_tokens_request_params(handler, data)
# call the oauth endpoints
return await self._oauth_call(handler, access_token_params, **kwargs)

async def refresh_user(self, user, handler=None, **kwargs):
"""
Renew the Access Token with a valid Refresh Token
"""

auth_state = await user.get_auth_state()
if not auth_state:
self.log.info(
f"No auth_state found for user {user} refresh, need full authentication",
)
return False

created_at = auth_state.get('created_at', 0)
expires_in = auth_state.get('expires_in', 0)
is_expired = created_at + expires_in - time.time() < 0
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May need to refresh before delta reaches 0, to avoid disruption while working.

if not is_expired:
self.log.info(
f"The access_token is still valid for user {user}, skipping refresh",
)
return True

refresh_token_params = self.build_refresh_token_request_params(
auth_state['refresh_token']
)
return await self._oauth_call(handler, refresh_token_params, **kwargs)

async def _oauth_call(self, handler, params, **kwargs):
"""
Common logic shared by authenticate() and refresh_user()
"""

# exchange the oauth code for an access token and get the JSON with info about it
token_info = await self.get_token_info(handler, access_token_params)
token_info = await self.get_token_info(handler, params)
# use the access_token to get userdata info
user_info = await self.token_to_user(token_info)
# extract the username out of the user_info dict and normalize it
Expand Down