Skip to content

Commit

Permalink
Merge pull request jupyterhub#4608 from minrk/oauth-state-cookie
Browse files Browse the repository at this point in the history
move service oauth state from cookies to memory
  • Loading branch information
consideRatio authored Jan 25, 2024
2 parents 41a2e29 + 2c7fe93 commit 041acbc
Show file tree
Hide file tree
Showing 2 changed files with 195 additions and 70 deletions.
192 changes: 143 additions & 49 deletions jupyterhub/services/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,15 @@
"""
import asyncio
import base64
import hashlib
import json
import os
import random
import re
import secrets
import socket
import string
import time
import uuid
import warnings
from http import HTTPStatus
from unittest import mock
Expand Down Expand Up @@ -106,14 +105,24 @@ class _ExpiringDict(dict):
"""

max_age = 0
purge_interval = 0

def __init__(self, max_age=0):
def __init__(self, max_age=0, purge_interval="max_age"):
self.max_age = max_age
if purge_interval == "max_age":
# default behavior: use max_age
purge_interval = max_age
self.purge_interval = purge_interval
self.timestamps = {}
self.values = {}
self._last_purge = time.monotonic()

def __len__(self):
return len(self.values)

def __setitem__(self, key, value):
"""Store key and record timestamp"""
self._maybe_purge()
self.timestamps[key] = time.monotonic()
self.values[key] = value

Expand All @@ -139,6 +148,7 @@ def _check_age(self, key):
if self.max_age > 0 and timestamp + self.max_age < now:
self.values.pop(key)
self.timestamps.pop(key)
self._maybe_purge()

def __contains__(self, key):
"""dict check for `key in dict`"""
Expand All @@ -150,17 +160,57 @@ def __getitem__(self, key):
self._check_age(key)
return self.values[key]

def __delitem__(self, key):
del self.values[key]
del self.timestamps[key]

def get(self, key, default=None):
"""dict-like get:"""
"""dict-like get"""
try:
return self[key]
except KeyError:
return default

def pop(self, key, default="_raise"):
"""Remove and return an item"""
if key in self:
value = self.values.pop(key)
del self.timestamps[key]
return value
else:
if default == "_raise":
raise KeyError(key)
else:
return default

def clear(self):
"""Clear the cache"""
self.values.clear()
self.timestamps.clear()
self._last_purge = time.monotonic()

# extended methods
def _maybe_purge(self):
"""purge expired values _if_ it's been purge_interval since the last purge
Called on every get/set, to keep the expired values clear.
"""
if not self.purge_interval > 0:
return
now = time.monotonic()
if self._last_purge < (now - self.purge_interval):
self.purge_expired()

def purge_expired(self):
"""Purge all expired values"""
if not self.max_age > 0:
return
now = self._last_purge = time.monotonic()
cutoff = now - self.max_age
for key in list(self.timestamps):
timestamp = self.timestamps[key]
if timestamp < cutoff:
del self[key]


class HubAuth(SingletonConfigurable):
Expand Down Expand Up @@ -854,37 +904,32 @@ async def _token_for_code(self, code):

return token_reply['access_token']

def _encode_state(self, state):
"""Encode a state dict as url-safe base64"""
# trim trailing `=` because = is itself not url-safe!
json_state = json.dumps(state)
return (
base64.urlsafe_b64encode(json_state.encode('utf8'))
.decode('ascii')
.rstrip('=')
)
# state-related

def _decode_state(self, b64_state):
"""Decode a base64 state
oauth_state_max_age = Integer(
600,
config=True,
help="""Max age (seconds) of oauth state.
Governs both oauth state cookie Max-Age,
as well as the in-memory _oauth_states cache.
""",
)
_oauth_states = Instance(
_ExpiringDict,
allow_none=False,
help="""
Store oauth state info for each oauth request, such as next_url
The oauth state field only contains the oauth state _id_ of each pending request,
while other information such as the cookie name, next_url
are stored in this dictionary.
""",
)

Always returns a dict.
The dict will be empty if the state is invalid.
"""
if isinstance(b64_state, str):
b64_state = b64_state.encode('ascii')
if len(b64_state) != 4:
# restore padding
b64_state = b64_state + (b'=' * (4 - len(b64_state) % 4))
try:
json_state = base64.urlsafe_b64decode(b64_state).decode('utf8')
except ValueError:
app_log.error("Failed to b64-decode state: %r", b64_state)
return {}
try:
return json.loads(json_state)
except ValueError:
app_log.error("Failed to json-decode state: %r", json_state)
return {}
@default('_oauth_states')
def _default_oauth_states(self):
return _ExpiringDict(max_age=self.oauth_state_max_age)

