Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

General improvements of misc commands and the extended cog #148

Merged
merged 3 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions src/cogs/restore.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

import logging
import re
from typing import TYPE_CHECKING

from core import ExtendedCog, misc_command
from core import ExtendedCog, MiscCommandContext, misc_command
from core.checkers.misc import bot_required_permissions, is_activated, is_user_authorized, misc_check

if TYPE_CHECKING:
Expand All @@ -16,11 +17,19 @@


class Restore(ExtendedCog):
@misc_command("restore", description="Send a message back in chat if a link is send.", extras={"soon": True})
def contains_message_link(self, message: Message) -> bool:
return bool(re.search(r"<?https://(?:.+\.)?discord(?:app)?\.com/channels/(\d+)/(\d+)/(\d+)", message.content))

@misc_command(
"restore",
description="Send a message back in chat if a link is send.",
extras={"soon": True},
trigger_condition=contains_message_link,
)
@bot_required_permissions(manage_webhooks=True)
@misc_check(is_activated)
@misc_check(is_user_authorized)
async def on_message(self, message: Message) -> None:
async def on_message(self, ctx: MiscCommandContext[MyBot], message: Message) -> None:
raise NotImplementedError("Restore is not implemented.")


Expand Down
12 changes: 4 additions & 8 deletions src/cogs/translate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from discord import Embed, Message, app_commands, ui
from discord.app_commands import locale_str as __

from core import ExtendedCog, ResponseType, TemporaryCache, db, misc_command, response_constructor
from core import ExtendedCog, MiscCommandContext, ResponseType, TemporaryCache, db, misc_command, response_constructor
from core.checkers.misc import bot_required_permissions, is_activated, is_user_authorized, misc_check
from core.constants import EmbedsCharLimits
from core.errors import BadArgument, NonSpecificError
Expand Down Expand Up @@ -282,11 +282,7 @@ async def translate_misc_condition(self, payload: RawReactionActionEvent) -> boo
@bot_required_permissions(send_messages=True, embed_links=True)
@misc_check(is_activated)
@misc_check(is_user_authorized)
async def translate_misc_command(self, payload: RawReactionActionEvent):
user = await self.bot.getch_user(payload.user_id)
if not user or user.bot: # TODO(airo.pi_): automatically ignore bots
return

async def translate_misc_command(self, ctx: MiscCommandContext[MyBot], payload: RawReactionActionEvent):
channel = await self.bot.getch_channel(payload.channel_id)
if channel is None:
return
Expand All @@ -304,12 +300,12 @@ async def public_pre_strategy():
await channel.typing()

async def private_pre_strategy():
await user.typing()
await ctx.user.typing()

if await self.public_translations(payload.guild_id):
strategies = Strategies(pre=public_pre_strategy, send=partial(channel.send, reference=message))
else:
strategies = Strategies(pre=private_pre_strategy, send=user.send)
strategies = Strategies(pre=private_pre_strategy, send=ctx.user.send)

