Skip to content

Commit

Permalink
Refresh token (#12)
Browse files Browse the repository at this point in the history
* Add "expires_at" to token data if applicable

If the expires_in key exists the expires_at will be added. This can be
useful to determine if the token needs refreshing without a failing call
to the api.

The expires_at is an absolute time when the token expires.

fixup

* Add support for refreshing oauth token

The "oauth.refresh" message is now handled and will trigger a token
refresh if possible. The Message data must contain valid "skill_id" and
"app_id" fields.

* gitignore: Add vim cache files
  • Loading branch information
forslund authored Feb 9, 2024
1 parent 4c67985 commit d2eb383
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 2 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,6 @@ dmypy.json

# Pyre type checker
.pyre/

# vim cache files
.*.sw?
66 changes: 64 additions & 2 deletions ovos_PHAL_plugin_oauth/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import tempfile
import time
import uuid

import qrcode
Expand Down Expand Up @@ -58,6 +59,11 @@ def oauth_callback(munged_id):
).json()

with OAuthTokenDatabase() as db:
# Make sure expires_at entry exists
if 'expires_at' not in token_response:
token_response['expires_at'] = (
time.time() + token_response['expires_in']
)
db.add_token(munged_id, token_response)

# Allow any registered app / skill to handle the token response urgently, if needed
Expand Down Expand Up @@ -92,6 +98,7 @@ def __init__(self, bus=None, config=None):
self.bus.on("oauth.register", self.handle_oauth_register)
self.bus.on("oauth.start", self.handle_start_oauth)
self.bus.on("oauth.get", self.handle_get_auth_url)
self.bus.on("oauth.refresh", self.handle_oauth_refresh_token)
self.bus.on("ovos.shell.oauth.register.credentials",
self.handle_client_secret)

Expand Down Expand Up @@ -147,7 +154,6 @@ def handle_oauth_register(self, message):
# these fields are app specific and provided by skills
auth_endpoint = message.data.get("auth_endpoint")
token_endpoint = message.data.get("token_endpoint")
refresh_endpoint = message.data.get("refresh_endpoint")
cb_endpoint = f"http://0.0.0.0:{self.port}/auth/callback/{munged_id}"
scope = message.data.get("scope")

Expand All @@ -169,7 +175,7 @@ def handle_oauth_register(self, message):
client_secret=client_secret,
auth_endpoint=auth_endpoint,
token_endpoint=token_endpoint,
refresh_endpoint=refresh_endpoint,
refresh_endpoint=None,
callback_endpoint=cb_endpoint,
scope=scope,
shell_integration=shell_display)
Expand Down Expand Up @@ -198,6 +204,62 @@ def handle_oauth_register(self, message):
"error": e})
self.bus.emit(response)

def handle_oauth_refresh_token(self, message):
"""Refresh oauth token.
See:
https://www.oauth.com/oauth2-servers/making-authenticated-requests/refreshing-an-access-token/
for details on the procedure.
"""
response_data = {}
oauth_id = f"{message.data['skill_id']}_{message.data['app_id']}"
# Load all needed data for refresh
with self.oauth_db as db:
app_data = db.get(oauth_id)
with OAuthTokenDatabase() as db:
token_data = db.get(oauth_id)

if (app_data is None or
token_data is None or 'refresh_token' not in token_data):
LOG.warning("Token data doesn't contain a refresh token and "
"cannot be refreshed.")
response_data["result"] = "Error"
else:
refresh_token = token_data["refresh_token"]

# Fall back to token endpoint if no specific refresh endpoint
# has been set
token_endpoint = app_data["token_endpoint"]

client_id = app_data["client_id"]
client_secret = app_data["client_secret"]

# Perform refresh
client = WebApplicationClient(client_id, refresh_token=refresh_token)
uri, headers, body = client.prepare_refresh_token_request(token_endpoint)
refresh_result = requests.post(uri, headers=headers, data=body,
auth=(client_id, client_secret))

if refresh_result.ok:
new_token_data = refresh_result.json()
# Make sure 'expires_at' entry exists in token
if 'expires_at' not in new_token_data:
new_token_data['expires_at'] = time.time() + token_data['expires_in']
# Store token
with OAuthTokenDatabase() as db:
token_data.update(new_token_data)
db.update_token(oauth_id, token_data)
response_data = {"result": "Ok",
"client_id": client_id,
"token": token_data}
else:
LOG.error("Token refresh failed :(")
response_data["result"] = "Error"

response = message.response(response_data)
self.bus.emit(response)

def get_oauth_url(self, skill_id, app_id):
munged_id = f"{skill_id}_{app_id}" # key for oauth db

Expand Down

0 comments on commit d2eb383

Please sign in to comment.