Skip to content

Commit

Permalink
Merge pull request #1 from adriendelsalle/user-credentials
Browse files Browse the repository at this point in the history
Pass user model to RepoProvider
  • Loading branch information
rprimet authored Feb 23, 2021
2 parents 2c09d72 + 44ef87f commit 807a300
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 6 deletions.
43 changes: 39 additions & 4 deletions binderhub/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
"""Base classes for request handlers"""

import json
import os

from http.client import responses
from tornado import web
from tornado.log import app_log
from tornado.httpclient import AsyncHTTPClient, HTTPRequest, HTTPError

from jupyterhub.services.auth import HubOAuthenticated, HubOAuth

from . import __version__ as binder_version
Expand All @@ -16,6 +20,7 @@ def initialize(self):
super().initialize()
if self.settings['auth_enabled']:
self.hub_auth = HubOAuth.instance(config=self.settings['traitlets_config'])
self.current_user_model = None

def get_current_user(self):
if not self.settings['auth_enabled']:
Expand Down Expand Up @@ -43,16 +48,46 @@ def get_spec_from_request(self, prefix):
spec = self.request.path[idx + len(prefix) + 1:]
return spec

def get_provider(self, provider_prefix, spec):
async def get_provider(self, provider_prefix, spec):
"""Construct a provider object"""
providers = self.settings['repo_providers']
if provider_prefix not in providers:
raise web.HTTPError(404, "No provider found for prefix %s" % provider_prefix)

async def api_request(url, *args, **kwargs):
headers = kwargs.setdefault('headers', {})
headers.update({'Authorization': 'token %s' % self.hub_auth.api_token})
hub_api_url = os.getenv('JUPYTERHUB_API_URL', '') or self.hub_auth.api_url
request_url = hub_api_url + url
req = HTTPRequest(request_url, *args, **kwargs)

try:
return await AsyncHTTPClient().fetch(req)
except HTTPError as e:
app_log.error("Error accessing Hub API (using %s): %s", request_url, e)

async def get_current_user_model():
"""Get the current user model.
The user auth_state is only accessible to admin users.
"""
if not self.settings['auth_enabled']:
return None

if self.current_user_model is None:
username = self.get_current_user()['name']
resp = await api_request(
f'/users/{username}',
method='GET',
)
self.current_user_model = json.loads(resp.body.decode('utf-8'))

return self.current_user_model

return providers[provider_prefix](
config=self.settings['traitlets_config'],
spec=spec,
handler=self)
config=self.settings['traitlets_config'],
spec=spec,
user_model=await get_current_user_model()
)

def get_badge_base_url(self):
badge_base_url = self.settings['badge_base_url']
Expand Down
2 changes: 1 addition & 1 deletion binderhub/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ async def get(self, provider_prefix, _unescaped_spec):

# get a provider object that encapsulates the provider and the spec
try:
provider = self.get_provider(provider_prefix, spec=spec)
provider = await self.get_provider(provider_prefix, spec=spec)
except Exception as e:
app_log.exception("Failed to get provider for %s", key)
await self.fail(str(e))
Expand Down
2 changes: 1 addition & 1 deletion binderhub/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async def get(self, provider_prefix, _unescaped_spec):
spec = self.get_spec_from_request(prefix)
spec = spec.rstrip("/")
try:
self.get_provider(provider_prefix, spec=spec)
await self.get_provider(provider_prefix, spec=spec)
except HTTPError:
raise
except Exception as e:
Expand Down

0 comments on commit 807a300

Please sign in to comment.