await self.process(
payload.user_id,
Expand Down
1 change: 1 addition & 0 deletions src/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .extended_commands import (
ExtendedCog as ExtendedCog,
ExtendedGroupCog as ExtendedGroupCog,
MiscCommandContext as MiscCommandContext,
cog_property as cog_property,
misc_command as misc_command,
)
Expand Down
4 changes: 2 additions & 2 deletions src/core/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
T = TypeVar("T")

CogT = TypeVar("CogT", bound="ExtendedCog")
BotT = TypeVar("BotT", bound="commands.Bot | commands.AutoShardedBot")

type Bot = commands.Bot | commands.AutoShardedBot
BotT = TypeVar("BotT", bound=Bot)

Snowflake = int

Expand Down
84 changes: 39 additions & 45 deletions src/core/extended_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@
TYPE_CHECKING,
Any,
ClassVar,
Concatenate,
Generic,
Literal,
ParamSpec,
Protocol,
Self,
TypeVar,
Expand All @@ -23,23 +20,20 @@
from discord.ext import commands
from discord.utils import maybe_coroutine

from ._types import BotT, CogT
from .errors import MiscCheckFailure, MiscCommandError, MiscNoPrivateMessage, UnexpectedError

if TYPE_CHECKING:
from discord.abc import MessageableChannel, Snowflake
from discord.ext.commands.bot import AutoShardedBot, Bot, BotBase # pyright: ignore[reportMissingTypeStubs]
from discord.ext.commands.bot import BotBase # pyright: ignore[reportMissingTypeStubs]

from mybot import MyBot

from ._types import CoroT, UnresolvedContext, UnresolvedContextT
from ._types import Bot, CogT, CoroT, UnresolvedContext, UnresolvedContextT

ConditionCallback = Callable[Concatenate["CogT", UnresolvedContextT, "P"], CoroT[bool] | bool]
Callback = Callable[Concatenate["CogT", UnresolvedContextT, "P"], CoroT["T"]]
type ConditionCallback[CogT, UnresolvedContextT] = Callable[[CogT, UnresolvedContextT], CoroT[bool] | bool]
type Callback[CogT, UnresolvedContext, R] = Callable[[CogT, "MiscCommandContext[Any]", UnresolvedContext], CoroT[R]]

P = ParamSpec("P")
T = TypeVar("T")
C = TypeVar("C", bound="commands.Cog")
R = TypeVar("R")


LiteralNames = Literal["raw_reaction_add", "message"]
Expand All @@ -57,7 +51,7 @@ class MiscCommandsType(Enum):


class ExtendedCog(commands.Cog):
__cog_misc_commands__: list[MiscCommand[Any, ..., Any]]
__cog_misc_commands__: list[MiscCommand[Any, Any]]
bot: MyBot

def __new__(cls, *args: Any, **kwargs: Any) -> Self:
Expand All @@ -74,7 +68,7 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Self:
def __init__(self, bot: MyBot) -> None:
self.bot = bot

def get_misc_commands(self) -> list[MiscCommand[Any, ..., Any]]:
def get_misc_commands(self) -> list[MiscCommand[Any, Any]]:
"""Return all the misc commands in this cog."""
return list(self.__cog_misc_commands__)

Expand All @@ -93,18 +87,18 @@ class ExtendedGroupCog(ExtendedCog):
__cog_is_app_commands_group__: ClassVar[bool] = True


class MiscCommand(Generic[CogT, P, T]):
bot: Bot | AutoShardedBot
class MiscCommand[CogT: ExtendedCog, R]:
bot: Bot

def __init__(
self,
name: str,
callback: Callback[CogT, UnresolvedContextT, P, T],
callback: Callback[CogT, UnresolvedContextT, R],
description: str,
nsfw: bool,
type: MiscCommandsType,
extras: dict[Any, Any],
trigger_condition: Callable[Concatenate[CogT, UnresolvedContext, P], bool | CoroT[bool]] | None,
trigger_condition: Callable[[CogT, UnresolvedContextT], bool | CoroT[bool]] | None,
) -> None:
self.name = name
self.type = type
Expand All @@ -123,14 +117,12 @@ def __init__(
)
self._callback = callback

async def do_call(self, cog: CogT, context: UnresolvedContext, *args: P.args, **kwargs: P.kwargs) -> T:
async def do_call(self, cog: CogT, context: UnresolvedContext) -> R:
if self.trigger_condition:
trigger_condition = await discord.utils.maybe_coroutine(
self.trigger_condition,
self.trigger_condition, # type: ignore
cog,
context,
*args,
**kwargs, # type: ignore
context, # type: ignore
)
if not trigger_condition:
return None # type: ignore
Expand All @@ -143,12 +135,12 @@ async def do_call(self, cog: CogT, context: UnresolvedContext, *args: P.args, **
self.bot.dispatch("misc_command_error", resolved_context, e)
return None # type: ignore

return await self._callback(cog, context, *args, **kwargs) # type: ignore
return await self._callback(cog, resolved_context, context) # type: ignore

def add_check(self, predicate: Callable[[MiscCommandContext[Any]], CoroT[bool] | bool]) -> None:
self.checks.append(predicate)

async def condition(self, func: ConditionCallback[CogT, UnresolvedContextT, P]) -> None:
async def condition(self, func: ConditionCallback[CogT, UnresolvedContextT]) -> None:
self.trigger_condition = func


Expand All @@ -159,8 +151,8 @@ def misc_command(
nsfw: bool = False,
listener_name: LiteralNames | None = None,
extras: dict[Any, Any] | None = None,
trigger_condition: ConditionCallback[CogT, UnresolvedContextT, P] | None = None,
) -> Callable[[Callback[CogT, UnresolvedContextT, P, T]], Callback[CogT, UnresolvedContextT, P, T]]:
trigger_condition: ConditionCallback[CogT, UnresolvedContextT] | None = None,
) -> Callable[[Callback[CogT, UnresolvedContextT, R]], Callable[[CogT, UnresolvedContext], CoroT[R]]]:
"""Register an event listener as a "command" that can be retrieved from the feature exporter.
Checkers will be called within the second argument of the function (right after the Cog (self))

Expand All @@ -179,10 +171,12 @@ def misc_command(
A wrapped function, bound with a MiscCommand.
"""

def inner(func: Callback[CogT, UnresolvedContextT, P, T]) -> Callback[CogT, UnresolvedContextT, P, T]:
def inner(
func: Callback[CogT, UnresolvedContextT, R],
) -> Callable[[CogT, UnresolvedContext], CoroT[R]]:
true_listener_name = "on_" + listener_name if listener_name else func.__name__

misc_command = MiscCommand[CogT, P, T](
misc_command = MiscCommand["CogT", R](
name=name,
callback=func,
description=description,
Expand All @@ -193,8 +187,8 @@ def inner(func: Callback[CogT, UnresolvedContextT, P, T]) -> Callback[CogT, Unre
)

@wraps(func)
async def inner(cog: CogT, context: UnresolvedContext, *args: P.args, **kwargs: P.kwargs) -> T:
return await misc_command.do_call(cog, context, *args, **kwargs)
async def inner(cog: CogT, context: UnresolvedContext) -> R:
return await misc_command.do_call(cog, context)

setattr(inner, "__listener_as_command__", misc_command)

Expand All @@ -218,21 +212,21 @@ class MiscCommandContextFilled(Protocol):
user: discord.User


class MiscCommandContext(Generic[BotT]):
class MiscCommandContext[B: Bot]:
def __init__(
self,
bot: BotT,
bot: B,
channel: MessageableChannel,
user: User | Member,
command: MiscCommand[Any, ..., Any],
command: MiscCommand[Any, Any],
) -> None:
self.channel: MessageableChannel = channel
self.user: User | Member = user
self.bot: BotT = bot
self.command: MiscCommand[Any, ..., Any] = command
self.bot: B = bot
self.command: MiscCommand[Any, Any] = command

@classmethod
async def resolve(cls, bot: BotT, context: UnresolvedContext, command: MiscCommand[Any, ..., Any]) -> Self:
async def resolve(cls, bot: B, context: UnresolvedContext, command: MiscCommand[Any, Any]) -> Self:
channel: MessageableChannel
user: User | Member

Expand Down Expand Up @@ -268,15 +262,15 @@ def bot_permissions(self) -> Permissions:
return channel.permissions_for(me)


def misc_guild_only() -> Callable[[T], T]:
def misc_guild_only() -> Callable[[R], R]:
def predicate(ctx: MiscCommandContext[Any]) -> bool:
if ctx.channel.guild is None:
raise MiscNoPrivateMessage
return True

def decorator(func: T) -> T:
def decorator(func: R) -> R:
if hasattr(func, "__listener_as_command__"):
misc_command: MiscCommand[Any, ..., Any] = getattr(func, "__listener_as_command__")
misc_command: MiscCommand[Any, Any] = getattr(func, "__listener_as_command__")
misc_command.add_check(predicate)
misc_command.guild_only = True
else:
Expand All @@ -291,10 +285,10 @@ def decorator(func: T) -> T:
return decorator


def misc_check(predicate: Callable[[MiscCommandContext[Any]], CoroT[bool] | bool]) -> Callable[[T], T]:
def decorator(func: T) -> T:
def misc_check(predicate: Callable[[MiscCommandContext[Any]], CoroT[bool] | bool]) -> Callable[[R], R]:
def decorator(func: R) -> R:
if hasattr(func, "__listener_as_command__"):
misc_command: MiscCommand[Any, ..., Any] = getattr(func, "__listener_as_command__")
misc_command: MiscCommand[Any, Any] = getattr(func, "__listener_as_command__")
misc_command.add_check(predicate)
else:
if not hasattr(func, "__misc_commands_checks__"):
Expand All @@ -314,10 +308,10 @@ def cog_property(cog_name: str):
cog_name: the cog name to return
"""

def inner(_: Callable[..., C]) -> C:
def inner(_: Callable[..., CogT]) -> CogT:
@property
def cog_getter(self: Any) -> C: # self is a cog within the .bot attribute (because every Cog should have it)
cog: C | None = self.bot.get_cog(cog_name)
def cog_getter(self: Any) -> CogT: # self is a cog within the .bot attribute (because every Cog should have it)
cog: CogT | None = self.bot.get_cog(cog_name)
if cog is None:
raise UnexpectedError(f"Cog named {cog_name} is not loaded.")
return cog
Expand Down
2 changes: 1 addition & 1 deletion src/features_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


FeatureCodebaseTypes = (
app_commands.Command[Any, ..., Any] | app_commands.Group | app_commands.ContextMenu | MiscCommand[Any, ..., Any]
app_commands.Command[Any, ..., Any] | app_commands.Group | app_commands.ContextMenu | MiscCommand[Any, Any]
)


Expand Down
3 changes: 2 additions & 1 deletion src/mybot.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(self, startup_sync: bool = False) -> None:
intents.reactions = True
intents.guilds = True
intents.messages = True
intents.message_content = True
logger.debug("Intents : %s", ", ".join(flag[0] for flag in intents if flag[1]))

super().__init__(
Expand Down Expand Up @@ -318,7 +319,7 @@ def misc_commands(self):
Returns:
_type_: the list of misc commands
"""
misc_commands: list[MiscCommand[Any, ..., Any]] = []
misc_commands: list[MiscCommand[Any, Any]] = []
for cog in self.cogs.values():
if isinstance(cog, ExtendedCog):
misc_commands.extend(cog.get_misc_commands())
Expand Down