Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
luwqz1 committed Jun 14, 2024
1 parent 0440eed commit 3edf3fa
Show file tree
Hide file tree
Showing 10 changed files with 67 additions and 64 deletions.
5 changes: 3 additions & 2 deletions examples/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from telegrinder.bot.dispatch import CompositionDispatch
from telegrinder.modules import logger
from telegrinder.node import Photo, RuleChain, ScalarNode, Source, Text, generate_node
from telegrinder.types.enums import ChatType

api = API(token=Token.from_env())
bot = Telegrinder(api, dispatch=CompositionDispatch())
logger.set_level("DEBUG")
logger.set_level("INFO")


@bot.loop_wrapper.lifespan.on_startup
Expand Down Expand Up @@ -45,7 +46,7 @@ async def photo_handler(photo: Photo, source: Source, db: DB):
# Container generate_noded node examples
@bot.on(
generate_node((Text,), lambda text: text == "hello"),
generate_node((Source,), lambda src: src.chat.username.unwrap_or_none() == "weirdlashes"),
generate_node((Source,), lambda src: src.chat.type == ChatType.PRIVATE),
)
async def hi_handler(source: Source):
await source.send("Hi !!")
Expand Down
11 changes: 10 additions & 1 deletion telegrinder/bot/cute_types/callback_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from telegrinder.msgspec_utils import Option, decoder
from telegrinder.types import (
CallbackQuery,
Chat,
InlineKeyboardMarkup,
InputFile,
InputMedia,
Expand Down Expand Up @@ -44,7 +45,7 @@ def is_topic_message(self) -> Option[bool]:
by the bot with the callback button that originated the query."""

return self.message.map(
lambda m: m.only().map(lambda m: m.is_topic_message.unwrap_or(False)).unwrap_or(False)
lambda m: m.only().map(lambda m: m.is_topic_message.unwrap_or(False)).unwrap_or(False),
)

@property
Expand All @@ -68,6 +69,14 @@ def message_id(self) -> Option[int]:
"""

return self.message.map(lambda m: m.v.message_id)

@property
def chat(self) -> Option[Chat]:
"""Optional. Chat the callback query originated from. This will be present
if the message is sent by the bot with the callback button that originated the query.
"""

return self.message.map(lambda m: m.v.chat)

def decode_callback_data(self) -> Option[dict[str, typing.Any]]:
if "cached_callback_data" in self.__dict__:
Expand Down
2 changes: 1 addition & 1 deletion telegrinder/bot/dispatch/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def load(self, external: typing.Self):
self.compositions.extend(external.compositions)

def __call__(self, *container_nodes: type[Node], is_blocking: bool = True):
def wrapper(func: typing.Callable):
def wrapper(func: typing.Callable[..., typing.Any]):
composition = Composition(func, is_blocking)
if container_nodes:
composition.nodes["container"] = ContainerNode.link_nodes(list(container_nodes))
Expand Down
4 changes: 1 addition & 3 deletions telegrinder/node/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
import inspect
import typing

ComposeResult: typing.TypeAlias = (
typing.Coroutine[typing.Any, typing.Any, typing.Any] | typing.AsyncGenerator[typing.Any, None]
)
ComposeResult: typing.TypeAlias = typing.Coroutine[typing.Any, typing.Any, typing.Any] | typing.AsyncGenerator[typing.Any, None]


class ComposeError(BaseException): ...
Expand Down
50 changes: 25 additions & 25 deletions telegrinder/node/composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,29 @@
from telegrinder.tools.magic import magic_bundle


async def compose_node(_node: type[Node], update: UpdateCute, ready_context: dict[str, "NodeSession"] | None = None) -> "NodeSession":
node = _node.as_node()

context = NodeCollection(ready_context.copy() if ready_context else {})

for name, subnode in node.get_sub_nodes().items():
if subnode is UpdateCute:
context.sessions[name] = NodeSession(update, {})
else:
context.sessions[name] = await compose_node(subnode, update)

generator: typing.AsyncGenerator | None

if node.is_generator():
generator = typing.cast(typing.AsyncGenerator, node.compose(**context.values()))
value = await generator.asend(None)
else:
generator = None
value = await node.compose(**context.values()) # type: ignore

return NodeSession(value, context.sessions, generator)


class NodeSession:
def __init__(
self,
Expand Down Expand Up @@ -44,33 +67,10 @@ async def close_all(self, with_value: typing.Any | None = None) -> None:
await session.close(with_value)


async def compose_node(_node: type[Node], update: UpdateCute, ready_context: dict[str, NodeSession] | None = None) -> NodeSession:
node = _node.as_node()

context = NodeCollection(ready_context.copy() if ready_context else {})

for name, subnode in node.get_sub_nodes().items():
if subnode is UpdateCute:
context.sessions[name] = NodeSession(update, {})
else:
context.sessions[name] = await compose_node(subnode, update)

generator: typing.AsyncGenerator | None

if node.is_generator():
generator = typing.cast(typing.AsyncGenerator, node.compose(**context.values()))
value = await generator.asend(None)
else:
generator = None
value = await node.compose(**context.values()) # type: ignore

return NodeSession(value, context.sessions, generator)


class Composition:
nodes: dict[str, type[Node]]

def __init__(self, func: typing.Callable, is_blocking: bool) -> None:
def __init__(self, func: typing.Callable[..., typing.Any], is_blocking: bool) -> None:
self.func = func
self.nodes = {
name: parameter.annotation
Expand Down Expand Up @@ -100,4 +100,4 @@ async def __call__(self, **kwargs: typing.Any) -> typing.Any:
return await self.func(**magic_bundle(self.func, kwargs, start_idx=0, bundle_ctx=False)) # type: ignore


__all__ = ("NodeCollection", "NodeSession", "compose_node", "Composition")
__all__ = ("NodeCollection", "NodeSession", "Composition", "compose_node")
12 changes: 7 additions & 5 deletions telegrinder/node/polymorphic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,21 @@

class Polymorphic:
@classmethod
async def compose(cls, update: UpdateNode):
impls: list[typing.Callable] = get_impls(cls)
for impl in impls:
async def compose(cls, update: UpdateNode) -> typing.Any:
for impl in get_impls(cls):
composition = Composition(impl, True)
node_collection = await composition.compose_nodes(update)
if node_collection is None:
continue

result = composition.func(cls, **node_collection.values())
if inspect.isawaitable(result):
result = await result # noqa
result = await result

await node_collection.close_all(with_value=result)
return result
raise ComposeError("No implementation found")

raise ComposeError("No implementation found.")


__all__ = ("Polymorphic", "impl")
9 changes: 5 additions & 4 deletions telegrinder/node/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,11 @@ def is_generator(cls) -> typing.Literal[False]:
def __new__(cls, *rules: ABCRule) -> type[Node]:
return type("_RuleNode", (cls,), {"dataclass": dict, "rules": rules}) # type: ignore

def __class_getitem__(cls, item: tuple[ABCRule, ...]) -> typing.Self:
if not isinstance(item, tuple):
item = (item,)
return cls(*item)
def __class_getitem__(cls, items: ABCRule | tuple[ABCRule, ...]) -> typing.Self:
if not isinstance(items, tuple):
items = (items,)
assert all(isinstance(rule, ABCRule) for rule in items), "All items must be instances of 'ABCRule'."
return cls(*items)

@staticmethod
def generate_node_dataclass(cls_: type["RuleChain"]): # noqa: ANN205
Expand Down
4 changes: 2 additions & 2 deletions telegrinder/node/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
class Source(Polymorphic, DataNode):
api: API
chat: Chat
thread_id: Option[int] = dataclasses.field(default_factory=Nothing)
thread_id: Option[int] = dataclasses.field(default_factory=lambda: Nothing())

@impl
async def compose_message(cls, message: MessageNode) -> typing.Self:
Expand All @@ -30,7 +30,7 @@ async def compose_message(cls, message: MessageNode) -> typing.Self:
async def compose_callback_query(cls, callback_query: CallbackQueryNode) -> typing.Self:
return cls(
api=callback_query.ctx_api,
chat=callback_query.message.expect(ComposeError).only(Message).expect(ComposeError).chat,
chat=callback_query.chat.expect(ComposeError),
thread_id=callback_query.message_thread_id,
)

Expand Down
2 changes: 1 addition & 1 deletion telegrinder/node/tools/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def error_on_none(value: T | None) -> T:
def generate_node(
subnodes: tuple[type[Node], ...],
func: typing.Callable[..., typing.Any],
casts: tuple[typing.Callable, ...] = (cast_false_to_none, error_on_none),
casts: tuple[typing.Callable[[typing.Any], typing.Any], ...] = (cast_false_to_none, error_on_none),
) -> type[ContainerNode]:
async def compose(**kw: typing.Any) -> typing.Any:
args = await ContainerNode.compose(**kw)
Expand Down
32 changes: 12 additions & 20 deletions telegrinder/tools/magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,9 @@
T = typing.TypeVar("T", bound=ABCRule)

FuncType: typing.TypeAlias = types.FunctionType | typing.Callable[..., typing.Any]
TRANSLATIONS_KEY: typing.Final[str] = "_translations"

Cls = typing.TypeVar("Cls")
P = typing.ParamSpec("P")
R = typing.TypeVar("R", covariant=True)


IMPL_MARK = "_is_impl"
TRANSLATIONS_KEY: typing.Final[str] = "_translations"
IMPL_MARK: typing.Final[str] = "_is_impl"

def resolve_arg_names(func: FuncType, start_idx: int = 1) -> tuple[str, ...]:
return func.__code__.co_varnames[start_idx : func.__code__.co_argcount]
Expand Down Expand Up @@ -48,30 +43,26 @@ def magic_bundle(
return args


def get_cached_translation(rule: "T", locale: str) -> typing.Optional["T"]:
translations = getattr(rule, TRANSLATIONS_KEY, {})
if not translations or locale not in translations:
return None
return translations[locale]
def get_cached_translation(rule: "T", locale: str) -> "T | None":
return getattr(rule, TRANSLATIONS_KEY, {}).get(locale)


def cache_translation(base_rule: "T", locale: str, translated_rule: "T") -> None:
translations = getattr(base_rule, TRANSLATIONS_KEY, {})
setattr(base_rule, TRANSLATIONS_KEY, {locale: translated_rule, **translations})


def get_impls(cls: type) -> list[typing.Callable]:
def get_impls(cls: type[typing.Any]) -> list[typing.Callable[..., typing.Any]]:
functions = [func.__func__ for func in cls.__dict__.values() if hasattr(func, "__func__")]
return [impl for impl in functions if getattr(impl, IMPL_MARK, False) is True]


if typing.TYPE_CHECKING:
impl = classmethod # type: ignore
else:
def impl(method: typing.Callable[typing.Concatenate[type[Cls], P], R]) -> typing.Callable[P, R]:
bound_method = classmethod(method)
setattr(method, IMPL_MARK, True)
return bound_method # type: ignore
@typing.cast(typing.Callable[..., type[classmethod]], lambda f: f)
def impl(method): # noqa
bound_method = classmethod(method)
setattr(method, IMPL_MARK, True)
return bound_method


__all__ = (
"TRANSLATIONS_KEY",
Expand All @@ -80,6 +71,7 @@ def impl(method: typing.Callable[typing.Concatenate[type[Cls], P], R]) -> typing
"get_default_args",
"get_default_args",
"magic_bundle",
"impl",
"resolve_arg_names",
"to_str",
)

0 comments on commit 3edf3fa

Please sign in to comment.