Skip to content

Commit

Permalink
Merge pull request #160 from colin99d/master
Browse files Browse the repository at this point in the history
feat: allow Requests to be sent to exempt_when
  • Loading branch information
laurentS authored Jun 27, 2024
2 parents 7769a13 + 42330fc commit a72bcc6
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 9 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Change Log

## [0.1.10] - 2024-06-04

### Changed

- Breaking change: allow usage of the request object in the except_when function (thanks @colin99d)

## [0.1.9] - 2024-02-05

### Added
Expand Down
13 changes: 7 additions & 6 deletions slowapi/extension.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
The starlette extension to rate-limit requests
"""

import asyncio
import functools
import inspect
Expand Down Expand Up @@ -486,7 +487,7 @@ def __evaluate_limits(
limit_for_header = None
for lim in limits:
limit_scope = lim.scope or endpoint
if lim.is_exempt:
if lim.is_exempt(request):
continue
if lim.methods is not None and request.method.lower() not in lim.methods:
continue
Expand Down Expand Up @@ -703,11 +704,9 @@ def decorator(func: Callable[..., Response]):
else:
self._route_limits.setdefault(name, []).extend(static_limits)

connection_type: Optional[str] = None
sig = inspect.signature(func)
for idx, parameter in enumerate(sig.parameters.values()):
if parameter.name == "request" or parameter.name == "websocket":
connection_type = parameter.name
break
else:
raise Exception(
Expand Down Expand Up @@ -736,7 +735,8 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Response:
if not isinstance(response, Response):
# get the response object from the decorated endpoint function
self._inject_headers(
kwargs.get("response"), request.state.view_rate_limit # type: ignore
kwargs.get("response"), # type: ignore
request.state.view_rate_limit,
)
else:
self._inject_headers(
Expand Down Expand Up @@ -768,7 +768,8 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Response:
if not isinstance(response, Response):
# get the response object from the decorated endpoint function
self._inject_headers(
kwargs.get("response"), request.state.view_rate_limit # type: ignore
kwargs.get("response"),
request.state.view_rate_limit, # type: ignore
)
else:
self._inject_headers(
Expand Down Expand Up @@ -805,7 +806,7 @@ def limit(
* **error_message**: string (or callable that returns one) to override the
error message used in the response.
* **exempt_when**: function returning a boolean indicating whether to exempt
the route from the limit
the route from the limit. This function can optionally use a Request object.
* **cost**: integer (or callable that returns one) which is the cost of a hit
* **override_defaults**: whether to override the default limits (default: True)
"""
Expand Down
18 changes: 15 additions & 3 deletions slowapi/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Callable, Iterator, List, Optional, Union

from limits import RateLimitItem, parse_many # type: ignore
from starlette.requests import Request


class Limit(object):
Expand All @@ -28,16 +29,27 @@ def __init__(
self.methods = methods
self.error_message = error_message
self.exempt_when = exempt_when
self._exempt_when_takes_request = (
self.exempt_when
and len(inspect.signature(self.exempt_when).parameters) == 1
)
self.cost = cost
self.override_defaults = override_defaults

@property
def is_exempt(self) -> bool:
def is_exempt(self, request: Optional[Request] = None) -> bool:
"""
Check if the limit is exempt.
** parameter **
* **request**: the request object
Return True to exempt the route from the limit.
"""
return self.exempt_when() if self.exempt_when is not None else False
if self.exempt_when is None:
return False
if self._exempt_when_takes_request and request:
return self.exempt_when(request)
return self.exempt_when()

@property
def scope(self) -> str:
Expand Down
55 changes: 55 additions & 0 deletions tests/test_starlette_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,61 @@ def t1(request: Request):
if i < 5:
assert response.text == "test"

def test_exempt_when_argument(self, build_starlette_app):
app, limiter = build_starlette_app(key_func=get_ipaddr)

def return_true():
return True

def return_false():
return False

def dynamic(request: Request):
user_agent = request.headers.get("User-Agent")
if user_agent is None:
return False
return user_agent == "exempt"

@limiter.limit("1/minute", exempt_when=return_true)
def always_true(request: Request):
return PlainTextResponse("test")

@limiter.limit("1/minute", exempt_when=return_false)
def always_false(request: Request):
return PlainTextResponse("test")

@limiter.limit("1/minute", exempt_when=dynamic)
def always_dynamic(request: Request):
return PlainTextResponse("test")

app.add_route("/true", always_true)
app.add_route("/false", always_false)
app.add_route("/dynamic", always_dynamic)

client = TestClient(app)
# Test always true always exempting
for i in range(0, 2):
response = client.get("/true")
assert response.status_code == 200
assert response.text == "test"
# Test always false hitting the limit after one hit
for i in range(0, 2):
response = client.get("/false")
assert response.status_code == 200 if i < 1 else 429
if i < 1:
assert response.text == "test"
# Test dynamic not exempting with the correct header
for i in range(0, 2):
response = client.get("/dynamic", headers={"User-Agent": "exempt"})
assert response.status_code == 200
assert response.text == "test"
# Test dynamic exempting with the incorrect header
for i in range(0, 2):
response = client.get("/dynamic")
assert response.status_code == 200 if i < 1 else 429
if i < 1:
assert response.text == "test"

def test_shared_decorator(self, build_starlette_app):
app, limiter = build_starlette_app(key_func=get_ipaddr)

Expand Down

0 comments on commit a72bcc6

Please sign in to comment.