From 87a3caa3a20fdb946976223d063e20f2d673d8f5 Mon Sep 17 00:00:00 2001
From: Roman Glushko <roman.glushko.m@gmail.com>
Date: Tue, 7 Nov 2023 17:15:14 +0200
Subject: [PATCH] #68 Events in Rate Limiters

---
 hyx/events.py                      |  8 ++++++
 hyx/ratelimit/api.py               | 26 ++++++++++++++++---
 hyx/ratelimit/buckets.py           |  0
 hyx/ratelimit/events.py            | 21 +++++++++++++++
 hyx/ratelimit/managers.py          | 10 +++++++-
 hyx/retry/api.py                   |  1 +
 tests/test_ratelimiter/test_api.py | 41 +++++++++++++++++++++++++++---
 7 files changed, 99 insertions(+), 8 deletions(-)
 create mode 100644 hyx/ratelimit/buckets.py
 create mode 100644 hyx/ratelimit/events.py

diff --git a/hyx/events.py b/hyx/events.py
index 96900fe..1f3f7bf 100644
--- a/hyx/events.py
+++ b/hyx/events.py
@@ -172,6 +172,14 @@ async def _get_or_init_listeners(self) -> List[ListenerT]:
         return self._inited_listeners
 
 
+class NoOpEventDispatcher(EventDispatcher):
+    def __getattr__(self, event_handler_name: str) -> Callable:
+        async def handle_event(*args, **kwargs) -> None:
+            pass
+
+        return handle_event
+
+
 def get_default_name(func: Optional[Callable] = None) -> str:
     """
     Get the default name of the component based on code context where it's being used
diff --git a/hyx/ratelimit/api.py b/hyx/ratelimit/api.py
index 9ece3b7..9d5c60c 100644
--- a/hyx/ratelimit/api.py
+++ b/hyx/ratelimit/api.py
@@ -1,7 +1,9 @@
 import functools
 from types import TracebackType
-from typing import Any, Optional, Type, cast
+from typing import Any, Optional, Sequence, Type, cast
 
+from hyx.events import EventDispatcher, EventManager, get_default_name
+from hyx.ratelimit.events import _RATELIMITER_LISTENERS, RateLimiterListener
 from hyx.ratelimit.managers import RateLimiter, TokenBucketLimiter
 from hyx.typing import FuncT
 
@@ -49,7 +51,7 @@ class tokenbucket:
     **Parameters**
 
     * **max_executions** *(float)* - How many executions are permitted?
-    * **per_time_secs** *(float)* - Per what time span? (in seconds)
+    * **per_time_secs** *(float)* - Per what time period? (in seconds)
     * **bucket_size** *(None | float)* - The token bucket size. Defines the max number of executions
         that are permitted to happen during bursts.
         The burst is when no executions have happened for a long time, and then you are receiving a
@@ -58,13 +60,31 @@ class tokenbucket:
 
     __slots__ = ("_limiter",)
 
-    def __init__(self, max_executions: float, per_time_secs: float, bucket_size: Optional[float] = None) -> None:
+    def __init__(
+        self,
+        max_executions: float,
+        per_time_secs: float,
+        bucket_size: Optional[float] = None,
+        name: Optional[str] = None,
+        listeners: Optional[Sequence[RateLimiterListener]] = None,
+        event_manager: Optional[EventManager] = None,
+    ) -> None:
+        event_dispatcher = EventDispatcher[RateLimiter, RateLimiterListener](
+            listeners,
+            _RATELIMITER_LISTENERS,
+            event_manager=event_manager,
+        )
+
         self._limiter = TokenBucketLimiter(
+            name=name or get_default_name(),
             max_executions=max_executions,
             per_time_secs=per_time_secs,
             bucket_size=bucket_size,
+            event_dispatcher=event_dispatcher.as_listener,
         )
 
+        event_dispatcher.set_component(self._limiter)
+
     async def __aenter__(self) -> "tokenbucket":
         await self._limiter.acquire()
 
diff --git a/hyx/ratelimit/buckets.py b/hyx/ratelimit/buckets.py
new file mode 100644
index 0000000..e69de29
diff --git a/hyx/ratelimit/events.py b/hyx/ratelimit/events.py
new file mode 100644
index 0000000..bcff079
--- /dev/null
+++ b/hyx/ratelimit/events.py
@@ -0,0 +1,21 @@
+from typing import TYPE_CHECKING, Union
+
+from hyx.events import ListenerFactoryT, ListenerRegistry
+
+if TYPE_CHECKING:
+    from hyx.ratelimit.managers import RateLimiter
+
+_RATELIMITER_LISTENERS: ListenerRegistry["RateLimiter", "RateLimiterListener"] = ListenerRegistry()
+
+
+class RateLimiterListener:
+    ...
+
+
+def register_ratelimiter_listener(listener: Union[RateLimiterListener, ListenerFactoryT]) -> None:
+    """
+    Register a listener that will listen to all rate limiter components in the system
+    """
+    global _RATELIMITER_LISTENERS
+
+    _RATELIMITER_LISTENERS.register(listener)
diff --git a/hyx/ratelimit/managers.py b/hyx/ratelimit/managers.py
index b6294da..b22b98a 100644
--- a/hyx/ratelimit/managers.py
+++ b/hyx/ratelimit/managers.py
@@ -1,6 +1,7 @@
 import asyncio
 from typing import Optional
 
+from hyx.ratelimit.events import RateLimiterListener
 from hyx.ratelimit.exceptions import RateLimitExceeded
 
 
@@ -26,7 +27,14 @@ class TokenBucketLimiter(RateLimiter):
         "_next_replenish_at",
     )
 
-    def __init__(self, max_executions: float, per_time_secs: float, bucket_size: Optional[float] = None) -> None:
+    def __init__(
+        self,
+        name: str,
+        max_executions: float,
+        per_time_secs: float,
+        event_dispatcher: RateLimiterListener,
+        bucket_size: Optional[float] = None,
+    ) -> None:
         self._max_executions = max_executions
         self._per_time_secs = per_time_secs
 
diff --git a/hyx/retry/api.py b/hyx/retry/api.py
index 019c881..2582a4e 100644
--- a/hyx/retry/api.py
+++ b/hyx/retry/api.py
@@ -78,6 +78,7 @@ def bucket_retry(
 
     def _decorator(func: FuncT) -> FuncT:
         limiter = TokenBucketLimiter(attempts, per_time_secs, bucket_size) if attempts and per_time_secs else None
+
         event_dispatcher = EventDispatcher[RetryManager, RetryListener](
             listeners,
             _RETRY_LISTENERS,
diff --git a/tests/test_ratelimiter/test_api.py b/tests/test_ratelimiter/test_api.py
index 9aa039e..5a48100 100644
--- a/tests/test_ratelimiter/test_api.py
+++ b/tests/test_ratelimiter/test_api.py
@@ -2,12 +2,21 @@
 
 import pytest
 
+from hyx.events import NoOpEventDispatcher
 from hyx.ratelimit import TokenBucketLimiter, ratelimiter, tokenbucket
 from hyx.ratelimit.exceptions import RateLimitExceeded
 
 
 async def test__ratelimiter__decorator() -> None:
-    @ratelimiter(limiter=TokenBucketLimiter(max_executions=4, per_time_secs=1, bucket_size=4))
+    limiter = TokenBucketLimiter(
+        name="hyx.tests.decorator",
+        max_executions=4,
+        per_time_secs=1,
+        bucket_size=4,
+        event_dispatcher=NoOpEventDispatcher().as_listener,
+    )
+
+    @ratelimiter(limiter=limiter)
     async def calc() -> float:
         return 42
 
@@ -25,7 +34,15 @@ async def calc() -> float:
 
 
 async def test__ratelimiter__context_manager() -> None:
-    limiter = ratelimiter(limiter=TokenBucketLimiter(max_executions=4, per_time_secs=1, bucket_size=4))
+    limiter = ratelimiter(
+        limiter=TokenBucketLimiter(
+            name="hyx.tests.ctxmgr",
+            max_executions=4,
+            per_time_secs=1,
+            bucket_size=4,
+            event_dispatcher=NoOpEventDispatcher().as_listener,
+        )
+    )
 
     for _ in range(4):
         async with limiter:
@@ -41,7 +58,15 @@ async def test__ratelimiter__token_bucket_context_manager() -> None:
 
 
 async def test__ratelimiter__limit_exceeded() -> None:
-    @ratelimiter(limiter=TokenBucketLimiter(max_executions=3, per_time_secs=1, bucket_size=3))
+    limiter = TokenBucketLimiter(
+        name="hyx.tests.limiter",
+        max_executions=3,
+        per_time_secs=1,
+        bucket_size=3,
+        event_dispatcher=NoOpEventDispatcher().as_listener,
+    )
+
+    @ratelimiter(limiter=limiter)
     async def calc() -> float:
         return 42
 
@@ -51,7 +76,15 @@ async def calc() -> float:
 
 
 async def test__ratelimiter__replenish_after_full_bucket() -> None:
-    @ratelimiter(limiter=TokenBucketLimiter(max_executions=3, per_time_secs=1, bucket_size=3))
+    limiter = TokenBucketLimiter(
+        name="hyx.tests.limiter",
+        max_executions=3,
+        per_time_secs=1,
+        bucket_size=3,
+        event_dispatcher=NoOpEventDispatcher().as_listener,
+    )
+
+    @ratelimiter(limiter=limiter)
     async def calc() -> float:
         return 42