diff --git a/examples/composition.py b/examples/composition.py index 6827161..1473981 100755 --- a/examples/composition.py +++ b/examples/composition.py @@ -7,10 +7,11 @@ from telegrinder.bot.dispatch import CompositionDispatch from telegrinder.modules import logger from telegrinder.node import Photo, RuleChain, ScalarNode, Source, Text, generate_node +from telegrinder.types.enums import ChatType api = API(token=Token.from_env()) bot = Telegrinder(api, dispatch=CompositionDispatch()) -logger.set_level("DEBUG") +logger.set_level("INFO") @bot.loop_wrapper.lifespan.on_startup @@ -45,7 +46,7 @@ async def photo_handler(photo: Photo, source: Source, db: DB): # Container generate_noded node examples @bot.on( generate_node((Text,), lambda text: text == "hello"), - generate_node((Source,), lambda src: src.chat.username.unwrap_or_none() == "weirdlashes"), + generate_node((Source,), lambda src: src.chat.type == ChatType.PRIVATE), ) async def hi_handler(source: Source): await source.send("Hi !!") diff --git a/telegrinder/bot/cute_types/callback_query.py b/telegrinder/bot/cute_types/callback_query.py index ee5e9d5..449fe48 100755 --- a/telegrinder/bot/cute_types/callback_query.py +++ b/telegrinder/bot/cute_types/callback_query.py @@ -9,6 +9,7 @@ from telegrinder.msgspec_utils import Option, decoder from telegrinder.types import ( CallbackQuery, + Chat, InlineKeyboardMarkup, InputFile, InputMedia, @@ -44,7 +45,7 @@ def is_topic_message(self) -> Option[bool]: by the bot with the callback button that originated the query.""" return self.message.map( - lambda m: m.only().map(lambda m: m.is_topic_message.unwrap_or(False)).unwrap_or(False) + lambda m: m.only().map(lambda m: m.is_topic_message.unwrap_or(False)).unwrap_or(False), ) @property @@ -68,6 +69,14 @@ def message_id(self) -> Option[int]: """ return self.message.map(lambda m: m.v.message_id) + + @property + def chat(self) -> Option[Chat]: + """Optional. Chat the callback query originated from. This will be present + if the message is sent by the bot with the callback button that originated the query. + """ + + return self.message.map(lambda m: m.v.chat) def decode_callback_data(self) -> Option[dict[str, typing.Any]]: if "cached_callback_data" in self.__dict__: diff --git a/telegrinder/bot/dispatch/composition.py b/telegrinder/bot/dispatch/composition.py index a79788c..7a279c7 100755 --- a/telegrinder/bot/dispatch/composition.py +++ b/telegrinder/bot/dispatch/composition.py @@ -36,7 +36,7 @@ def load(self, external: typing.Self): self.compositions.extend(external.compositions) def __call__(self, *container_nodes: type[Node], is_blocking: bool = True): - def wrapper(func: typing.Callable): + def wrapper(func: typing.Callable[..., typing.Any]): composition = Composition(func, is_blocking) if container_nodes: composition.nodes["container"] = ContainerNode.link_nodes(list(container_nodes)) diff --git a/telegrinder/node/base.py b/telegrinder/node/base.py index 20d3e64..eb7eb6d 100755 --- a/telegrinder/node/base.py +++ b/telegrinder/node/base.py @@ -2,9 +2,7 @@ import inspect import typing -ComposeResult: typing.TypeAlias = ( - typing.Coroutine[typing.Any, typing.Any, typing.Any] | typing.AsyncGenerator[typing.Any, None] -) +ComposeResult: typing.TypeAlias = typing.Coroutine[typing.Any, typing.Any, typing.Any] | typing.AsyncGenerator[typing.Any, None] class ComposeError(BaseException): ... diff --git a/telegrinder/node/composer.py b/telegrinder/node/composer.py index 7fa3e07..ec94f87 100755 --- a/telegrinder/node/composer.py +++ b/telegrinder/node/composer.py @@ -6,6 +6,29 @@ from telegrinder.tools.magic import magic_bundle +async def compose_node(_node: type[Node], update: UpdateCute, ready_context: dict[str, "NodeSession"] | None = None) -> "NodeSession": + node = _node.as_node() + + context = NodeCollection(ready_context.copy() if ready_context else {}) + + for name, subnode in node.get_sub_nodes().items(): + if subnode is UpdateCute: + context.sessions[name] = NodeSession(update, {}) + else: + context.sessions[name] = await compose_node(subnode, update) + + generator: typing.AsyncGenerator | None + + if node.is_generator(): + generator = typing.cast(typing.AsyncGenerator, node.compose(**context.values())) + value = await generator.asend(None) + else: + generator = None + value = await node.compose(**context.values()) # type: ignore + + return NodeSession(value, context.sessions, generator) + + class NodeSession: def __init__( self, @@ -44,33 +67,10 @@ async def close_all(self, with_value: typing.Any | None = None) -> None: await session.close(with_value) -async def compose_node(_node: type[Node], update: UpdateCute, ready_context: dict[str, NodeSession] | None = None) -> NodeSession: - node = _node.as_node() - - context = NodeCollection(ready_context.copy() if ready_context else {}) - - for name, subnode in node.get_sub_nodes().items(): - if subnode is UpdateCute: - context.sessions[name] = NodeSession(update, {}) - else: - context.sessions[name] = await compose_node(subnode, update) - - generator: typing.AsyncGenerator | None - - if node.is_generator(): - generator = typing.cast(typing.AsyncGenerator, node.compose(**context.values())) - value = await generator.asend(None) - else: - generator = None - value = await node.compose(**context.values()) # type: ignore - - return NodeSession(value, context.sessions, generator) - - class Composition: nodes: dict[str, type[Node]] - def __init__(self, func: typing.Callable, is_blocking: bool) -> None: + def __init__(self, func: typing.Callable[..., typing.Any], is_blocking: bool) -> None: self.func = func self.nodes = { name: parameter.annotation @@ -100,4 +100,4 @@ async def __call__(self, **kwargs: typing.Any) -> typing.Any: return await self.func(**magic_bundle(self.func, kwargs, start_idx=0, bundle_ctx=False)) # type: ignore -__all__ = ("NodeCollection", "NodeSession", "compose_node", "Composition") +__all__ = ("NodeCollection", "NodeSession", "Composition", "compose_node") diff --git a/telegrinder/node/polymorphic.py b/telegrinder/node/polymorphic.py index 3d53914..608fbe0 100644 --- a/telegrinder/node/polymorphic.py +++ b/telegrinder/node/polymorphic.py @@ -10,19 +10,21 @@ class Polymorphic: @classmethod - async def compose(cls, update: UpdateNode): - impls: list[typing.Callable] = get_impls(cls) - for impl in impls: + async def compose(cls, update: UpdateNode) -> typing.Any: + for impl in get_impls(cls): composition = Composition(impl, True) node_collection = await composition.compose_nodes(update) if node_collection is None: continue + result = composition.func(cls, **node_collection.values()) if inspect.isawaitable(result): - result = await result # noqa + result = await result + await node_collection.close_all(with_value=result) return result - raise ComposeError("No implementation found") + + raise ComposeError("No implementation found.") __all__ = ("Polymorphic", "impl") diff --git a/telegrinder/node/rule.py b/telegrinder/node/rule.py index bb9816a..338038f 100755 --- a/telegrinder/node/rule.py +++ b/telegrinder/node/rule.py @@ -38,10 +38,11 @@ def is_generator(cls) -> typing.Literal[False]: def __new__(cls, *rules: ABCRule) -> type[Node]: return type("_RuleNode", (cls,), {"dataclass": dict, "rules": rules}) # type: ignore - def __class_getitem__(cls, item: tuple[ABCRule, ...]) -> typing.Self: - if not isinstance(item, tuple): - item = (item,) - return cls(*item) + def __class_getitem__(cls, items: ABCRule | tuple[ABCRule, ...]) -> typing.Self: + if not isinstance(items, tuple): + items = (items,) + assert all(isinstance(rule, ABCRule) for rule in items), "All items must be instances of 'ABCRule'." + return cls(*items) @staticmethod def generate_node_dataclass(cls_: type["RuleChain"]): # noqa: ANN205 diff --git a/telegrinder/node/source.py b/telegrinder/node/source.py index e68bf03..190fd3d 100755 --- a/telegrinder/node/source.py +++ b/telegrinder/node/source.py @@ -16,7 +16,7 @@ class Source(Polymorphic, DataNode): api: API chat: Chat - thread_id: Option[int] = dataclasses.field(default_factory=Nothing) + thread_id: Option[int] = dataclasses.field(default_factory=lambda: Nothing()) @impl async def compose_message(cls, message: MessageNode) -> typing.Self: @@ -30,7 +30,7 @@ async def compose_message(cls, message: MessageNode) -> typing.Self: async def compose_callback_query(cls, callback_query: CallbackQueryNode) -> typing.Self: return cls( api=callback_query.ctx_api, - chat=callback_query.message.expect(ComposeError).only(Message).expect(ComposeError).chat, + chat=callback_query.chat.expect(ComposeError), thread_id=callback_query.message_thread_id, ) diff --git a/telegrinder/node/tools/generator.py b/telegrinder/node/tools/generator.py index ad36565..4ff0dee 100755 --- a/telegrinder/node/tools/generator.py +++ b/telegrinder/node/tools/generator.py @@ -22,7 +22,7 @@ def error_on_none(value: T | None) -> T: def generate_node( subnodes: tuple[type[Node], ...], func: typing.Callable[..., typing.Any], - casts: tuple[typing.Callable, ...] = (cast_false_to_none, error_on_none), + casts: tuple[typing.Callable[[typing.Any], typing.Any], ...] = (cast_false_to_none, error_on_none), ) -> type[ContainerNode]: async def compose(**kw: typing.Any) -> typing.Any: args = await ContainerNode.compose(**kw) diff --git a/telegrinder/tools/magic.py b/telegrinder/tools/magic.py index 8676c01..5be3942 100755 --- a/telegrinder/tools/magic.py +++ b/telegrinder/tools/magic.py @@ -9,14 +9,9 @@ T = typing.TypeVar("T", bound=ABCRule) FuncType: typing.TypeAlias = types.FunctionType | typing.Callable[..., typing.Any] -TRANSLATIONS_KEY: typing.Final[str] = "_translations" - -Cls = typing.TypeVar("Cls") -P = typing.ParamSpec("P") -R = typing.TypeVar("R", covariant=True) - -IMPL_MARK = "_is_impl" +TRANSLATIONS_KEY: typing.Final[str] = "_translations" +IMPL_MARK: typing.Final[str] = "_is_impl" def resolve_arg_names(func: FuncType, start_idx: int = 1) -> tuple[str, ...]: return func.__code__.co_varnames[start_idx : func.__code__.co_argcount] @@ -48,11 +43,8 @@ def magic_bundle( return args -def get_cached_translation(rule: "T", locale: str) -> typing.Optional["T"]: - translations = getattr(rule, TRANSLATIONS_KEY, {}) - if not translations or locale not in translations: - return None - return translations[locale] +def get_cached_translation(rule: "T", locale: str) -> "T | None": + return getattr(rule, TRANSLATIONS_KEY, {}).get(locale) def cache_translation(base_rule: "T", locale: str, translated_rule: "T") -> None: @@ -60,18 +52,17 @@ def cache_translation(base_rule: "T", locale: str, translated_rule: "T") -> None setattr(base_rule, TRANSLATIONS_KEY, {locale: translated_rule, **translations}) -def get_impls(cls: type) -> list[typing.Callable]: +def get_impls(cls: type[typing.Any]) -> list[typing.Callable[..., typing.Any]]: functions = [func.__func__ for func in cls.__dict__.values() if hasattr(func, "__func__")] return [impl for impl in functions if getattr(impl, IMPL_MARK, False) is True] -if typing.TYPE_CHECKING: - impl = classmethod # type: ignore -else: - def impl(method: typing.Callable[typing.Concatenate[type[Cls], P], R]) -> typing.Callable[P, R]: - bound_method = classmethod(method) - setattr(method, IMPL_MARK, True) - return bound_method # type: ignore +@typing.cast(typing.Callable[..., type[classmethod]], lambda f: f) +def impl(method): # noqa + bound_method = classmethod(method) + setattr(method, IMPL_MARK, True) + return bound_method + __all__ = ( "TRANSLATIONS_KEY", @@ -80,6 +71,7 @@ def impl(method: typing.Callable[typing.Concatenate[type[Cls], P], R]) -> typing "get_default_args", "get_default_args", "magic_bundle", + "impl", "resolve_arg_names", "to_str", )