diff --git a/HISTORY.md b/HISTORY.md index 7c87807..2dc65be 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,4 +1,8 @@ # History + +## 0.5.0 (Unreleased) +* Add support for floating point values for rate limits + ## 0.4.2 (2023-09-27) * Update conda-forge package to restrict pyrate-limiter to <3.0 diff --git a/pyproject.toml b/pyproject.toml index 751c3c0..95b51e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,5 +72,5 @@ ignore_missing_imports = true [tool.ruff] output-format = 'grouped' -line-length = 120 +line-length = 110 select = ['B', 'C4','C90', 'E', 'F'] diff --git a/requests_ratelimiter/requests_ratelimiter.py b/requests_ratelimiter/requests_ratelimiter.py index ca09274..0e0c10b 100644 --- a/requests_ratelimiter/requests_ratelimiter.py +++ b/requests_ratelimiter/requests_ratelimiter.py @@ -1,3 +1,4 @@ +from fractions import Fraction from inspect import signature from logging import getLogger from time import time @@ -42,7 +43,7 @@ def __init__( ): # Translate request rate values into RequestRate objects rates = [ - RequestRate(limit, interval) + _convert_rate(limit, interval) for interval, limit in { Duration.SECOND * burst: per_second * burst, Duration.MINUTE: per_minute, @@ -52,6 +53,11 @@ def __init__( }.items() if limit ] + if rates and not limiter: + logger.debug( + "Creating Limiter with rates:\n%s", + "\n".join([f"{r.limit}/{r.interval}s" for r in rates]), + ) # If using a persistent backend, we don't want to use monotonic time (the default) if bucket_class not in (MemoryListBucket, MemoryQueueBucket) and not time_function: @@ -69,7 +75,7 @@ def __init__( self._default_bucket = str(uuid4()) # If the superclass is an adapter or custom Session, pass along any valid keyword arguments - session_kwargs = get_valid_kwargs(super().__init__, kwargs) + session_kwargs = _get_valid_kwargs(super().__init__, kwargs) super().__init__(**session_kwargs) # type: ignore # Base Session doesn't take any kwargs # Conveniently, both Session.send() and HTTPAdapter.send() have a mostly consistent signature @@ -108,7 +114,7 @@ def _fill_bucket(self, request: PreparedRequest): If the server also has an hourly limit, we don't have enough information to know if we've exceeded that limit or how long to delay, so we'll keep delaying in 1-minute intervals. """ - logger.info(f'Rate limit exceeded for {request.url}; filling limiter bucket') + logger.info(f"Rate limit exceeded for {request.url}; filling limiter bucket") bucket = self.limiter.bucket_group[self._bucket_name(request)] # Determine how many requests we've made within the smallest defined time interval @@ -166,7 +172,16 @@ class LimiterAdapter(LimiterMixin, HTTPAdapter): # type: ignore # send signatu """ -def get_valid_kwargs(func: Callable, kwargs: Dict) -> Dict: +def _convert_rate(limit: float, interval: float) -> RequestRate: + """Handle fractional rate limits by converting to a whole number of requests per interval""" + # Convert both limit and interval to fractions, and adjust for floating point weirdness + f1 = Fraction(limit).limit_denominator(1000) + f2 = Fraction(interval).limit_denominator(1000) + rate_fraction = f1 / f2 + return RequestRate(rate_fraction.numerator, rate_fraction.denominator) + + +def _get_valid_kwargs(func: Callable, kwargs: Dict) -> Dict: """Get the subset of non-None ``kwargs`` that are valid params for ``func``""" sig_params = list(signature(func).parameters) return {k: v for k, v in kwargs.items() if k in sig_params and v is not None} diff --git a/test/test_requests_ratelimiter.py b/test/test_requests_ratelimiter.py index 9163beb..91e8565 100644 --- a/test/test_requests_ratelimiter.py +++ b/test/test_requests_ratelimiter.py @@ -14,13 +14,15 @@ from time import sleep from unittest.mock import patch +import pytest from pyrate_limiter import Duration, Limiter, RequestRate from requests import Response, Session from requests.adapters import HTTPAdapter from requests_ratelimiter import LimiterAdapter, LimiterMixin, LimiterSession +from requests_ratelimiter.requests_ratelimiter import _convert_rate -patch_sleep = patch('pyrate_limiter.limit_context_decorator.sleep', side_effect=sleep) +patch_sleep = patch("pyrate_limiter.limit_context_decorator.sleep", side_effect=sleep) rate = RequestRate(5, Duration.SECOND) @@ -38,7 +40,7 @@ def test_limiter_session(mock_sleep): @patch_sleep -@patch.object(HTTPAdapter, 'send') +@patch.object(HTTPAdapter, "send") def test_limiter_adapter(mock_send, mock_sleep): # To allow mounting a mock:// URL, we need to patch HTTPAdapter.send() # so it doesn't validate the protocol @@ -49,7 +51,7 @@ def test_limiter_adapter(mock_send, mock_sleep): session = Session() adapter = LimiterAdapter(per_second=5) - session.mount('http+mock://', adapter) + session.mount("http+mock://", adapter) for _ in range(5): session.get(MOCKED_URL) @@ -138,3 +140,19 @@ def test_limit_status_disabled(mock_sleep): session.get(MOCKED_URL_429) session.get(MOCKED_URL_429) assert mock_sleep.called is False + + +@pytest.mark.parametrize( + "limit, interval, expected_limit, expected_interval", + [ + (5, 1, 5, 1), + (0.5, 1, 1, 2), + (1, 0.5, 2, 1), + (0.1, 0.5, 1, 5), + (0.001, 0.05, 1, 50), + ], +) +def test_convert_rate(limit, interval, expected_limit, expected_interval): + rate = _convert_rate(limit, interval) + assert rate.limit == expected_limit + assert rate.interval == expected_interval