Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
timoniq committed Jul 12, 2024
1 parent b31385e commit f79e5fb
Show file tree
Hide file tree
Showing 10 changed files with 149 additions and 93 deletions.
79 changes: 0 additions & 79 deletions examples/aesthetics.py

This file was deleted.

22 changes: 17 additions & 5 deletions examples/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,30 @@
import aiosqlite # type: ignore

from telegrinder.modules import logger
from telegrinder.node import PER_EVENT, ScalarNode
from telegrinder.node import ScalarNode, per_call


# Pure node examples
class DB(ScalarNode, aiosqlite.Connection):
# DB connection will be only opened once per event
scope = PER_EVENT

@classmethod
async def compose(cls) -> typing.AsyncGenerator[aiosqlite.Connection, None]:
connection = await aiosqlite.connect("test.db")
logger.info("Opening connection")
yield connection
logger.info("Closing connection")
await connection.close()


# DB connection will be only opened once per event
# to change this and resolve node each time when needed use:
@per_call
class DB2(ScalarNode, aiosqlite.Connection): ...


async def create_tables() -> None:
async with aiosqlite.connect("test.db") as conn:
await conn.execute(
"create table if not exists admins("
"id integer primary key autoincrement, "
"telegram_id text unique"
")"
)
83 changes: 83 additions & 0 deletions examples/with_nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from examples.nodes import DB, create_tables
from telegrinder import API, Message, Telegrinder, Token, node
from telegrinder.bot.dispatch import Context
from telegrinder.bot.rules import ABCRule, Markup, Text
from telegrinder.modules import logger

api = API(token=Token.from_env())
bot = Telegrinder(api)
logger.set_level("DEBUG")


class IsChat(ABCRule):
async def check(self, source: node.Source) -> bool:
return source.chat.id < 0


class IsAdmin(ABCRule):
async def check(self, source: node.Source, db: DB, context: Context) -> bool:
result = await db.execute("select * from admins where telegram_id = ?", (source.from_user.id,))
context["is_admin"] = True
return bool(await result.fetchone())


async def promote(user_id: int, *, db: DB) -> None:
await db.execute("insert into admins(telegram_id) values (?) on conflict do nothing", (user_id,))
await db.commit()


@bot.on.message(IsChat())
async def photo_in_chat_handler(message: Message, p: node.Photo):
photo_size = p.sizes[-1]
await message.answer("Photo ratio H/W: {}".format(photo_size.height / photo_size.width))


### Two handlers below require DB node
# DB node is marked with scope = PER_EVENT
# therefore DB node will only be resolved once on event
# If you run this code you will see that connection opening
# and closing is going to happen only once.


@bot.on.message()
async def photo_handler(
message: Message,
db: DB,
photo: node.Photo,
) -> None:
await message.answer("Got a photo in private message")


@bot.on.message()
async def integer_handler(
message: Message,
db: DB,
i: node.TextInteger,
) -> None:
await message.answer(f"{i} + 3 = {i + 3}")


@bot.on.message(IsAdmin(), Text("/op"))
async def add_admin_handler(message: Message, db: DB) -> str | None:
if not message.reply_to_message:
return "Need reply"
await promote(message.reply_to_message.unwrap().from_user.id, db=db)
await message.answer("Done")


@bot.on.message(Markup("/getadmin <token>"))
async def getadmin_handler(message: Message, token: str, db: DB) -> str:
print(token)
if token != api.token:
return "Wrong token"
await promote(message.from_user.id, db=db)
return "Done"


@bot.on.message(Text("/amiadmin"), IsAdmin().optional())
async def amiadmin_handler(message: Message, is_admin: bool = False):
await message.answer("You are " + ("not " if not is_admin else "") + "an admin")


