Skip to content

Commit

Permalink
Added stop_polling & changed way polling is run for (graceful) shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
coder2020official committed Aug 31, 2024
1 parent 08afe40 commit 393ca6f
Showing 1 changed file with 43 additions and 64 deletions.
107 changes: 43 additions & 64 deletions telebot/async_telebot.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ def __init__(self, token: str, parse_mode: Optional[str]=None, offset: Optional[
util.validate_token(self.token)

self.bot_id: Union[int, None] = util.extract_bot_id(self.token) # subject to change, unspecified

self.__polling: Optional[asyncio.Event] = None
self._stop_event = asyncio.Event()

self._update_tasks_set = set()


@property
Expand Down Expand Up @@ -317,72 +322,34 @@ async def polling(self, non_stop: bool=True, skip_pending=False, interval: int=0
await self.skip_updates()

if restart_on_change:
self._setup_change_detector(path_to_watch)
self._setup_change_detector(path_to_watch)

tasks = [] # only polling & event task
# we will stop polling when either of these two fail/complete:
# 1. polling task: due to exception, etc
# 2. stop_event: due to stop_polling call which causes _stop_event.set()
tasks.append(asyncio.create_task(self._process_polling(non_stop, interval, timeout, request_timeout, allowed_updates)))
tasks.append(asyncio.create_task(self._stop_event.wait()))
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)

pending = pending.union(self._update_tasks_set)
for task in pending:
try:
task.cancel()
except asyncio.CancelledError: # handled just in case, not necessary
await task # cleanup

await self._process_polling(non_stop, interval, timeout, request_timeout, allowed_updates)
await asyncio.gather(*done, return_exceptions=True)

async def infinity_polling(self, timeout: Optional[int]=20, skip_pending: Optional[bool]=False, request_timeout: Optional[int]=None,
logger_level: Optional[int]=logging.ERROR, allowed_updates: Optional[List[str]]=None,
restart_on_change: Optional[bool]=False, path_to_watch: Optional[str]=None, *args, **kwargs):
"""
Wrap polling with infinite loop and exception handling to avoid bot stops polling.
.. note::
Install watchdog and psutil before using restart_on_change option.
:param timeout: Timeout in seconds for get_updates(Defaults to None)
:type timeout: :obj:`int`
:param skip_pending: skip old updates
:type skip_pending: :obj:`bool`
:param request_timeout: Aiohttp's request timeout. Defaults to 5 minutes(aiohttp.ClientTimeout).
:type request_timeout: :obj:`int`
:param logger_level: Custom logging level for infinity_polling logging.
Use logger levels from logging as a value. None/NOTSET = no error logging
:type logger_level: :obj:`int`
:param allowed_updates: A list of the update types you want your bot to receive.
For example, specify [“message”, “edited_channel_post”, “callback_query”] to only receive updates of these types.
See util.update_types for a complete list of available update types.
Specify an empty list to receive all update types except chat_member (default).
If not specified, the previous setting will be used.
Please note that this parameter doesn't affect updates created before the call to the get_updates,
so unwanted updates may be received for a short period of time.
:type allowed_updates: :obj:`list` of :obj:`str`
:param restart_on_change: Restart a file on file(s) change. Defaults to False
:type restart_on_change: :obj:`bool`
:param path_to_watch: Path to watch for changes. Defaults to current directory
:type path_to_watch: :obj:`str`
:return: None
Deprecated. Use polling instead.
"""
if skip_pending:
await self.skip_updates()
self._polling = True

if restart_on_change:
self._setup_change_detector(path_to_watch)

while self._polling:
try:
await self._process_polling(non_stop=True, timeout=timeout, request_timeout=request_timeout,
allowed_updates=allowed_updates, *args, **kwargs)
except Exception as e:
if logger_level and logger_level >= logging.ERROR:
logger.error("Infinity polling exception: %s", self.__hide_token(str(e)))
if logger_level and logger_level >= logging.DEBUG:
logger.error("Exception traceback:\n%s", self.__hide_token(traceback.format_exc()))
await asyncio.sleep(3)
continue
if logger_level and logger_level >= logging.INFO:
logger.error("Infinity polling: polling exited")
if logger_level and logger_level >= logging.INFO:
logger.error("Break infinity polling")
# logger_level is useless & not used;
await self.polling(non_stop=True, skip_pending=skip_pending, interval=0, timeout=timeout, request_timeout=request_timeout,
allowed_updates=allowed_updates, restart_on_change=restart_on_change, path_to_watch=path_to_watch, *args, **kwargs)

async def _handle_exception(self, exception: Exception) -> bool:
if self.exception_handler is None:
Expand Down Expand Up @@ -410,6 +377,13 @@ async def _handle_error_interval(self, error_interval: float):
error_interval = 60
return error_interval

async def stop_polling(self):
"""
Stop polling.
"""
self._stop_event.set()
self.__polling.clear()

async def _process_polling(self, non_stop: bool=False, interval: int=0, timeout: int=20,
request_timeout: int=None, allowed_updates: Optional[List[str]]=None):
"""
Expand Down Expand Up @@ -440,12 +414,13 @@ async def _process_polling(self, non_stop: bool=False, interval: int=0, timeout:

logger.info('Starting your bot with username: [@%s]', self.user.username)

self._polling = True

error_interval = 0.25

self.__polling = asyncio.Event()
self.__polling.set()

try:
while self._polling:
while self.__polling.is_set():
try:
updates = await self.get_updates(offset=self.offset, allowed_updates=allowed_updates, timeout=timeout, request_timeout=request_timeout)
if updates:
Expand Down Expand Up @@ -499,7 +474,7 @@ async def _process_polling(self, non_stop: bool=False, interval: int=0, timeout:
else:
break
finally:
self._polling = False
self.__polling.clear() # clear polling event
await self.close_session()
logger.warning('Polling is stopped.')

Expand All @@ -518,7 +493,11 @@ async def _process_updates(self, handlers, messages, update_type):
tasks = []
middlewares = await self._get_middlewares(update_type)
for message in messages:
tasks.append(self._run_middlewares_and_handlers(message, handlers, middlewares, update_type))
task = asyncio.create_task(self._run_middlewares_and_handlers(message, handlers, middlewares, update_type))
tasks.append(task)
task.add_done_callback(self._update_tasks_set.discard)
self._update_tasks_set.add(task)

await asyncio.gather(*tasks)

async def _run_middlewares_and_handlers(self, message, handlers, middlewares, update_type):
Expand Down

0 comments on commit 393ca6f

Please sign in to comment.