Skip to content

Commit

Permalink
RSDK-8341: log on main thread (#690)
Browse files Browse the repository at this point in the history
Co-authored-by: sean yu <[email protected]>
  • Loading branch information
purplenicole730 and hexbabe committed Aug 2, 2024
1 parent b5ba049 commit c9049ed
Showing 1 changed file with 61 additions and 17 deletions.
78 changes: 61 additions & 17 deletions src/viam/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from copy import copy
from datetime import datetime
from logging import DEBUG, ERROR, FATAL, INFO, WARN, WARNING # noqa: F401
from threading import Lock, Thread
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Union

from grpclib.exceptions import StreamTerminatedError
Expand All @@ -19,22 +20,61 @@
_MODULE_PARENT: Optional["RobotClient"] = None


class _SingletonEventLoopThread:
_instance = None
_lock = Lock()
_ready_event = asyncio.Event()
_thread = None

def __new__(cls):
# Ensure singleton precondition
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super(_SingletonEventLoopThread, cls).__new__(cls)
cls._instance._loop = None
cls._instance._thread = Thread(target=cls._instance._run)
cls._instance._thread.start()
return cls._instance

def _run(self):
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self._ready_event.set()
self._loop.run_forever()

def stop(self):
if self._loop is not None:
self._loop.call_soon_threadsafe(self._loop.stop)
self._loop.close()

def get_loop(self):
if self._loop is None:
raise RuntimeError("Event loop is None. Did you call .start() and .wait_until_ready()?")
return self._loop

async def wait_until_ready(self):
await self._ready_event.wait()


class _ModuleHandler(logging.Handler):
_parent: "RobotClient"
_logger: logging.Logger
_worker: _SingletonEventLoopThread

def __init__(self, parent: "RobotClient"):
super().__init__()
self._parent = parent
self._logger = logging.getLogger("ModuleLogger")
addHandlers(self._logger, True)
super().__init__()
self._logger.setLevel(self.level)
self._worker = _SingletonEventLoopThread()

def setLevel(self, level: Union[int, str]) -> None:
self._logger.setLevel(level)
return super().setLevel(level)

def handle_task_result(self, task: asyncio.Task):
async def handle_task_result(self, task: asyncio.Task):
try:
_ = task.result()
except (asyncio.CancelledError, asyncio.InvalidStateError, StreamTerminatedError):
Expand All @@ -48,24 +88,28 @@ def emit(self, record: logging.LogRecord):
time = datetime.fromtimestamp(record.created)

try:
assert self._parent is not None
try:
loop = asyncio.get_event_loop()
loop.create_task(
self._parent.log(name, record.levelname, time, message, stack), name=f"{viam._TASK_PREFIX}-LOG-{record.created}"
).add_done_callback(self.handle_task_result)
except RuntimeError:
# If the log is coming from a thread that doesn't have an event loop, create and set a new one.
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.create_task(
self._parent.log(name, record.levelname, time, message, stack), name=f"{viam._TASK_PREFIX}-LOG-{record.created}"
).add_done_callback(self.handle_task_result)
loop = self._worker.get_loop()
asyncio.run_coroutine_threadsafe(
self._asynchronously_emit(record, name, message, stack, time),
loop,
)
except Exception as err:
# If the module log fails, log using stdout/stderr handlers
self._logger.error(f"ModuleLogger failed for {record.name} - {err}")
self._logger.log(record.levelno, message)

async def _asynchronously_emit(self, record: logging.LogRecord, name: str, message: str, stack: str, time: datetime):
await self._worker.wait_until_ready()
task = self._worker.get_loop().create_task(
self._parent.log(name, record.levelname, time, message, stack),
name=f"{viam._TASK_PREFIX}-LOG-{record.created}",
)
task.add_done_callback(lambda t: asyncio.run_coroutine_threadsafe(self.handle_task_result(t), self._worker.get_loop()))

def close(self):
self._worker.stop()
super().close()


class _ColorFormatter(logging.Formatter):
MAPPING = {
Expand All @@ -76,8 +120,8 @@ class _ColorFormatter(logging.Formatter):
"CRITICAL": 41, # white on red bg
}

def __init__(self, patern):
logging.Formatter.__init__(self, patern)
def __init__(self, pattern):
logging.Formatter.__init__(self, pattern)

def format(self, record):
colored_record = copy(record)
Expand Down

0 comments on commit c9049ed

Please sign in to comment.