diff --git a/threadopener/cooldown.py b/threadopener/cooldown.py new file mode 100644 index 00000000..63aae42a --- /dev/null +++ b/threadopener/cooldown.py @@ -0,0 +1,32 @@ +from typing import Dict, Optional, Callable, Any, final, Tuple, Union + +import discord +from redbot.core import commands + + +@final +class ThreadCooldown(commands.CooldownMapping[commands.Context]): + def __init__( + self, + original: Optional[commands.Cooldown], + type: Callable[[commands.Context], Any], + ) -> None: + super().__init__(original, type) + self._cache: Dict[Any, commands.Cooldown] = {} + self._cooldown: Optional[commands.Cooldown] = original + self._type: Callable[[commands.Context], Any] = type + + def __call__(self) -> "ThreadCooldown": + return self + + def get_bucket( + self, message: discord.Message, current: Optional[float] = None + ) -> Optional[commands.Cooldown]: + return super().get_bucket(message, current) # type: ignore + + def _bucket_key(self, tup: Tuple[int, Union[int, str]]) -> Tuple[int, Union[int, str]]: + return tup + + def is_rate_limited(self, message: discord.Message) -> bool: + bucket = self.get_bucket(message) + return bucket.update_rate_limit() is not None # type: ignore diff --git a/threadopener/core.py b/threadopener/core.py index 3a7c977d..9b191683 100644 --- a/threadopener/core.py +++ b/threadopener/core.py @@ -32,6 +32,7 @@ from .abc import CompositeMetaClass from .commands import Commands +from .cooldown import ThreadCooldown log: logging.Logger = logging.getLogger("red.seina.threadopener") @@ -66,9 +67,7 @@ def __init__(self, bot: Red) -> None: cooldown: Tuple[int, int, commands.BucketType] = (3, 10, commands.BucketType.guild) - self.spam_control: commands.CooldownMapping[ - commands.Context - ] = commands.CooldownMapping.from_cooldown(*cooldown) + self.spam_control: ThreadCooldown = ThreadCooldown.from_cooldown(*cooldown) def format_help_for_context(self, ctx: commands.Context) -> str: pre_processed = super().format_help_for_context(ctx) @@ -111,10 +110,10 @@ async def on_message(self, message: discord.Message) -> None: ) return - bucket = self.spam_control.get_bucket(message) # type: ignore + bucket = self.spam_control.get_bucket(message) current = message.created_at.timestamp() retry_after = bucket and bucket.update_rate_limit(current) - if retry_after and message.author.id in self.bot.owner_ids: # type: ignore + if retry_after and message.author.id not in self.bot.owner_ids: # type: ignore log.debug(f"{message.channel} ratelimit exhausted, retry after: {retry_after}.") return