From 4b00ebd4aa0bf9afcb212fe2f854707f0f5e442e Mon Sep 17 00:00:00 2001 From: Georgy Pliev Date: Wed, 19 Jun 2024 18:14:29 +0300 Subject: [PATCH] feat: implement UserAdapter & ChatAdapter, close #83 --- examples/upload.py | 10 +- telegrinder/__init__.py | 6 + telegrinder/bot/__init__.py | 4 +- telegrinder/bot/cute_types/update.py | 12 +- telegrinder/bot/dispatch/handler/func.py | 2 +- .../bot/dispatch/handler/message_reply.py | 3 + telegrinder/bot/dispatch/process.py | 13 +- telegrinder/bot/dispatch/view/abc.py | 6 +- telegrinder/bot/dispatch/view/message.py | 2 +- telegrinder/bot/rules/abc.py | 62 +++++--- telegrinder/bot/rules/adapter/__init__.py | 2 + telegrinder/bot/rules/adapter/abc.py | 8 +- telegrinder/bot/rules/adapter/chat.py | 38 +++++ telegrinder/bot/rules/adapter/errors.py | 3 +- telegrinder/bot/rules/adapter/event.py | 23 +-- telegrinder/bot/rules/adapter/user.py | 43 ++++++ telegrinder/bot/rules/adapter/utils.py | 74 +++++++++ telegrinder/bot/rules/callback_data.py | 3 +- telegrinder/bot/rules/chat_join.py | 5 +- telegrinder/bot/rules/func.py | 12 +- telegrinder/bot/rules/inline.py | 4 +- telegrinder/bot/rules/is_from.py | 144 +++++++++--------- telegrinder/bot/rules/message.py | 2 +- telegrinder/bot/rules/rule_enum.py | 24 +-- telegrinder/bot/rules/update.py | 6 +- telegrinder/types/objects.py | 23 ++- 26 files changed, 359 insertions(+), 175 deletions(-) create mode 100644 telegrinder/bot/rules/adapter/chat.py create mode 100644 telegrinder/bot/rules/adapter/user.py create mode 100644 telegrinder/bot/rules/adapter/utils.py diff --git a/examples/upload.py b/examples/upload.py index 2e2f8c8..c376d7a 100755 --- a/examples/upload.py +++ b/examples/upload.py @@ -1,19 +1,19 @@ import pathlib -from telegrinder import API, Button, Keyboard, Message, Telegrinder, Token +from telegrinder import API, Message, Telegrinder, Token +from telegrinder.bot.rules.is_from import IsPrivate from telegrinder.modules import logger -from telegrinder.rules import Text +from telegrinder.rules import IsUser, Text from telegrinder.types import InputFile api = API(token=Token.from_env()) bot = Telegrinder(api) -kb = (Keyboard().add(Button("Button 1")).add(Button("Button 2"))).get_markup() -cool_bytes = pathlib.Path("assets/satie.jpeg").read_bytes() +cool_bytes = pathlib.Path("examples/assets/satie.jpeg").read_bytes() logger.set_level("INFO") -@bot.on.message(Text("/photo")) +@bot.on.message(Text("/photo"), IsPrivate() & IsUser()) async def start(message: Message): await message.answer_photo( InputFile("satie.jpeg", cool_bytes), diff --git a/telegrinder/__init__.py b/telegrinder/__init__.py index c6cb4a7..736b76f 100755 --- a/telegrinder/__init__.py +++ b/telegrinder/__init__.py @@ -51,8 +51,10 @@ async def start(message: Message): BaseView, CallbackQueryCute, CallbackQueryReturnManager, + CallbackQueryRule, CallbackQueryView, ChatJoinRequestCute, + ChatJoinRequestRule, ChatJoinRequestView, ChatMemberUpdatedCute, ChatMemberView, @@ -63,6 +65,7 @@ async def start(message: Message): FuncHandler, InlineQueryCute, InlineQueryReturnManager, + InlineQueryRule, MessageCute, MessageReplyHandler, MessageReturnManager, @@ -153,6 +156,9 @@ async def start(message: Message): "CallbackQueryView", "ChatJoinRequest", "ChatJoinRequestCute", + "CallbackQueryRule", + "ChatJoinRequestRule", + "InlineQueryRule", "ChatJoinRequestView", "ChatMemberUpdated", "ChatMemberUpdatedCute", diff --git a/telegrinder/bot/__init__.py b/telegrinder/bot/__init__.py index 1788247..e395902 100755 --- a/telegrinder/bot/__init__.py +++ b/telegrinder/bot/__init__.py @@ -38,7 +38,7 @@ register_manager, ) from .polling import ABCPolling, Polling -from .rules import ABCRule, CallbackQueryRule, MessageRule +from .rules import ABCRule, CallbackQueryRule, ChatJoinRequestRule, InlineQueryRule, MessageRule from .scenario import ABCScenario, Checkbox, Choice __all__ = ( @@ -59,6 +59,8 @@ "CallbackQueryReturnManager", "CallbackQueryRule", "CallbackQueryView", + "ChatJoinRequestRule", + "InlineQueryRule", "ChatJoinRequestCute", "ChatJoinRequestView", "ChatMemberUpdatedCute", diff --git a/telegrinder/bot/cute_types/update.py b/telegrinder/bot/cute_types/update.py index 4a50c72..fb8d35c 100755 --- a/telegrinder/bot/cute_types/update.py +++ b/telegrinder/bot/cute_types/update.py @@ -15,16 +15,12 @@ class UpdateCute(BaseCute[Update], Update, kw_only=True): api: ABCAPI @property - def incoming_update(self) -> Option[Model]: - return getattr( - self, - self.update_type.expect("Update object has no incoming update.").value, - ) + def incoming_update(self) -> Model: + return getattr(self, self.update_type.value).unwrap() def get_event(self, event_model: type[ModelT]) -> Option[ModelT]: - match self.incoming_update: - case Some(event) if isinstance(event, event_model): - return Some(event) + if isinstance(self.incoming_update, event_model): + return Some(self.incoming_update) return Nothing() diff --git a/telegrinder/bot/dispatch/handler/func.py b/telegrinder/bot/dispatch/handler/func.py index 490b952..f7b5863 100755 --- a/telegrinder/bot/dispatch/handler/func.py +++ b/telegrinder/bot/dispatch/handler/func.py @@ -48,7 +48,7 @@ def __repr__(self) -> str: ) async def check(self, api: ABCAPI, event: Update, ctx: Context | None = None) -> bool: - if self.update_type is not None and self.update_type != event.update_type.unwrap_or_none(): + if self.update_type is not None and self.update_type != event.update_type: return False ctx = ctx or Context() temp_ctx = ctx.copy() diff --git a/telegrinder/bot/dispatch/handler/message_reply.py b/telegrinder/bot/dispatch/handler/message_reply.py index 3972358..852b327 100755 --- a/telegrinder/bot/dispatch/handler/message_reply.py +++ b/telegrinder/bot/dispatch/handler/message_reply.py @@ -19,11 +19,13 @@ def __init__( is_blocking: bool = True, as_reply: bool = False, preset_context: Context | None = None, + **default_params: typing.Any, ) -> None: self.text = text self.rules = list(rules) self.as_reply = as_reply self.is_blocking = is_blocking + self.default_params = default_params self.preset_context = preset_context or Context() def __repr__(self) -> str: @@ -51,6 +53,7 @@ async def run(self, event: MessageCute, _: Context) -> typing.Any: await event.answer( text=self.text, reply_parameters=ReplyParameters(event.message_id) if self.as_reply else None, + **self.default_params, ) diff --git a/telegrinder/bot/dispatch/process.py b/telegrinder/bot/dispatch/process.py index d44d3c0..5311fb1 100755 --- a/telegrinder/bot/dispatch/process.py +++ b/telegrinder/bot/dispatch/process.py @@ -16,16 +16,17 @@ from telegrinder.bot.dispatch.handler.abc import ABCHandler from telegrinder.bot.rules.abc import ABCRule -T = typing.TypeVar("T", bound=BaseCute) +AdaptTo = typing.TypeVar("AdaptTo") +Event = typing.TypeVar("Event", bound=BaseCute) _: typing.TypeAlias = typing.Any async def process_inner( - event: T, + event: Event, raw_event: Update, - middlewares: list[ABCMiddleware[T]], - handlers: list["ABCHandler[T]"], - return_manager: ABCReturnManager[T] | None = None, + middlewares: list[ABCMiddleware[Event]], + handlers: list["ABCHandler[Event]"], + return_manager: ABCReturnManager[Event] | None = None, ) -> bool: logger.debug("Processing {!r}...", event.__class__.__name__) ctx = Context(raw_update=raw_event) @@ -58,7 +59,7 @@ async def process_inner( async def check_rule( api: ABCAPI, - rule: "ABCRule[T]", + rule: "ABCRule[Event, AdaptTo]", update: Update, ctx: Context, ) -> bool: diff --git a/telegrinder/bot/dispatch/view/abc.py b/telegrinder/bot/dispatch/view/abc.py index 6d8bf8e..e5abf99 100755 --- a/telegrinder/bot/dispatch/view/abc.py +++ b/telegrinder/bot/dispatch/view/abc.py @@ -69,11 +69,7 @@ def get_event_type(cls) -> Option[type[EventType]]: @staticmethod def get_raw_event(update: Update) -> Option[Model]: - match update.update_type: - case Some(update_type): - return getattr(update, update_type.value) - case _: - return Nothing() + return getattr(update, update.update_type.value) @typing.overload @classmethod diff --git a/telegrinder/bot/dispatch/view/message.py b/telegrinder/bot/dispatch/view/message.py index f056f83..b8a952c 100755 --- a/telegrinder/bot/dispatch/view/message.py +++ b/telegrinder/bot/dispatch/view/message.py @@ -33,7 +33,7 @@ async def check(self, event: Update) -> bool: return ( True if self.update_type is None - else self.update_type == event.update_type.unwrap_or_none() + else self.update_type == event.update_type ) diff --git a/telegrinder/bot/rules/abc.py b/telegrinder/bot/rules/abc.py index 0f6ac4a..8eb7792 100755 --- a/telegrinder/bot/rules/abc.py +++ b/telegrinder/bot/rules/abc.py @@ -1,7 +1,8 @@ import inspect -import typing from abc import ABC, abstractmethod +import typing_extensions as typing + from telegrinder.bot.cute_types import BaseCute, MessageCute, UpdateCute from telegrinder.bot.dispatch.context import Context from telegrinder.bot.dispatch.process import check_rule @@ -10,13 +11,14 @@ from telegrinder.tools.magic import cache_translation, get_cached_translation from telegrinder.types.objects import Update as UpdateObject -T = typing.TypeVar("T", bound=BaseCute) +AdaptTo = typing.TypeVar("AdaptTo", default=UpdateCute) +EventCute = typing.TypeVar("EventCute", bound=BaseCute, default=UpdateCute) Message: typing.TypeAlias = MessageCute Update: typing.TypeAlias = UpdateCute -def with_caching_translations(func): +def with_caching_translations(func: typing.Callable[..., typing.Any]): """Should be used as decorator for .translate method. Caches rule translations.""" async def wrapper(self: "ABCRule[typing.Any]", translator: ABCTranslator): @@ -29,15 +31,15 @@ async def wrapper(self: "ABCRule[typing.Any]", translator: ABCTranslator): return wrapper -class ABCRule(ABC, typing.Generic[T]): - adapter: ABCAdapter[UpdateObject, T] = RawUpdateAdapter() # type: ignore - requires: list["ABCRule[T]"] = [] +class ABCRule(ABC, typing.Generic[EventCute, AdaptTo]): + adapter: ABCAdapter[UpdateObject, AdaptTo] = RawUpdateAdapter() # type: ignore + requires: list["ABCRule[EventCute]"] = [] @abstractmethod - async def check(self, event: T, ctx: Context) -> bool: + async def check(self, event: AdaptTo, ctx: Context) -> bool: pass - def __init_subclass__(cls, requires: list["ABCRule[T]"] | None = None): + def __init_subclass__(cls, requires: list["ABCRule[EventCute, AdaptTo]"] | None = None) -> None: """Merges requirements from inherited classes and rule-specific requirements.""" requirements = [] @@ -48,17 +50,41 @@ def __init_subclass__(cls, requires: list["ABCRule[T]"] | None = None): requirements.extend(requires or ()) cls.requires = list(dict.fromkeys(requirements)) - def __and__(self, other: "ABCRule[T]"): + def __and__(self, other: "ABCRule[EventCute, AdaptTo]") -> "AndRule[EventCute, AdaptTo]": + """And Rule. + + ```python + rule = HasText() & HasCaption() + rule #> AndRule(HasText(), HasCaption()) -> True if all rules in an AndRule are True, otherwise False. + ``` + """ + return AndRule(self, other) - def __or__(self, other: "ABCRule[T]"): + def __or__(self, other: "ABCRule[EventCute, AdaptTo]") -> "OrRule[EventCute, AdaptTo]": + """Or Rule. + + ```python + rule = HasText() | HasCaption() + rule #> OrRule(HasText(), HasCaption()) -> True if any rule in an OrRule are True, otherwise False. + ``` + """ + return OrRule(self, other) - def __neg__(self) -> "ABCRule[T]": + def __invert__(self) -> "NotRule[EventCute, AdaptTo]": + """Not Rule. + + ```python + rule = ~HasText() + rule # NotRule(HasText()) -> True if rule returned False, otherwise False. + ``` + """ + return NotRule(self) def __repr__(self) -> str: - return "".format( + return "<{}: adapter={!r}>".format( self.__class__.__name__, self.adapter, ) @@ -67,8 +93,8 @@ async def translate(self, translator: ABCTranslator) -> typing.Self: return self -class AndRule(ABCRule[T]): - def __init__(self, *rules: ABCRule[T]): +class AndRule(ABCRule[EventCute, AdaptTo]): + def __init__(self, *rules: ABCRule[EventCute, AdaptTo]) -> None: self.rules = rules async def check(self, event: Update, ctx: Context) -> bool: @@ -80,8 +106,8 @@ async def check(self, event: Update, ctx: Context) -> bool: return True -class OrRule(ABCRule[T]): - def __init__(self, *rules: ABCRule[T]): +class OrRule(ABCRule[EventCute, AdaptTo]): + def __init__(self, *rules: ABCRule[EventCute, AdaptTo]) -> None: self.rules = rules async def check(self, event: Update, ctx: Context) -> bool: @@ -93,8 +119,8 @@ async def check(self, event: Update, ctx: Context) -> bool: return False -class NotRule(ABCRule[T]): - def __init__(self, rule: ABCRule[T]): +class NotRule(ABCRule[EventCute, AdaptTo]): + def __init__(self, rule: ABCRule[EventCute, AdaptTo]) -> None: self.rule = rule async def check(self, event: Update, ctx: Context) -> bool: diff --git a/telegrinder/bot/rules/adapter/__init__.py b/telegrinder/bot/rules/adapter/__init__.py index 2d8b5e6..ef9c4fc 100755 --- a/telegrinder/bot/rules/adapter/__init__.py +++ b/telegrinder/bot/rules/adapter/__init__.py @@ -2,10 +2,12 @@ from .errors import AdapterError from .event import EventAdapter from .raw_update import RawUpdateAdapter +from .user import UserAdapter __all__ = ( "ABCAdapter", "AdapterError", "EventAdapter", "RawUpdateAdapter", + "UserAdapter", ) diff --git a/telegrinder/bot/rules/adapter/abc.py b/telegrinder/bot/rules/adapter/abc.py index f4d52a5..debfe11 100755 --- a/telegrinder/bot/rules/adapter/abc.py +++ b/telegrinder/bot/rules/adapter/abc.py @@ -8,13 +8,13 @@ from telegrinder.bot.rules.adapter.errors import AdapterError from telegrinder.model import Model -UpdateT = typing.TypeVar("UpdateT", bound=Model) -CuteT = typing.TypeVar("CuteT", bound=BaseCute) +From = typing.TypeVar("From", bound=Model) +To = typing.TypeVar("To") -class ABCAdapter(abc.ABC, typing.Generic[UpdateT, CuteT]): +class ABCAdapter(abc.ABC, typing.Generic[From, To]): @abc.abstractmethod - async def adapt(self, api: ABCAPI, update: UpdateT) -> Result[CuteT, AdapterError]: + async def adapt(self, api: ABCAPI, update: From) -> Result[To, AdapterError]: pass diff --git a/telegrinder/bot/rules/adapter/chat.py b/telegrinder/bot/rules/adapter/chat.py new file mode 100644 index 0000000..10e7360 --- /dev/null +++ b/telegrinder/bot/rules/adapter/chat.py @@ -0,0 +1,38 @@ +import typing + +from fntypes.result import Error, Ok, Result + +from telegrinder.api.abc import ABCAPI +from telegrinder.bot.cute_types.base import BaseCute +from telegrinder.bot.rules.adapter.abc import ABCAdapter +from telegrinder.bot.rules.adapter.errors import AdapterError +from telegrinder.bot.rules.adapter.raw_update import RawUpdateAdapter +from telegrinder.bot.rules.adapter.utils import Source, get_by_sources +from telegrinder.types.objects import Chat, Update + +ToCute = typing.TypeVar("ToCute", bound=BaseCute) + + +@typing.runtime_checkable +class HasChat(Source, typing.Protocol): + chat: Chat + + +class ChatAdapter(ABCAdapter[Update, Chat]): + def __init__(self) -> None: + self.raw_adapter = RawUpdateAdapter() + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}: Update -> UpdateCute -> Chat>" + + async def adapt(self, api: ABCAPI, update: Update) -> Result[Chat, AdapterError]: + match await self.raw_adapter.adapt(api, update): + case Ok(event): + if (source := get_by_sources(event.incoming_update, HasChat)): + return Ok(source) + return Error(AdapterError(f"{event.incoming_update.__class__.__name__!r} has no chat.")) + case Error(_) as error: + return error + + +__all__ = ("ChatAdapter",) diff --git a/telegrinder/bot/rules/adapter/errors.py b/telegrinder/bot/rules/adapter/errors.py index 05063ae..fcbfd47 100755 --- a/telegrinder/bot/rules/adapter/errors.py +++ b/telegrinder/bot/rules/adapter/errors.py @@ -1,5 +1,4 @@ -class AdapterError(RuntimeError): - pass +class AdapterError(RuntimeError): ... __all__ = ("AdapterError",) diff --git a/telegrinder/bot/rules/adapter/event.py b/telegrinder/bot/rules/adapter/event.py index 7735fae..f29a47b 100755 --- a/telegrinder/bot/rules/adapter/event.py +++ b/telegrinder/bot/rules/adapter/event.py @@ -7,13 +7,14 @@ from telegrinder.bot.rules.adapter.abc import ABCAdapter from telegrinder.bot.rules.adapter.errors import AdapterError from telegrinder.msgspec_utils import Nothing +from telegrinder.types.enums import UpdateType from telegrinder.types.objects import Model, Update -CuteT = typing.TypeVar("CuteT", bound=BaseCute) +ToCute = typing.TypeVar("ToCute", bound=BaseCute) -class EventAdapter(ABCAdapter[Update, CuteT]): - def __init__(self, event: str | type[Model], cute_model: type[CuteT]) -> None: +class EventAdapter(ABCAdapter[Update, ToCute]): + def __init__(self, event: UpdateType | type[Model], cute_model: type[ToCute]) -> None: self.event = event self.cute_model = cute_model @@ -27,27 +28,31 @@ def __repr__(self) -> str: ) else: raw_update_type = self.event.__name__ + return "<{}: adapt {} -> {}>".format( self.__class__.__name__, raw_update_type, self.cute_model.__name__, ) - async def adapt(self, api: ABCAPI, update: Update) -> Result[CuteT, AdapterError]: + async def adapt(self, api: ABCAPI, update: Update) -> Result[ToCute, AdapterError]: update_dct = update.to_dict() - if isinstance(self.event, str): - if self.event not in update_dct: + if isinstance(self.event, UpdateType): + if update.update_type != self.event: return Error( AdapterError(f"Update is not of event type {self.event!r}."), ) - if update_dct[self.event] is Nothing: + + if update_dct[self.event.value] is Nothing: return Error( AdapterError(f"Update is not an {self.event!r}."), ) + return Ok( - self.cute_model.from_update(update_dct[self.event].unwrap(), bound_api=api), + self.cute_model.from_update(update_dct[self.event.value].unwrap(), bound_api=api), ) - event = update_dct[update.update_type.unwrap()].unwrap() + + event = update_dct[update.update_type.value].unwrap() if not update.update_type or not issubclass(event.__class__, self.event): return Error(AdapterError(f"Update is not an {self.event.__name__!r}.")) return Ok(self.cute_model.from_update(event, bound_api=api)) diff --git a/telegrinder/bot/rules/adapter/user.py b/telegrinder/bot/rules/adapter/user.py new file mode 100644 index 0000000..9f11e66 --- /dev/null +++ b/telegrinder/bot/rules/adapter/user.py @@ -0,0 +1,43 @@ +import typing + +from fntypes.result import Error, Ok, Result + +from telegrinder.api.abc import ABCAPI +from telegrinder.bot.cute_types.base import BaseCute +from telegrinder.bot.rules.adapter.abc import ABCAdapter +from telegrinder.bot.rules.adapter.errors import AdapterError +from telegrinder.bot.rules.adapter.raw_update import RawUpdateAdapter +from telegrinder.bot.rules.adapter.utils import Source, get_by_sources +from telegrinder.types.objects import Update, User + +ToCute = typing.TypeVar("ToCute", bound=BaseCute) + + +@typing.runtime_checkable +class HasFrom(Source, typing.Protocol): + from_: User + + +@typing.runtime_checkable +class HasUser(Source, typing.Protocol): + user: User + + +class UserAdapter(ABCAdapter[Update, User]): + def __init__(self) -> None: + self.raw_adapter = RawUpdateAdapter() + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}: Update -> UpdateCute -> User>" + + async def adapt(self, api: ABCAPI, update: Update) -> Result[User, AdapterError]: + match await self.raw_adapter.adapt(api, update): + case Ok(event): + if (source := get_by_sources(event.incoming_update, [HasFrom, HasUser])): + return Ok(source) + return Error(AdapterError(f"{event.incoming_update.__class__.__name__!r} has no user.")) + case Error(_) as error: + return error + + +__all__ = ("UserAdapter",) diff --git a/telegrinder/bot/rules/adapter/utils.py b/telegrinder/bot/rules/adapter/utils.py new file mode 100644 index 0000000..e7f6194 --- /dev/null +++ b/telegrinder/bot/rules/adapter/utils.py @@ -0,0 +1,74 @@ +import typing + +from fntypes.option import Some +from fntypes.variative import Variative + +from telegrinder.model import Model +from telegrinder.msgspec_utils import Nothing + +T = typing.TypeVar("T", bound="Source") + + +@typing.runtime_checkable +class Source(typing.Protocol): ... + + +def unwrap_value(value: typing.Any) -> typing.Any | None: + if value in (Nothing, None): + return None + if isinstance(value, Some): + return value.unwrap() + if isinstance(value, Variative): + return unwrap_value(value.v) + return value + + +def get_by_sources(model: Model, sources: type[T] | list[type[T]]) -> typing.Any | None: + """ + For example: + + ```python + @typing.runtime_checkable + class HasFrom(Source, typing.Protocol): + from_: User + + + @typing.runtime_checkable + class HasUser(Source, typing.Protocol): + user: User + + + class Message(Model): + from_: User + + + class MessageReactionUpdated(Model): + user: User + + + get_by_sources(Message(...), [HasFrom, HasUser]) # User(...) + get_by_sources(Message(...), HasUser) # None + + get_by_sources(MessageReactionUpdated(), [HasFrom, HasUser]) # User(...) + get_by_sources(Message(...), HasFrom) # None + ``` + """ + + sources = [sources] if not isinstance(sources, list) else sources + for source in sources: + if not isinstance(model, source): + continue + + values = filter(None, [ + getattr(model, field, None) + for field in typing.get_type_hints(source) + ]) + for value in values: + value = unwrap_value(value) + if value is not None: + return value + + return None + + +__all__ = ("Source", "unwrap_value", "get_by_sources") diff --git a/telegrinder/bot/rules/callback_data.py b/telegrinder/bot/rules/callback_data.py index 790102e..4000894 100755 --- a/telegrinder/bot/rules/callback_data.py +++ b/telegrinder/bot/rules/callback_data.py @@ -10,6 +10,7 @@ from telegrinder.bot.rules.adapter import EventAdapter from telegrinder.model import decoder from telegrinder.tools.buttons import DataclassInstance +from telegrinder.types.enums import UpdateType from .abc import ABCRule from .markup import Markup, PatternLike, check_string @@ -22,7 +23,7 @@ class CallbackQueryRule(ABCRule[CallbackQuery], abc.ABC): - adapter = EventAdapter("callback_query", CallbackQuery) + adapter: EventAdapter[CallbackQuery] = EventAdapter(UpdateType.CALLBACK_QUERY, CallbackQuery) @abc.abstractmethod async def check(self, event: CallbackQuery, ctx: Context) -> bool: diff --git a/telegrinder/bot/rules/chat_join.py b/telegrinder/bot/rules/chat_join.py index ad6c84a..ea419a7 100644 --- a/telegrinder/bot/rules/chat_join.py +++ b/telegrinder/bot/rules/chat_join.py @@ -4,14 +4,15 @@ from telegrinder.bot.cute_types import ChatJoinRequestCute from telegrinder.bot.dispatch.context import Context from telegrinder.bot.rules.adapter import EventAdapter +from telegrinder.types.enums import UpdateType from .abc import ABCRule ChatJoinRequest: typing.TypeAlias = ChatJoinRequestCute -class ChatJoinRequestRule(ABCRule[ChatJoinRequestCute], requires=[]): - adapter = EventAdapter("chat_join_request", ChatJoinRequest) +class ChatJoinRequestRule(ABCRule[ChatJoinRequest], requires=[]): + adapter: EventAdapter[ChatJoinRequest] = EventAdapter(UpdateType.CHAT_JOIN_REQUEST, ChatJoinRequest) @abc.abstractmethod async def check(self, event: ChatJoinRequest, ctx: Context) -> bool: diff --git a/telegrinder/bot/rules/func.py b/telegrinder/bot/rules/func.py index f6c06e2..6035c1f 100755 --- a/telegrinder/bot/rules/func.py +++ b/telegrinder/bot/rules/func.py @@ -4,19 +4,19 @@ from telegrinder.bot.dispatch.context import Context from telegrinder.types import Update -from .abc import ABCAdapter, ABCRule, RawUpdateAdapter, T +from .abc import ABCAdapter, ABCRule, AdaptTo, EventCute, RawUpdateAdapter -class FuncRule(ABCRule, typing.Generic[T]): +class FuncRule(ABCRule[EventCute, AdaptTo], typing.Generic[AdaptTo, EventCute]): def __init__( self, - func: typing.Callable[[T, Context], typing.Awaitable[bool] | bool], - adapter: ABCAdapter[Update, T] | None = None, + func: typing.Callable[[AdaptTo, Context], typing.Awaitable[bool] | bool], + adapter: ABCAdapter[Update, AdaptTo] | None = None, ): self.func = func - self.adapter = adapter or RawUpdateAdapter() + self.adapter = adapter or RawUpdateAdapter() # type: ignore - async def check(self, event: T, ctx: Context) -> bool: + async def check(self, event: AdaptTo, ctx: Context) -> bool: result = self.func(event, ctx) if inspect.isawaitable(result): return await result diff --git a/telegrinder/bot/rules/inline.py b/telegrinder/bot/rules/inline.py index 64b6785..98a5df9 100755 --- a/telegrinder/bot/rules/inline.py +++ b/telegrinder/bot/rules/inline.py @@ -5,7 +5,7 @@ from telegrinder.bot.dispatch.context import Context from telegrinder.bot.rules.abc import ABCRule from telegrinder.bot.rules.adapter import EventAdapter -from telegrinder.types.enums import ChatType +from telegrinder.types.enums import ChatType, UpdateType from .markup import Markup, PatternLike, check_string @@ -13,7 +13,7 @@ class InlineQueryRule(ABCRule[InlineQuery], abc.ABC): - adapter = EventAdapter("inline_query", InlineQuery) + adapter: EventAdapter[InlineQuery] = EventAdapter(UpdateType.INLINE_QUERY, InlineQuery) @abc.abstractmethod async def check(self, query: InlineQuery, ctx: Context) -> bool: diff --git a/telegrinder/bot/rules/is_from.py b/telegrinder/bot/rules/is_from.py index a77e618..f8d254f 100755 --- a/telegrinder/bot/rules/is_from.py +++ b/telegrinder/bot/rules/is_from.py @@ -1,34 +1,31 @@ +import abc import typing -from fntypes.co import Nothing, Some - from telegrinder.bot.cute_types.base import BaseCute -from telegrinder.bot.cute_types.update import UpdateCute from telegrinder.bot.dispatch.context import Context -from telegrinder.msgspec_utils import Option +from telegrinder.bot.rules.adapter.chat import ChatAdapter +from telegrinder.bot.rules.adapter.user import UserAdapter from telegrinder.types.enums import ChatType, DiceEmoji -from telegrinder.types.objects import User +from telegrinder.types.objects import Chat, User from .abc import ABCRule, Message from .message import MessageRule -T = typing.TypeVar("T", bound=BaseCute) +EventCute = typing.TypeVar("EventCute", bound=BaseCute) -def get_from_user(obj: typing.Any) -> User: - assert isinstance(obj, FromUserProto) - return obj.from_.unwrap() if isinstance(obj.from_, Some | Nothing) else obj.from_ +class UserRule(ABCRule[EventCute]): + adapter: UserAdapter = UserAdapter() + @abc.abstractmethod + async def check(self, user: User, ctx: Context) -> bool: ... -@typing.runtime_checkable -class FromUserProto(typing.Protocol): - from_: User | Option[User] +class ChatRule(ABCRule[EventCute]): + adapter: ChatAdapter = ChatAdapter() -class HasFrom(ABCRule[T]): - async def check(self, event: UpdateCute, ctx: Context) -> bool: - event_model = event.incoming_update.unwrap() - return isinstance(event_model, FromUserProto) and bool(event_model.from_) + @abc.abstractmethod + async def check(self, chat: Chat, ctx: Context) -> bool: ... class HasDice(MessageRule): @@ -36,104 +33,101 @@ async def check(self, message: Message, ctx: Context) -> bool: return bool(message.dice) -class IsForward(MessageRule): - async def check(self, message: Message, ctx: Context) -> bool: - return bool(message.forward_origin) +class IsBot(UserRule[EventCute]): + async def check(self, user: User, ctx: Context) -> bool: + return user.is_bot -class IsForwardType(MessageRule, requires=[IsForward()]): - def __init__( - self, fwd_type: typing.Literal["user", "hidden_user", "chat", "channel"], / - ) -> None: - self.fwd_type = fwd_type +class IsUser(UserRule[EventCute]): + async def check(self, user: User, ctx: Context) -> bool: + return not user.is_bot - async def check(self, message: Message, ctx: Context) -> bool: - return message.forward_origin.unwrap().v.type == self.fwd_type +class IsPremium(UserRule[EventCute]): + async def check(self, user: User, ctx: Context) -> bool: + return not user.is_premium.unwrap_or(False) -class IsReply(MessageRule): - async def check(self, message: Message, ctx: Context) -> bool: - return bool(message.reply_to_message) +class IsLanguageCode(UserRule[EventCute]): + def __init__(self, lang_codes: str | list[str], /) -> None: + self.lang_codes = [lang_codes] if isinstance(lang_codes, str) else lang_codes -class IsSticker(MessageRule): - async def check(self, message: Message, ctx: Context) -> bool: - return bool(message.sticker) + async def check(self, user: User, ctx: Context) -> bool: + return user.language_code.unwrap_or_none() in self.lang_codes -class IsBot(ABCRule[T], requires=[HasFrom()]): - async def check(self, event: UpdateCute, ctx: Context) -> bool: - return get_from_user(event.incoming_update.unwrap()).is_bot +class IsUserId(UserRule[EventCute]): + def __init__(self, user_ids: int | list[int], /) -> None: + self.user_ids = [user_ids] if isinstance(user_ids, int) else user_ids + async def check(self, user: User, ctx: Context) -> bool: + return user.id in self.user_ids -class IsUser(ABCRule[T], requires=[HasFrom()]): - async def check(self, event: UpdateCute, ctx: Context) -> bool: - return not get_from_user(event.incoming_update.unwrap()).is_bot +class IsForum(ChatRule[EventCute]): + async def check(self, chat: Chat, ctx: Context) -> bool: + return chat.is_forum.unwrap_or(False) -class IsPremium(ABCRule[T], requires=[HasFrom()]): - async def check(self, event: UpdateCute, ctx: Context) -> bool: - return get_from_user(event.incoming_update.unwrap()).is_premium.unwrap_or(False) +class IsChatId(ChatRule[EventCute]): + def __init__(self, chat_ids: int | list[int], /) -> None: + self.chat_ids = [chat_ids] if isinstance(chat_ids, int) else chat_ids -class IsLanguageCode(ABCRule[T], requires=[HasFrom()]): - def __init__(self, lang_codes: str | list[str], /) -> None: - self.lang_codes = [lang_codes] if isinstance(lang_codes, str) else lang_codes + async def check(self, chat: Chat, ctx: Context) -> bool: + return chat.id in self.chat_ids - async def check(self, event: UpdateCute, ctx: Context) -> bool: - return ( - get_from_user(event.incoming_update.unwrap()).language_code.unwrap_or_none() - in self.lang_codes - ) +class IsPrivate(ChatRule[EventCute]): + async def check(self, chat: Chat, ctx: Context) -> bool: + return chat.type == ChatType.PRIVATE -class IsForum(MessageRule): - async def check(self, message: Message, ctx: Context) -> bool: - return message.chat.is_forum.unwrap_or(False) +class IsGroup(ChatRule[EventCute]): + async def check(self, chat: Chat, ctx: Context) -> bool: + return chat.type == ChatType.GROUP -class IsUserId(ABCRule[T], requires=[HasFrom()]): - def __init__(self, user_ids: int | list[int], /) -> None: - self.user_ids = [user_ids] if isinstance(user_ids, int) else user_ids - async def check(self, event: UpdateCute, ctx: Context) -> bool: - return get_from_user(event.incoming_update.unwrap()).id in self.user_ids +class IsSuperGroup(ChatRule[EventCute]): + async def check(self, chat: Chat, ctx: Context) -> bool: + return chat.type == ChatType.SUPERGROUP -class IsChatId(MessageRule): - def __init__(self, chat_ids: int | list[int], /) -> None: - self.chat_ids = [chat_ids] if isinstance(chat_ids, int) else chat_ids +class IsChat(ChatRule[EventCute]): + async def check(self, chat: Chat, ctx: Context) -> bool: + return chat.type in (ChatType.GROUP, ChatType.SUPERGROUP) - async def check(self, message: Message, ctx: Context) -> bool: - return message.chat.id in self.chat_ids +class IsDiceEmoji(MessageRule, requires=[HasDice()]): + def __init__(self, dice_emoji: DiceEmoji, /) -> None: + self.dice_emoji = dice_emoji -class IsPrivate(MessageRule): async def check(self, message: Message, ctx: Context) -> bool: - return message.chat.type == ChatType.PRIVATE + return message.dice.unwrap().emoji == self.dice_emoji -class IsGroup(MessageRule): +class IsForward(MessageRule): async def check(self, message: Message, ctx: Context) -> bool: - return message.chat.type == ChatType.GROUP + return bool(message.forward_origin) -class IsSuperGroup(MessageRule): +class IsForwardType(MessageRule, requires=[IsForward()]): + def __init__( + self, fwd_type: typing.Literal["user", "hidden_user", "chat", "channel"], / + ) -> None: + self.fwd_type = fwd_type + async def check(self, message: Message, ctx: Context) -> bool: - return message.chat.type == ChatType.SUPERGROUP + return message.forward_origin.unwrap().v.type == self.fwd_type -class IsChat(MessageRule): +class IsReply(MessageRule): async def check(self, message: Message, ctx: Context) -> bool: - return message.chat.type in (ChatType.GROUP, ChatType.SUPERGROUP) + return bool(message.reply_to_message) -class IsDiceEmoji(MessageRule, requires=[HasDice()]): - def __init__(self, dice_emoji: DiceEmoji, /) -> None: - self.dice_emoji = dice_emoji - +class IsSticker(MessageRule): async def check(self, message: Message, ctx: Context) -> bool: - return message.dice.unwrap().emoji == self.dice_emoji + return bool(message.sticker) __all__ = ( diff --git a/telegrinder/bot/rules/message.py b/telegrinder/bot/rules/message.py index b01bc5e..80111a8 100644 --- a/telegrinder/bot/rules/message.py +++ b/telegrinder/bot/rules/message.py @@ -8,7 +8,7 @@ class MessageRule(ABCRule[Message], abc.ABC): - adapter = EventAdapter(MessageEvent, Message) + adapter: EventAdapter[Message] = EventAdapter(MessageEvent, Message) @abc.abstractmethod async def check(self, message: Message, ctx: Context) -> bool: ... diff --git a/telegrinder/bot/rules/rule_enum.py b/telegrinder/bot/rules/rule_enum.py index 4c96a40..d44659f 100755 --- a/telegrinder/bot/rules/rule_enum.py +++ b/telegrinder/bot/rules/rule_enum.py @@ -3,26 +3,26 @@ from telegrinder.bot.dispatch.context import Context -from .abc import ABCRule, T, Update, check_rule +from .abc import ABCRule, AdaptTo, EventCute, Update, check_rule from .func import FuncRule @dataclasses.dataclass -class RuleEnumState: +class RuleEnumState(typing.Generic[EventCute, AdaptTo]): name: str rule: ABCRule - cls: type["RuleEnum"] + cls: type["RuleEnum[EventCute, AdaptTo]"] def __eq__(self, other: typing.Self) -> bool: return self.cls == other.cls and self.name == other.name -class RuleEnum(ABCRule[T]): - __enum__: list[RuleEnumState] +class RuleEnum(ABCRule[EventCute, AdaptTo]): + __enum__: list[RuleEnumState[EventCute, AdaptTo]] - def __init_subclass__(cls, *args, **kwargs): + def __init_subclass__(cls, *args: typing.Any, **kwargs: typing.Any) -> None: new_attributes = set(cls.__dict__) - set(RuleEnum.__dict__) - {"__enum__", "__init__"} - enum_lst: list[RuleEnumState] = [] + enum_lst: list[RuleEnumState[EventCute, AdaptTo]] = [] self = cls.__new__(cls) self.__init__() @@ -30,26 +30,26 @@ def __init_subclass__(cls, *args, **kwargs): for attribute_name in new_attributes: rules = getattr(cls, attribute_name) attribute = RuleEnumState(attribute_name, rules, cls) - + setattr( self, attribute.name, - self & FuncRule(lambda _, ctx: self.must_be_state(ctx, attribute)), + self & FuncRule(lambda _, ctx: self.must_be_state(ctx, attribute)), # type: ignore ) enum_lst.append(attribute) setattr(cls, "__enum__", enum_lst) @classmethod - def save_state(cls, ctx: Context, enum: RuleEnumState) -> None: + def save_state(cls, ctx: Context, enum: RuleEnumState[EventCute, AdaptTo]) -> None: ctx.update({cls.__class__.__name__ + "_state": enum}) @classmethod - def check_state(cls, ctx: Context) -> RuleEnumState | None: + def check_state(cls, ctx: Context) -> RuleEnumState[EventCute, AdaptTo] | None: return ctx.get(cls.__class__.__name__ + "_state") @classmethod - def must_be_state(cls, ctx: Context, state: RuleEnumState) -> bool: + def must_be_state(cls, ctx: Context, state: RuleEnumState[EventCute, AdaptTo]) -> bool: real_state = cls.check_state(ctx) if not real_state: return False diff --git a/telegrinder/bot/rules/update.py b/telegrinder/bot/rules/update.py index 38f4ef2..405bc75 100644 --- a/telegrinder/bot/rules/update.py +++ b/telegrinder/bot/rules/update.py @@ -2,15 +2,15 @@ from telegrinder.bot.dispatch.context import Context from telegrinder.types.enums import UpdateType -from .abc import ABCRule, T +from .abc import ABCRule, EventCute -class IsUpdate(ABCRule[T]): +class IsUpdate(ABCRule[EventCute]): def __init__(self, update_type: UpdateType, /) -> None: self.update_type = update_type async def check(self, event: UpdateCute, ctx: Context) -> bool: - return event.update_type.unwrap_or_none() == self.update_type + return event.update_type == self.update_type __all__ = ("IsUpdate",) diff --git a/telegrinder/types/objects.py b/telegrinder/types/objects.py index 0662410..6c2272d 100644 --- a/telegrinder/types/objects.py +++ b/telegrinder/types/objects.py @@ -300,23 +300,20 @@ class Update(Model): in the chat to receive these updates.""" def __eq__(self, other: typing.Any) -> bool: - return isinstance(other, self.__class__) and self.update_type.map( - lambda x: x == other.update_type.unwrap_or_none(), - ).unwrap_or(False) + return isinstance(other, self.__class__) and self.update_type == other.update_type @property - def update_type(self) -> Option[UpdateType]: + def update_type(self) -> UpdateType: """Incoming update type.""" - if update := next( - filter( - lambda x: bool(x[1]), - self.to_dict(exclude_fields={"update_id"}).items(), - ), - None, - ): - return Some(UpdateType(update[0])) - return Nothing + return UpdateType( + next( + filter( + lambda x: bool(x[1]), + self.to_dict(exclude_fields={"update_id"}).items(), + ), + )[0], + ) class WebhookInfo(Model):