diff --git a/python/idsse_common/idsse/common/log_util.py b/python/idsse_common/idsse/common/log_util.py index c1c4edd..fbaf6f8 100644 --- a/python/idsse_common/idsse/common/log_util.py +++ b/python/idsse_common/idsse/common/log_util.py @@ -10,7 +10,7 @@ # Mackenzie Grimes (2) # # ------------------------------------------------------------------------------ -# pylint: disable=too-few-public-methods +# pylint: disable=too-few-public-methods,missing-class-docstring import logging import time diff --git a/python/idsse_common/idsse/common/rabbitmq_utils.py b/python/idsse_common/idsse/common/rabbitmq_utils.py index a933a40..6a75ccd 100644 --- a/python/idsse_common/idsse/common/rabbitmq_utils.py +++ b/python/idsse_common/idsse/common/rabbitmq_utils.py @@ -11,6 +11,7 @@ # # ---------------------------------------------------------------------------------- +import contextvars import logging import logging.config import uuid @@ -344,6 +345,11 @@ def threadsafe_nack( threadsafe_call(channel, lambda: channel.basic_nack(delivery_tag, requeue=requeue)) +def _set_context(context): + for var, value in context.items(): + var.set(value) + + class Consumer(Thread): """ RabbitMQ consumer, runs in own thread to not block heartbeat. A thread pool @@ -352,6 +358,10 @@ class Consumer(Thread): shutdown. The start() and stop() methods should be called from the same thread as the one used to create the instance. """ + + # pylint: disable=too-many-instance-attributes + # Eight is reasonable in this case. + def __init__( self, conn_params: Conn, @@ -361,6 +371,7 @@ def __init__( **kwargs, ): super().__init__(*args, **kwargs) + self.context = contextvars.copy_context() self.daemon = True self._tpx = ThreadPoolExecutor(max_workers=num_message_handlers) self._conn_params = conn_params @@ -382,6 +393,7 @@ def __init__( self.channel.basic_qos(prefetch_count=1) def run(self): + _set_context(self.context) logger.info('Start Consuming... (to stop press CTRL+C)') self.channel.start_consuming() @@ -428,6 +440,7 @@ def __init__( **kwargs, ): super().__init__(*args, **kwargs) + self.context = contextvars.copy_context() self.daemon = True self._is_running = True self._exch = exch_params @@ -454,6 +467,7 @@ def __init__( self.channel.confirm_delivery() def run(self): + _set_context(self.context) logger.info('Starting publisher') while self._is_running: if self.connection and self.connection.is_open: diff --git a/python/idsse_common/test/test_log_util.py b/python/idsse_common/test/test_log_util.py index 7920d0f..b1f1429 100644 --- a/python/idsse_common/test/test_log_util.py +++ b/python/idsse_common/test/test_log_util.py @@ -10,9 +10,13 @@ # # ---------------------------------------------------------------------------------- # pylint: disable=missing-function-docstring,redefined-outer-name,invalid-name,unused-argument +# pylint: disable=missing-class-docstring +import contextvars import logging import logging.config +import threading +import time from datetime import datetime, UTC from uuid import uuid4 as uuid @@ -77,12 +81,11 @@ def test_get_default_log_config_with_corr_id(capsys): logger.debug(msg=EXAMPLE_LOG_MESSAGE) stdout = capsys.readouterr().out # capture std output from test run - # should not be logging DEBUG if default log config handled level correctly assert stdout == '' + logger.info(msg=EXAMPLE_LOG_MESSAGE) stdout = capsys.readouterr().out - assert EXAMPLE_LOG_MESSAGE in stdout assert corr_id in stdout @@ -94,3 +97,46 @@ def test_get_default_log_config_no_corr_id(capsys): logger.debug('hello world') stdout = capsys.readouterr().out assert corr_id not in stdout + + +def test_getting_logs_from_threaded_func(capsys): + logging.config.dictConfig(get_default_log_config('INFO', True)) + set_corr_id_context_var(EXAMPLE_ORIGINATOR, key=EXAMPLE_UUID) + + def worker(): + logger = logging.getLogger(__name__) + logger.info(EXAMPLE_LOG_MESSAGE) + + # Create and start the thread + thread = threading.Thread(target=contextvars.copy_context().run, args=(worker,)) + thread.start() + + time.sleep(.1) + stdout = capsys.readouterr().out + assert EXAMPLE_LOG_MESSAGE in stdout + + +def test_getting_logs_from_thread_class(capsys): + logging.config.dictConfig(get_default_log_config('INFO', True)) + set_corr_id_context_var(EXAMPLE_ORIGINATOR, key=EXAMPLE_UUID) + + def set_context(context): + for var, value in context.items(): + var.set(value) + + class MyThread(threading.Thread): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.context = contextvars.copy_context() + + def run(self): + set_context(self.context) + logger = logging.getLogger(f'{__name__}::{self.__class__.__name__}') + logger.info(EXAMPLE_LOG_MESSAGE) + + thread = MyThread() + thread.start() + thread.join() + + stdout = capsys.readouterr().out + assert EXAMPLE_LOG_MESSAGE in stdout