bot.loop_wrapper.lifespan.on_startup(create_tables())
bot.run_forever()
2 changes: 1 addition & 1 deletion telegrinder/bot/polling/polling.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ async def listen(self) -> typing.AsyncGenerator[list[Update], None]:
exit(6)
else:
logger.warning(
"Server disconnected, waiting 5 seconds to reconnetion...",
"Server disconnected, waiting 5 seconds to reconnet...",
)
reconn_counter += 1
await asyncio.sleep(self.reconnection_timeout)
Expand Down
20 changes: 19 additions & 1 deletion telegrinder/bot/rules/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,17 @@ async def bounding_check(
kw[k] = ctx
else:
raise LookupError(
f"Cannot bound {k!r} to '{self.__class__.__name__}.check()', because it cannot be resolved."
f"Cannot bound {k!r} of type {v!r} to '{self.__class__.__name__}.check()', because it cannot be resolved."
)

return await self.check(**kw)

def optional(self) -> "ABCRule":
return self | Always()

def should_fail(self) -> "ABCRule":
return self & Never()

def __init_subclass__(cls, requires: list["ABCRule"] | None = None) -> None:
"""Merges requirements from inherited classes and rule-specific requirements."""

Expand Down Expand Up @@ -161,10 +167,22 @@ async def check(self, event: Update, ctx: Context) -> bool:
return not await check_rule(event.ctx_api, self.rule, event, ctx_copy)


class Never(ABCRule):
async def check(self) -> typing.Literal[False]:
return False


class Always(ABCRule):
async def check(self) -> typing.Literal[True]:
return True


__all__ = (
"ABCRule",
"AndRule",
"NotRule",
"OrRule",
"with_caching_translations",
"Never",
"Always",
)
4 changes: 2 additions & 2 deletions telegrinder/bot/rules/mention.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from telegrinder.types.enums import MessageEntityType

from .abc import ABCRule, Message
from .message import Message, MessageRule
from .text import HasText


class HasMention(ABCRule, requires=[HasText()]):
class HasMention(MessageRule, requires=[HasText()]):
async def check(self, message: Message) -> bool:
if not message.entities.unwrap_or_none():
return False
Expand Down
4 changes: 3 additions & 1 deletion telegrinder/node/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .container import ContainerNode
from .message import MessageNode
from .rule import RuleChain
from .scope import PER_CALL, PER_EVENT, NodeScope
from .scope import PER_CALL, PER_EVENT, NodeScope, per_call, per_event
from .source import Source
from .text import Text, TextInteger
from .tools import generate_node
Expand Down Expand Up @@ -36,4 +36,6 @@
"NodeScope",
"PER_CALL",
"PER_EVENT",
"per_call",
"per_event",
)
7 changes: 5 additions & 2 deletions telegrinder/node/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class ComposeError(BaseException): ...

class Node(abc.ABC):
node: str = "node"
scope: NodeScope = NodeScope.PER_CALL
scope: NodeScope = NodeScope.PER_EVENT

@classmethod
@abc.abstractmethod
Expand Down Expand Up @@ -77,7 +77,10 @@ def create_class(name, bases, dct):
return type(
"Scalar",
(SCALAR_NODE,),
{"as_node": classmethod(lambda cls: create_node(cls, bases, dct))},
{
"as_node": classmethod(lambda cls: create_node(cls, bases, dct)),
"scope": Node.scope,
},
)

class ScalarNode(ScalarNodeProto, abc.ABC, metaclass=create_class):
Expand Down
4 changes: 2 additions & 2 deletions telegrinder/node/composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ async def compose_node(
else:
context.sessions[name] = await compose_node(subnode, update, ctx)

if getattr(_node, "scope", None) is NodeScope.PER_EVENT:
node_ctx[_node] = context.sessions[name]
if getattr(subnode, "scope", None) is NodeScope.PER_EVENT:
node_ctx[subnode] = context.sessions[name]

generator: typing.AsyncGenerator | None

Expand Down
17 changes: 17 additions & 0 deletions telegrinder/node/scope.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
import enum
import typing

if typing.TYPE_CHECKING:
from .base import Node


T = typing.TypeVar("T", bound=type["Node"])


class NodeScope(enum.Enum):
Expand All @@ -8,3 +15,13 @@ class NodeScope(enum.Enum):

PER_EVENT = NodeScope.PER_EVENT
PER_CALL = NodeScope.PER_CALL


def per_call(node: T) -> T:
setattr(node, "scope", PER_CALL)
return node


def per_event(node: T) -> T:
setattr(node, "scope", PER_EVENT)
return node

0 comments on commit f79e5fb

Please sign in to comment.