From cf68aa1817f7698e643fe9603742fad54402c258 Mon Sep 17 00:00:00 2001 From: Viktor Taskov Date: Tue, 25 Jul 2023 17:31:39 +0100 Subject: [PATCH] Override formatter on subsequent configurations when settings object is passed --- mlserver/logging.py | 35 +++++++++++++++++----------- tests/test_logging.py | 54 +++++++++++++++++++++++++++++++++++++++---- 2 files changed, 72 insertions(+), 17 deletions(-) diff --git a/mlserver/logging.py b/mlserver/logging.py index 92258e509..3c68d02ef 100644 --- a/mlserver/logging.py +++ b/mlserver/logging.py @@ -14,6 +14,8 @@ logger = logging.getLogger(LoggerName) +_STREAM_HANDLER_NAME = "stdout_stream_handler" + def get_logger(): return logger @@ -44,13 +46,15 @@ class ModelLoggerFormatter(logging.Formatter): '"message": "%(message)s" %(model)s}' ) - def __init__(self, use_structured_logging: bool): + def __init__(self, settings: Optional[Settings]): + self.use_structured_logging = ( + settings is not None and settings.use_structured_logging + ) super().__init__( self._STRUCTURED_FORMAT - if use_structured_logging + if self.use_structured_logging else self._UNSTRUCTURED_FORMAT ) - self.use_structured_logging = use_structured_logging @staticmethod def _format_unstructured_model_details(name: str, version: str) -> str: @@ -83,21 +87,26 @@ def format(self, record: logging.LogRecord) -> str: return super().format(record) +def _find_handler( + logger: logging.Logger, handler_name: str +) -> Optional[logging.Handler]: + for h in logger.handlers: + if h.get_name() == handler_name: + return h + return None + + def configure_logger(settings: Optional[Settings] = None): logger = get_logger() # Don't add handler twice - if not logger.handlers: - stream_handler = StreamHandler(sys.stdout) - - use_structured_logging = False - if settings and settings.use_structured_logging: - use_structured_logging = True - - formatter = ModelLoggerFormatter(use_structured_logging) - stream_handler.setFormatter(formatter) + handler = _find_handler(logger, _STREAM_HANDLER_NAME) + if handler is None: + handler = StreamHandler(sys.stdout) + handler.set_name(_STREAM_HANDLER_NAME) + logger.addHandler(handler) - logger.addHandler(stream_handler) + handler.setFormatter(ModelLoggerFormatter(settings)) logger.setLevel(logging.INFO) if settings and settings.debug: diff --git a/tests/test_logging.py b/tests/test_logging.py index 26d41980d..f39d97f4b 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -8,6 +8,7 @@ ModelLoggerFormatter, configure_logger, logger, + _STREAM_HANDLER_NAME, ) from mlserver.settings import ModelParameters, Settings from tests.fixtures import SumModel @@ -44,9 +45,15 @@ ], ) def test_model_logging_formatter_unstructured( - name: str, version: str, expected_model_fmt: str, fmt_present_in_all: bool, caplog + name: str, + version: str, + expected_model_fmt: str, + fmt_present_in_all: bool, + settings: Settings, + caplog, ): - caplog.handler.setFormatter(ModelLoggerFormatter(use_structured_logging=False)) + settings.use_structured_logging = False + caplog.handler.setFormatter(ModelLoggerFormatter(settings)) caplog.set_level(INFO) model_settings = ModelSettings( @@ -96,9 +103,15 @@ def test_model_logging_formatter_unstructured( ], ) def test_model_logging_formatter_structured( - name: str, version: str, expected_model_fmt: str, fmt_present_in_all: bool, caplog + name: str, + version: str, + expected_model_fmt: str, + fmt_present_in_all: bool, + settings: Settings, + caplog, ): - caplog.handler.setFormatter(ModelLoggerFormatter(use_structured_logging=True)) + settings.use_structured_logging = True + caplog.handler.setFormatter(ModelLoggerFormatter(settings)) caplog.set_level(INFO) model_settings = ModelSettings( @@ -136,3 +149,36 @@ def test_log_level_gets_persisted(debug: bool, settings: Settings, caplog): assert test_log_message in caplog.text else: assert test_log_message not in caplog.text + + +def test_configure_logger_when_called_multiple_times_with_same_logger(settings): + logger = configure_logger() + + assert len(logger.handlers) == 1 + handler = logger.handlers[0] + assert handler.name == _STREAM_HANDLER_NAME + assert ( + hasattr(handler.formatter, "use_structured_logging") + and handler.formatter.use_structured_logging is False + ) + + logger = configure_logger(settings) + + assert len(logger.handlers) == 1 + handler = logger.handlers[0] + assert handler.name == _STREAM_HANDLER_NAME + assert ( + hasattr(handler.formatter, "use_structured_logging") + and handler.formatter.use_structured_logging is False + ) + + settings.use_structured_logging = True + logger = configure_logger(settings) + + assert len(logger.handlers) == 1 + handler = logger.handlers[0] + assert handler.name == _STREAM_HANDLER_NAME + assert ( + hasattr(handler.formatter, "use_structured_logging") + and handler.formatter.use_structured_logging is True + )