def set_state_cookie(self, handler, next_url=None):
"""Generate an OAuth state and store it in a cookie
Expand Down Expand Up @@ -914,7 +959,7 @@ def set_state_cookie(self, handler, next_url=None):
extra_state['cookie_name'] = cookie_name
else:
cookie_name = self.state_cookie_name
b64_state = self.generate_state(next_url, **extra_state)
state_id = self.generate_state(next_url, **extra_state)
kwargs = {
'path': self.base_url,
'httponly': True,
Expand All @@ -931,39 +976,69 @@ def set_state_cookie(self, handler, next_url=None):
else:
if get_browser_protocol(handler.request) == 'https':
kwargs['secure'] = True

# don't allow overriding some fields
no_override_keys = set(kwargs.keys()) | {"expires_days", "expires"}

# load user cookie overrides
kwargs.update(self.cookie_options)
handler.set_secure_cookie(cookie_name, b64_state, **kwargs)
return b64_state
for key, value in self.cookie_options.items():
# don't include overrides
if key.lower() not in no_override_keys:
kwargs[key] = value
handler.set_secure_cookie(cookie_name, state_id, **kwargs)
return state_id

def generate_state(self, next_url=None, **extra_state):
"""Generate a state string, given a next_url redirect target
The state info is stored locally in self._oauth_states,
and only the state id is returned for use in the oauth state field (cookie, redirect param)
Parameters
----------
next_url : str
The URL of the page to redirect to on successful login.
Returns
-------
state (str): The base64-encoded state string.
state_id (str): The state string to be used as a cookie value.
"""
state = {'uuid': uuid.uuid4().hex, 'next_url': next_url}
state.update(extra_state)
return self._encode_state(state)
state_id = secrets.token_urlsafe(16)
state = {'next_url': next_url}
if extra_state:
state.update(extra_state)
self._oauth_states[state_id] = state
return state_id

def clear_oauth_state(self, state_id):
"""Clear persisted oauth state"""
self._oauth_states.pop(state_id, None)
self._oauth_states.purge_expired()

def clear_oauth_state_cookies(self, handler):
"""Clear persisted oauth state"""
for cookie_name, cookie in handler.request.cookies.items():
if cookie_name.startswith(self.state_cookie_name):
handler.clear_cookie(
cookie_name,
path=self.base_url,
)

def get_next_url(self, b64_state=''):
def _decode_state(self, state_id, /):
return self._oauth_states.get(state_id, {})

def get_next_url(self, state_id='', /):
"""Get the next_url for redirection, given an encoded OAuth state"""
state = self._decode_state(b64_state)
state = self._decode_state(state_id)
return state.get('next_url') or self.base_url

def get_state_cookie_name(self, b64_state=''):
def get_state_cookie_name(self, state_id='', /):
"""Get the cookie name for oauth state, given an encoded OAuth state
Cookie name is stored in the state itself because the cookie name
is randomized to deal with races between concurrent oauth sequences.
"""
state = self._decode_state(b64_state)
state = self._decode_state(state_id)
return state.get('cookie_name') or self.state_cookie_name

def set_cookie(self, handler, access_token):
Expand Down Expand Up @@ -1246,23 +1321,42 @@ async def get(self):

code = self.get_argument("code", False)
if not code:
raise HTTPError(400, "oauth callback made without a token")
raise HTTPError(400, "OAuth callback made without a token")

# validate OAuth state
arg_state = self.get_argument("state", None)
if arg_state is None:
raise HTTPError(500, "oauth state is missing. Try logging in again.")
raise HTTPError(400, "OAuth state is missing. Try logging in again.")
cookie_name = self.hub_auth.get_state_cookie_name(arg_state)
cookie_state = self.get_secure_cookie(cookie_name)
# clear cookie state now that we've consumed it
self.clear_cookie(cookie_name, path=self.hub_auth.base_url)
if cookie_state:
self.clear_cookie(cookie_name, path=self.hub_auth.base_url)
else:
# completing oauth with stale state, but already logged in.
# stop here and redirect to default URL
# don't complete oauth (no new token), but do complete redirecting to the destination
if self.current_user:
app_log.warning("Attempting oauth completion after already logging in.")
self.hub_auth.clear_oauth_state_cookies(self)
next_url = self.hub_auth.get_next_url(arg_state)
self.redirect(next_url)
return

if isinstance(cookie_state, bytes):
cookie_state = cookie_state.decode('ascii', 'replace')

# check that state matches
if arg_state != cookie_state:
app_log.warning("oauth state %r != %r", arg_state, cookie_state)
raise HTTPError(403, "oauth state does not match. Try logging in again.")
raise HTTPError(403, "OAuth state does not match. Try logging in again.")
next_url = self.hub_auth.get_next_url(cookie_state)
# clear consumed state from _oauth_states cache now that we're done with it
self.hub_auth.clear_oauth_state(cookie_state)
# clear _all_ oauth state cookies on success
# This prevents multiple concurrent logins in the same browser,
# which is probably okay.
self.hub_auth.clear_oauth_state_cookies(self)

token = await self.hub_auth.token_for_code(code, sync=False)
session_id = self.hub_auth.get_session_id(self)
Expand Down
Loading

0 comments on commit 041acbc

Please sign in to comment.