diff --git a/binderhub/app.py b/binderhub/app.py index ed2dbe5a1..0ba0298a2 100644 --- a/binderhub/app.py +++ b/binderhub/app.py @@ -47,7 +47,7 @@ from .health import HealthHandler from .launcher import Launcher from .log import log_request -from .ratelimit import RateLimiter +from .ratelimit import RequestRateLimiter from .repoproviders import RepoProvider from .registry import DockerRegistry from .main import MainHandler, ParameterizedMainHandler, LegacyRedirectHandler @@ -329,6 +329,18 @@ def _valid_badge_base_url(self, proposal): config=True, ) + rate_limit_url = Unicode( + config=True, + help="""Use external rate-limiter service + + Allows shared rate-limit state across a federation of BinderHub instances + """, + ) + + rate_limit_token = Unicode( + config=True, help="""Token used to access external rate limit service""" + ) + log_tail_lines = Integer( 100, help=""" @@ -788,7 +800,8 @@ def initialize(self, *args, **kwargs): "per_repo_quota": self.per_repo_quota, "per_repo_quota_higher": self.per_repo_quota_higher, "repo_providers": self.repo_providers, - "rate_limiter": RateLimiter(parent=self), + "rate_limit_url": self.rate_limit_url, + "rate_limit_token": self.rate_limit_token, "use_registry": self.use_registry, "build_class": self.build_class, "registry": registry, @@ -814,6 +827,12 @@ def initialize(self, *args, **kwargs): "normalized_origin": self.normalized_origin, } ) + if not self.rate_limit_url: + self.tornado_settings["rate_limiters"] = { + "request": RequestRateLimiter(parent=self), + "repo": RepoRateLimiter(parent=self), + } + if self.auth_enabled: self.tornado_settings['cookie_secret'] = os.urandom(32) if self.cors_allow_origin: diff --git a/binderhub/base.py b/binderhub/base.py index 1412ccc75..d3ff61530 100644 --- a/binderhub/base.py +++ b/binderhub/base.py @@ -6,12 +6,13 @@ import jwt from http.client import responses from tornado import web +from tornado.httpclient import AsyncHTTPClient, HTTPClientError from tornado.log import app_log from jupyterhub.services.auth import HubOAuthenticated, HubOAuth from . import __version__ as binder_version from .ratelimit import RateLimitExceeded -from .utils import ip_in_networks +from .utils import ip_in_networks, url_path_join class BaseHandler(HubOAuthenticated, web.RequestHandler): @@ -98,12 +99,51 @@ def check_build_token(self, build_token, provider_spec): self._have_build_token = True return decoded - def check_rate_limit(self): - rate_limiter = self.settings["rate_limiter"] - if rate_limiter.limit == 0: - # no limit enabled - return + async def _check_rate_limit(self, which, key, quota=None): + + if self.settings["rate_limit_url"]: + # check with remote rate-limiter service + # defined in binderhub/ratelimitapp + if quota: + body = json.dumps({"quota": quota}) + else: + body = "" + try: + response = await AsyncHTTPClient().fetch( + url_path_join(self.settings["rate_limit_url"], which, key), + method="POST", + body=body, + headers={ + "Authorization": f"Bearer {self.settings['rate_limit_token']}", + }, + ) + limit = json.loads(response.body)["limit"] + except HTTPClientError as e: + if e.code == 429: + # turn remote 429 back into RateLimitExceeded + response = json.loads(e.response.body) + raise RateLimitExceeded(response["message"]) + else: + app_log.warning(f"Failed to check external rate limit: {e}") + return + else: + # check with internal rate limiter + rate_limiter = self.settings["rate_limiters"][which] + if rate_limiter.limit == 0: + # no limit enabled + return + + limit = rate_limiter.increment(key, quota) + + app_log.debug(f"Rate limit for {which}/{key}: {limit}") + + self.set_header("x-ratelimit-remaining", str(limit["remaining"])) + self.set_header("x-ratelimit-reset", str(limit["reset"])) + self.set_header("x-ratelimit-limit", str(limit["limit"])) + return limit + + async def check_request_rate_limit(self): if self.settings['auth_enabled'] and self.current_user: # authenticated, no limit # TODO: separate authenticated limit @@ -116,19 +156,12 @@ def check_rate_limit(self): # rate limit is applied per-ip request_ip = self.request.remote_ip - try: - limit = rate_limiter.increment(request_ip) - except RateLimitExceeded: - raise web.HTTPError( - 429, - f"Rate limit exceeded. Try again in {rate_limiter.period_seconds} seconds.", - ) - else: - app_log.debug(f"Rate limit for {request_ip}: {limit}") + return self._check_rate_limit("request", request_ip) - self.set_header("x-ratelimit-remaining", str(limit["remaining"])) - self.set_header("x-ratelimit-reset", str(limit["reset"])) - self.set_header("x-ratelimit-limit", str(rate_limiter.limit)) + def check_repo_rate_limit(self, repo_url, quota): + return self._check_rate_limit( + "repo", urllib.parse.quote(repo_url, safe=""), quota + ) def get_current_user(self): if not self.settings['auth_enabled']: diff --git a/binderhub/builder.py b/binderhub/builder.py index fb3d24a4f..28207ab7c 100644 --- a/binderhub/builder.py +++ b/binderhub/builder.py @@ -14,7 +14,7 @@ import docker from tornado import gen from tornado.httpclient import HTTPClientError -from tornado.web import Finish, authenticated +from tornado.web import Finish, HTTPError, authenticated from tornado.queues import Queue from tornado.iostream import StreamClosedError from tornado.ioloop import IOLoop @@ -23,6 +23,7 @@ from .base import BaseHandler from .build import ProgressEvent +from .ratelimit import RateLimitExceeded from .utils import KUBE_REQUEST_TIMEOUT # Separate buckets for builds and launches. @@ -242,7 +243,6 @@ async def get(self, provider_prefix, _unescaped_spec): # verify the build token and rate limit build_token = self.get_argument("build_token", None) self.check_build_token(build_token, f"{provider_prefix}/{spec}") - self.check_rate_limit() # Verify if the provider is valid for EventSource. # EventSource cannot handle HTTP errors, so we must validate and send @@ -280,6 +280,30 @@ async def get(self, provider_prefix, _unescaped_spec): 'repo': repo_url, } + # check request (client ip) rate limit + try: + await self.check_request_rate_limit() + except RateLimitExceeded as e: + LAUNCH_COUNT.labels( + status="request_quota", + ).inc() + raise HTTPError(429, str(e)) + + # check repo rate limit + repo_config = provider.repo_config(self.settings) + # TODO: put busy users in a queue rather than fail? + # That would be hard to do without in-memory state. + try: + await self.check_repo_rate_limit(repo_url, quota=repo_config.get("quota")) + except RateLimitExceeded as e: + LAUNCH_COUNT.labels( + status="repo_quota", + **self.repo_metric_labels, + ).inc() + app_log.error(str(e)) + await self.fail(f"Too many users running {self.repo_url}! Try again soon.") + return + try: ref = await provider.get_resolved_ref() except Exception as e: @@ -511,20 +535,8 @@ async def get(self, provider_prefix, _unescaped_spec): async def launch(self, provider): """Ask JupyterHub to launch the image.""" - # Load the spec-specific configuration if it has been overridden - repo_config = provider.repo_config(self.settings) - - # the image name (without tag) is unique per repo - # use this to count the number of pods running with a given repo - # if we added annotations/labels with the repo name via KubeSpawner - # we could do this better - image_no_tag = self.image_name.rsplit(':', 1)[0] - - # TODO: put busy users in a queue rather than fail? - # That would be hard to do without in-memory state. - repo_quota = repo_config.get("quota") pod_quota = self.settings["pod_quota"] - if pod_quota is not None or repo_quota: + if pod_quota is not None: # Fetch info on currently running users *only* if quotas are set matching_pods = 0 @@ -559,25 +571,12 @@ async def launch(self, provider): matching_pods += 1 break - if repo_quota and matching_pods >= repo_quota: - LAUNCH_COUNT.labels( - status="repo_quota", - **self.repo_metric_labels, - ).inc() - app_log.error( - f"{self.repo_url} has exceeded quota: {matching_pods}/{repo_quota} ({total_pods} total)" - ) - await self.fail( - f"Too many users running {self.repo_url}! Try again soon." - ) - return - - if matching_pods >= 0.5 * repo_quota: - log = app_log.warning - else: - log = app_log.info - log("Launching pod for %s: %s other pods running this repo (%s total)", - self.repo_url, matching_pods, total_pods) + app_log.info( + "Launching pod for %s: %s other pods running this repo (%s total)", + self.repo_url, + matching_pods, + total_pods, + ) await self.emit({ 'phase': 'launching', diff --git a/binderhub/ratelimit.py b/binderhub/ratelimit.py index 5b79a492e..25b087267 100644 --- a/binderhub/ratelimit.py +++ b/binderhub/ratelimit.py @@ -4,7 +4,6 @@ from traitlets import Integer, Dict, Float, default from traitlets.config import LoggingConfigurable - class RateLimitExceeded(Exception): """Exception raised when rate limit is exceeded""" @@ -66,11 +65,17 @@ def time(): """Mostly here to enable override in tests""" return time.time() - def increment(self, key): + def increment(self, key, quota=None): """Check rate limit for a key key: key for recording rate limit. Each key tracks a different rate limit. - Returns: {"remaining": int_remaining, "reset": int_timestamp} + Returns: + { + "remaining": int_remaining, + "limit": int_total_limit, + "reset": int_timestamp, + "reset_in": int_seconds_remaining, + } Raises: RateLimitExceeded if the request would exceed the rate limit. """ now = int(self.time()) @@ -79,17 +84,53 @@ def increment(self, key): if key not in self._limits or self._limits[key]["reset"] < now: # no limit recorded, or reset expired + max_limit = quota or self.limit self._limits[key] = { - "remaining": self.limit, + "remaining": max_limit, + "limit": max_limit, "reset": now + self.period_seconds, } limit = self._limits[key] # keep decrementing, so we have a track of excess requests # which indicate abuse limit["remaining"] -= 1 + reset_in = int(limit["reset"] - now) if limit["remaining"] < 0: - seconds_until_reset = int(limit["reset"] - now) raise RateLimitExceeded( - f"Rate limit exceeded (by {-limit['remaining']}) for {key!r}, reset in {seconds_until_reset}s." + f"Rate limit exceeded (by {-limit['remaining']}) for {key!r}, reset in {reset_in}s." ) - return limit + limit_copy = limit.copy() + limit_copy["reset_in"] = reset_in + return limit_copy + + +# classes for storing separate config + + +class RequestRateLimiter(RateLimiter): + """RateLimiter subclass for client requests + + Rate limit is applied to launch requests by client ip + """ + + +class RepoRateLimiter(RateLimiter): + """RateLimiter subclass for repo launches + + Rate limit is applied to launch requests by repo + """ + + @default("limit") + def _defaul_limit(self): + # default: no limit + return 0 + + +def main(): + from .ratelimitapp import main + + main() + + +if __name__ == "__main__": + main() diff --git a/binderhub/ratelimitapp.py b/binderhub/ratelimitapp.py new file mode 100644 index 000000000..d06de647c --- /dev/null +++ b/binderhub/ratelimitapp.py @@ -0,0 +1,175 @@ +"""Application for external managing of rate limits + +Use in combination with BinderHub.rate_limit_url +and BinderHub.rate_limit_token +""" + +import json +import logging +import os +import re +from http.client import responses + +import tornado.log +import tornado.options +from tornado import ioloop, web +from traitlets import Integer, Set, Unicode, default +from traitlets.config import Application + +from .base import BaseHandler +from .ratelimit import RateLimitExceeded, RepoRateLimiter, RequestRateLimiter + +_auth_header_pat = re.compile(r"^(?:token|bearer)\s+([^\s]+)$", flags=re.IGNORECASE) + + +class RateLimitHandler(BaseHandler): + """API endpoint for external storage of rate limits""" + + def initialize(self, tokens, rate_limiters): + self.rate_limit_tokens = tokens + self.rate_limiters = rate_limiters + self.log = tornado.log.app_log + + def get_current_user(self): + """Authenticate rate limit requests with tokens""" + auth_header = self.request.headers.get("Authorization") + if not auth_header: + return + match = _auth_header_pat.match(auth_header) + if not match: + return None + token = match.group(1) + return f"token-{token[:3]}..." + + def set_default_headers(self): + super().set_default_headers() + self.set_header("Content-Type", "application/json") + + def write_error(self, status_code, **kwargs): + exc_info = kwargs.get("exc_info") + message = "" + status_message = responses.get(status_code, "Unknown HTTP Error") + if exc_info: + message = self.extract_message(exc_info) + if not message: + message = status_message + self.write(json.dumps({"message": message})) + + @web.authenticated + def post(self, which, key): + """Increment rate limit of kind `which` for key `key`""" + initial_limit = None + if self.request.body: + try: + body = json.loads(self.request.body) + initial_limit = body.get("limit", None) + except Exception: + raise web.HTTPError( + 400, + f"Rate limit body must be a dict with a 'limit' key, got {self.request.body}", + ) + if not (initial_limit is None or isinstance(initial_limit, int)): + raise web.HTTPError( + 400, f"limit must be null or a number, not {initial_limit}" + ) + + try: + rate_limiter = self.rate_limiters[which] + except KeyError: + raise web.HTTPError(404, f"No such rate limit: {which}") + + if rate_limiter.limit == 0: + # no limit + self.write( + json.dumps( + { + "limit": { + "limit": 0, + "remaining": 0, + "reset": 0, + "reset_in": 0, + } + } + ) + ) + return + + try: + limit = rate_limiter.increment(key, initial_limit) + except RateLimitExceeded as e: + raise web.HTTPError(429, str(e)) + + self.log.debug(f"Rate limit for {which}/{key}: {limit}") + + self.write(json.dumps({"limit": limit})) + + +aliases = {} +aliases.update(Application.aliases) +aliases.update( + { + "ip": "RateLimitApp.ip", + "port": "RateLimitApp.port", + } +) + + +class RateLimitApp(Application): + # load the same config files as binderhub itself + name = "binderhub" + + aliases = aliases + classes = [RepoRateLimiter, RequestRateLimiter] + + ip = Unicode("", config=True) + port = Integer(8888, config=True) + + tokens = Set(config=True) + + @default("tokens") + def _default_tokens(self): + tokens = set(os.environ.get("RATE_LIMIT_TOKENS", "").strip().split(";")) + tokens.discard("") + return tokens + + def initialize(self, argv=None): + super().initialize(argv) + # hook up tornado logging + tornado.options.options.logging = logging.getLevelName(self.log_level) + tornado.log.enable_pretty_logging() + self.log = tornado.log.app_log + + self.rate_limiters = { + "repo": RepoRateLimiter(parent=self), + "request": RequestRateLimiter(parent=self), + } + + def start(self, run_loop=True): + if not self.tokens: + self.exit("Need to set one of $RATE_LIMIT_TOKENS or c.RateLimitApp.tokens.") + web_app = web.Application( + [ + ( + "/([^/]+)/(.+)", + RateLimitHandler, + { + "tokens": self.tokens, + "rate_limiters": self.rate_limiters, + }, + ) + ] + ) + self.http_server = web.HTTPServer( + web_app, + xheaders=True, + ) + self.log.info(f"Rate limiter listening on {self.ip}:{self.port}") + self.http_server.listen(self.port, self.ip) + if run_loop: + ioloop.IOLoop.current().start() + + +main = RateLimitApp.launch_instance + +if __name__ == "__main__": + main()