From 8ec9aecfc11d1a53032eab28089f81b7f95574b3 Mon Sep 17 00:00:00 2001 From: "Rylan.Loong" Date: Wed, 4 Sep 2024 15:20:46 +0800 Subject: [PATCH] refactor: optimize logging and add test cases. --- src/pyqcloud_sdk/logging.py | 39 ++++++++++++----- tests/test_config.py | 78 ++++++++++++++++++++++++++++++++++ tests/test_logging.py | 83 +++++++++++++++++++++++++++++++++++++ 3 files changed, 189 insertions(+), 11 deletions(-) create mode 100644 tests/test_config.py create mode 100644 tests/test_logging.py diff --git a/src/pyqcloud_sdk/logging.py b/src/pyqcloud_sdk/logging.py index 1d6a010..5a6d253 100644 --- a/src/pyqcloud_sdk/logging.py +++ b/src/pyqcloud_sdk/logging.py @@ -1,22 +1,39 @@ import logging import sys +from typing import Optional, TextIO DEFAULT_LOG_LEVEL = logging.WARNING DEFAULT_LOG_FORMAT = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" -logger = logging.getLogger(__name__) -logger.addHandler(logging.NullHandler()) +def get_logger(name: Optional[str] = None) -> logging.Logger: + logger = logging.getLogger(name or __name__) + if not logger.handlers: + logger.addHandler(logging.NullHandler()) + if logger.level == logging.NOTSET: + logger.setLevel(DEFAULT_LOG_LEVEL) + return logger -def setup_logging(level=DEFAULT_LOG_LEVEL, format=DEFAULT_LOG_FORMAT, stream=sys.stderr): - """ - Configures the root logger. - Args: - level (int, optional): The logging level. Defaults to logging.WARNING. - format (str, optional): The log message format. Defaults to DEFAULT_LOG_FORMAT. - stream (TextIOWrapper, optional): The output stream. Defaults to sys.stderr. - """ +def setup_logging( + level: int = DEFAULT_LOG_LEVEL, + format: str = DEFAULT_LOG_FORMAT, + stream: TextIO = sys.stderr, + logger_name: Optional[str] = None, +) -> None: + logger = get_logger(logger_name) + logger.setLevel(level) + + # Remove all existing handlers + for handler in logger.handlers[:]: + logger.removeHandler(handler) + handler = logging.StreamHandler(stream) handler.setFormatter(logging.Formatter(format)) - logging.basicConfig(level=level, handlers=[handler]) + logger.addHandler(handler) + + # Prevent the log messages from being passed to the root logger + logger.propagate = False + + +logger = get_logger() diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..afffd51 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,78 @@ +import unittest +from unittest.mock import patch + +from pyqcloud_sdk.config import Config # Adjust the import path as needed +from pyqcloud_sdk.logging import logger + + +class TestConfig(unittest.TestCase): + def setUp(self): + self.config = Config() + + def test_init(self): + """Test the initialization of Config object.""" + self.assertIsNone(self.config.Module) + self.assertIsNone(self.config.Version) + self.assertIsNone(self.config.EndPoint) + self.assertIsNone(self.config.Region) + self.assertIsNone(self.config.SecretId) + self.assertIsNone(self.config.SecretKey) + + def test_deserialize_all_fields(self): + """Test deserialization with all fields present.""" + config_dict = { + "Module": "TestModule", + "Version": "1.0", + "EndPoint": "test.endpoint.com", + "Region": "test-region", + "SecretId": "test-secret-id", + "SecretKey": "test-secret-key", + } + self.config._deserialize(config_dict) + + self.assertEqual(self.config.Module, "TestModule") + self.assertEqual(self.config.Version, "1.0") + self.assertEqual(self.config.EndPoint, "test.endpoint.com") + self.assertEqual(self.config.Region, "test-region") + self.assertEqual(self.config.SecretId, "test-secret-id") + self.assertEqual(self.config.SecretKey, "test-secret-key") + + def test_deserialize_partial_fields(self): + """Test deserialization with only some fields present.""" + config_dict = { + "Module": "TestModule", + "Version": "1.0", + } + self.config._deserialize(config_dict) + + self.assertEqual(self.config.Module, "TestModule") + self.assertEqual(self.config.Version, "1.0") + self.assertIsNone(self.config.EndPoint) + self.assertIsNone(self.config.Region) + self.assertIsNone(self.config.SecretId) + self.assertIsNone(self.config.SecretKey) + + def test_deserialize_extra_fields(self): + """Test deserialization with extra fields.""" + config_dict = {"Module": "TestModule", "ExtraField1": "extra1", "ExtraField2": "extra2"} + + with patch("pyqcloud_sdk.config.logger") as mock_logger: + self.config._deserialize(config_dict) + mock_logger.warning.assert_called_once_with("ExtraField1,ExtraField2 fields are useless.") + + self.assertEqual(self.config.Module, "TestModule") + + def test_deserialize_empty_dict(self): + """Test deserialization with an empty dictionary.""" + self.config._deserialize({}) + + self.assertIsNone(self.config.Module) + self.assertIsNone(self.config.Version) + self.assertIsNone(self.config.EndPoint) + self.assertIsNone(self.config.Region) + self.assertIsNone(self.config.SecretId) + self.assertIsNone(self.config.SecretKey) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_logging.py b/tests/test_logging.py new file mode 100644 index 0000000..eb18762 --- /dev/null +++ b/tests/test_logging.py @@ -0,0 +1,83 @@ +import unittest +import logging +import io +from contextlib import redirect_stdout + +from pyqcloud_sdk.logging import setup_logging, get_logger # Adjust import as needed + + +class TestLogging(unittest.TestCase): + def setUp(self): + # Reset the root logger before each test + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + logging.root.setLevel(logging.WARNING) + + def test_default_logging_setup(self): + logger = get_logger("test_logger") + self.assertIsInstance(logger, logging.Logger) + self.assertEqual(logger.level, logging.WARNING) + self.assertTrue(any(isinstance(h, logging.NullHandler) for h in logger.handlers)) + + def test_custom_logging_setup(self): + log_stream = io.StringIO() + setup_logging(level=logging.DEBUG, stream=log_stream, logger_name="custom_logger") + logger = get_logger("custom_logger") + + logger.debug("Debug message") + logger.info("Info message") + logger.warning("Warning message") + + log_output = log_stream.getvalue() + self.assertIn("DEBUG - custom_logger - Debug message", log_output) + self.assertIn("INFO - custom_logger - Info message", log_output) + self.assertIn("WARNING - custom_logger - Warning message", log_output) + + def test_log_levels(self): + log_stream = io.StringIO() + setup_logging(level=logging.WARNING, stream=log_stream) + logger = get_logger() + + logger.debug("Debug message") + logger.info("Info message") + logger.warning("Warning message") + logger.error("Error message") + + log_output = log_stream.getvalue() + self.assertNotIn("Debug message", log_output) + self.assertNotIn("Info message", log_output) + self.assertIn("WARNING", log_output) + self.assertIn("ERROR", log_output) + + def test_custom_format(self): + log_stream = io.StringIO() + custom_format = "%(levelname)s: %(message)s" + setup_logging(level=logging.INFO, format=custom_format, stream=log_stream) + logger = get_logger() + + logger.info("Test message") + + log_output = log_stream.getvalue() + self.assertEqual(log_output.strip(), "INFO: Test message") + + def test_multiple_loggers(self): + log_stream1 = io.StringIO() + log_stream2 = io.StringIO() + + setup_logging(level=logging.INFO, stream=log_stream1, logger_name="logger1") + setup_logging(level=logging.ERROR, stream=log_stream2, logger_name="logger2") + + logger1 = get_logger("logger1") + logger2 = get_logger("logger2") + + logger1.info("Info from logger1") + logger2.info("Info from logger2") + logger2.error("Error from logger2") + + self.assertIn("Info from logger1", log_stream1.getvalue()) + self.assertNotIn("Info from logger2", log_stream2.getvalue()) + self.assertIn("Error from logger2", log_stream2.getvalue()) + + +# if __name__ == "__main__": +# unittest.main()