Skip to content

Commit

Permalink
Merge pull request #88 from JWCook/fractional-ratelimits
Browse files Browse the repository at this point in the history
Add support for floating point values for rate limits
  • Loading branch information
JWCook authored Jan 29, 2024
2 parents 3e9a833 + 52d7864 commit 46c8d73
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 8 deletions.
4 changes: 4 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# History

## 0.5.0 (Unreleased)
* Add support for floating point values for rate limits

## 0.4.2 (2023-09-27)
* Update conda-forge package to restrict pyrate-limiter to <3.0

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,5 @@ ignore_missing_imports = true

[tool.ruff]
output-format = 'grouped'
line-length = 120
line-length = 110
select = ['B', 'C4','C90', 'E', 'F']
23 changes: 19 additions & 4 deletions requests_ratelimiter/requests_ratelimiter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from fractions import Fraction
from inspect import signature
from logging import getLogger
from time import time
Expand Down Expand Up @@ -42,7 +43,7 @@ def __init__(
):
# Translate request rate values into RequestRate objects
rates = [
RequestRate(limit, interval)
_convert_rate(limit, interval)
for interval, limit in {
Duration.SECOND * burst: per_second * burst,
Duration.MINUTE: per_minute,
Expand All @@ -52,6 +53,11 @@ def __init__(
}.items()
if limit
]
if rates and not limiter:
logger.debug(
"Creating Limiter with rates:\n%s",
"\n".join([f"{r.limit}/{r.interval}s" for r in rates]),
)

# If using a persistent backend, we don't want to use monotonic time (the default)
if bucket_class not in (MemoryListBucket, MemoryQueueBucket) and not time_function:
Expand All @@ -69,7 +75,7 @@ def __init__(
self._default_bucket = str(uuid4())

# If the superclass is an adapter or custom Session, pass along any valid keyword arguments
session_kwargs = get_valid_kwargs(super().__init__, kwargs)
session_kwargs = _get_valid_kwargs(super().__init__, kwargs)
super().__init__(**session_kwargs) # type: ignore # Base Session doesn't take any kwargs

# Conveniently, both Session.send() and HTTPAdapter.send() have a mostly consistent signature
Expand Down Expand Up @@ -108,7 +114,7 @@ def _fill_bucket(self, request: PreparedRequest):
If the server also has an hourly limit, we don't have enough information to know if we've
exceeded that limit or how long to delay, so we'll keep delaying in 1-minute intervals.
"""
logger.info(f'Rate limit exceeded for {request.url}; filling limiter bucket')
logger.info(f"Rate limit exceeded for {request.url}; filling limiter bucket")
bucket = self.limiter.bucket_group[self._bucket_name(request)]

# Determine how many requests we've made within the smallest defined time interval
Expand Down Expand Up @@ -166,7 +172,16 @@ class LimiterAdapter(LimiterMixin, HTTPAdapter): # type: ignore # send signatu
"""


def get_valid_kwargs(func: Callable, kwargs: Dict) -> Dict:
def _convert_rate(limit: float, interval: float) -> RequestRate:
"""Handle fractional rate limits by converting to a whole number of requests per interval"""
# Convert both limit and interval to fractions, and adjust for floating point weirdness
f1 = Fraction(limit).limit_denominator(1000)
f2 = Fraction(interval).limit_denominator(1000)
rate_fraction = f1 / f2
return RequestRate(rate_fraction.numerator, rate_fraction.denominator)


def _get_valid_kwargs(func: Callable, kwargs: Dict) -> Dict:
"""Get the subset of non-None ``kwargs`` that are valid params for ``func``"""
sig_params = list(signature(func).parameters)
return {k: v for k, v in kwargs.items() if k in sig_params and v is not None}
24 changes: 21 additions & 3 deletions test/test_requests_ratelimiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
from time import sleep
from unittest.mock import patch

import pytest
from pyrate_limiter import Duration, Limiter, RequestRate
from requests import Response, Session
from requests.adapters import HTTPAdapter

from requests_ratelimiter import LimiterAdapter, LimiterMixin, LimiterSession
from requests_ratelimiter.requests_ratelimiter import _convert_rate

patch_sleep = patch('pyrate_limiter.limit_context_decorator.sleep', side_effect=sleep)
patch_sleep = patch("pyrate_limiter.limit_context_decorator.sleep", side_effect=sleep)
rate = RequestRate(5, Duration.SECOND)


Expand All @@ -38,7 +40,7 @@ def test_limiter_session(mock_sleep):


@patch_sleep
@patch.object(HTTPAdapter, 'send')
@patch.object(HTTPAdapter, "send")
def test_limiter_adapter(mock_send, mock_sleep):
# To allow mounting a mock:// URL, we need to patch HTTPAdapter.send()
# so it doesn't validate the protocol
Expand All @@ -49,7 +51,7 @@ def test_limiter_adapter(mock_send, mock_sleep):

session = Session()
adapter = LimiterAdapter(per_second=5)
session.mount('http+mock://', adapter)
session.mount("http+mock://", adapter)

for _ in range(5):
session.get(MOCKED_URL)
Expand Down Expand Up @@ -138,3 +140,19 @@ def test_limit_status_disabled(mock_sleep):
session.get(MOCKED_URL_429)
session.get(MOCKED_URL_429)
assert mock_sleep.called is False


@pytest.mark.parametrize(
"limit, interval, expected_limit, expected_interval",
[
(5, 1, 5, 1),
(0.5, 1, 1, 2),
(1, 0.5, 2, 1),
(0.1, 0.5, 1, 5),
(0.001, 0.05, 1, 50),
],
)
def test_convert_rate(limit, interval, expected_limit, expected_interval):
rate = _convert_rate(limit, interval)
assert rate.limit == expected_limit
assert rate.interval == expected_interval

0 comments on commit 46c8d73

Please sign in to comment.