diff --git a/src/viam/logging.py b/src/viam/logging.py index b61d8efd0..a9eafa9ff 100644 --- a/src/viam/logging.py +++ b/src/viam/logging.py @@ -4,6 +4,7 @@ from copy import copy from datetime import datetime from logging import DEBUG, ERROR, FATAL, INFO, WARN, WARNING # noqa: F401 +from threading import Lock, Thread from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Union from grpclib.exceptions import StreamTerminatedError @@ -19,22 +20,61 @@ _MODULE_PARENT: Optional["RobotClient"] = None +class _SingletonEventLoopThread: + _instance = None + _lock = Lock() + _ready_event = asyncio.Event() + _thread = None + + def __new__(cls): + # Ensure singleton precondition + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = super(_SingletonEventLoopThread, cls).__new__(cls) + cls._instance._loop = None + cls._instance._thread = Thread(target=cls._instance._run) + cls._instance._thread.start() + return cls._instance + + def _run(self): + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + self._ready_event.set() + self._loop.run_forever() + + def stop(self): + if self._loop is not None: + self._loop.call_soon_threadsafe(self._loop.stop) + self._loop.close() + + def get_loop(self): + if self._loop is None: + raise RuntimeError("Event loop is None. Did you call .start() and .wait_until_ready()?") + return self._loop + + async def wait_until_ready(self): + await self._ready_event.wait() + + class _ModuleHandler(logging.Handler): _parent: "RobotClient" _logger: logging.Logger + _worker: _SingletonEventLoopThread def __init__(self, parent: "RobotClient"): + super().__init__() self._parent = parent self._logger = logging.getLogger("ModuleLogger") addHandlers(self._logger, True) - super().__init__() self._logger.setLevel(self.level) + self._worker = _SingletonEventLoopThread() def setLevel(self, level: Union[int, str]) -> None: self._logger.setLevel(level) return super().setLevel(level) - def handle_task_result(self, task: asyncio.Task): + async def handle_task_result(self, task: asyncio.Task): try: _ = task.result() except (asyncio.CancelledError, asyncio.InvalidStateError, StreamTerminatedError): @@ -48,24 +88,28 @@ def emit(self, record: logging.LogRecord): time = datetime.fromtimestamp(record.created) try: - assert self._parent is not None - try: - loop = asyncio.get_event_loop() - loop.create_task( - self._parent.log(name, record.levelname, time, message, stack), name=f"{viam._TASK_PREFIX}-LOG-{record.created}" - ).add_done_callback(self.handle_task_result) - except RuntimeError: - # If the log is coming from a thread that doesn't have an event loop, create and set a new one. - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.create_task( - self._parent.log(name, record.levelname, time, message, stack), name=f"{viam._TASK_PREFIX}-LOG-{record.created}" - ).add_done_callback(self.handle_task_result) + loop = self._worker.get_loop() + asyncio.run_coroutine_threadsafe( + self._asynchronously_emit(record, name, message, stack, time), + loop, + ) except Exception as err: # If the module log fails, log using stdout/stderr handlers self._logger.error(f"ModuleLogger failed for {record.name} - {err}") self._logger.log(record.levelno, message) + async def _asynchronously_emit(self, record: logging.LogRecord, name: str, message: str, stack: str, time: datetime): + await self._worker.wait_until_ready() + task = self._worker.get_loop().create_task( + self._parent.log(name, record.levelname, time, message, stack), + name=f"{viam._TASK_PREFIX}-LOG-{record.created}", + ) + task.add_done_callback(lambda t: asyncio.run_coroutine_threadsafe(self.handle_task_result(t), self._worker.get_loop())) + + def close(self): + self._worker.stop() + super().close() + class _ColorFormatter(logging.Formatter): MAPPING = { @@ -76,8 +120,8 @@ class _ColorFormatter(logging.Formatter): "CRITICAL": 41, # white on red bg } - def __init__(self, patern): - logging.Formatter.__init__(self, patern) + def __init__(self, pattern): + logging.Formatter.__init__(self, pattern) def format(self, record): colored_record = copy(record)