Skip to content

Commit

Permalink
feat: add the is_waiting_for for WaiterMachine (#79)
Browse files Browse the repository at this point in the history
  • Loading branch information
prostomarkeloff committed Jun 17, 2024
1 parent aa6a271 commit c0c6bba
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 3 deletions.
2 changes: 2 additions & 0 deletions examples/easy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
logger.set_level("INFO")



@bot.on.message(Text("/start"))
async def start(message: Message):
me = (await api.get_me()).unwrap().first_name
Expand All @@ -25,6 +26,7 @@ async def start(message: Message):
bot.dispatch.message,
message,
Text(["fine", "bad"], ignore_case=True),
exit=MessageReplyHandler("Oh, ok, exiting state...", Text("/exit")),
default=MessageReplyHandler("Fine or bad"),
)

Expand Down
9 changes: 6 additions & 3 deletions telegrinder/bot/dispatch/waiter_machine/machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ async def wait(
*rules: ABCRule[EventModel],
default: Behaviour[EventModel] | None = None,
on_drop: Behaviour[EventModel] | None = None,
exit: Behaviour[EventModel] | None = None,
expiration: datetime.timedelta | float | None = None,
) -> ShortStateContext[EventModel]:
if isinstance(expiration, int | float):
Expand All @@ -79,6 +80,7 @@ async def wait(
if isinstance(linked, tuple)
else (linked.ctx_api, state_view.get_state_key(linked))
) # type: ignore
api, key = linked if isinstance(linked, tuple) else (linked.ctx_api, state_view.get_state_key(linked)) # type: ignore
if not key:
raise RuntimeError("Unable to get state key.")

Expand All @@ -92,6 +94,7 @@ async def wait(
expiration=expiration,
default_behaviour=default,
on_drop_behaviour=on_drop,
exit_behaviour=exit,
)

if view_name not in self.storage:
Expand All @@ -113,11 +116,11 @@ async def call_behaviour(
update: Update,
behaviour: Behaviour[EventModel] | None = None,
**context: typing.Any,
) -> None:
) -> bool:
# TODO: support view as a behaviour

if behaviour is None:
return
return False

ctx = Context(**context)
if isinstance(event, asyncio.Event):
Expand All @@ -129,6 +132,6 @@ async def call_behaviour(

if await behaviour.check(event.api, update, ctx):
await behaviour.run(event, ctx)

return True

__all__ = ("WaiterMachine",)
12 changes: 12 additions & 0 deletions telegrinder/bot/dispatch/waiter_machine/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,18 @@ async def pre(self, event: EventType, ctx: Context) -> bool:
await self.machine.drop(self.view, short_state.key, ctx.raw_update, **preset_context.copy())
return True

# before running the handler we check if the user wants to exit waiting
if short_state.exit_behaviour is not None:
if await self.machine.call_behaviour(
self.view,
event,
ctx.raw_update,
behaviour=short_state.exit_behaviour,
**preset_context,
):
await self.machine.drop(self.view, short_state.key, ctx.raw_update, **preset_context.copy())
return True

handler = FuncHandler(
self.pass_runtime,
list(short_state.rules),
Expand Down
1 change: 1 addition & 0 deletions telegrinder/bot/dispatch/waiter_machine/short_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class ShortState(typing.Generic[EventModel]):
)
default_behaviour: Behaviour[EventModel] | None = dataclasses.field(default=None)
on_drop_behaviour: Behaviour[EventModel] | None = dataclasses.field(default=None)
exit_behaviour: Behaviour[EventModel] | None = dataclasses.field(default=None)
expiration_date: datetime.datetime | None = dataclasses.field(init=False)
context: ShortStateContext[EventModel] | None = dataclasses.field(default=None, init=False)

Expand Down

0 comments on commit c0c6bba

Please sign in to comment.