Skip to content

Commit

Permalink
refactor: optimize logging and add test cases.
Browse files Browse the repository at this point in the history
  • Loading branch information
YongSangUn committed Sep 4, 2024
1 parent 289371a commit 8ec9aec
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 11 deletions.
39 changes: 28 additions & 11 deletions src/pyqcloud_sdk/logging.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,39 @@
import logging
import sys
from typing import Optional, TextIO

DEFAULT_LOG_LEVEL = logging.WARNING
DEFAULT_LOG_FORMAT = "%(asctime)s - %(levelname)s - %(name)s - %(message)s"

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

def get_logger(name: Optional[str] = None) -> logging.Logger:
logger = logging.getLogger(name or __name__)
if not logger.handlers:
logger.addHandler(logging.NullHandler())
if logger.level == logging.NOTSET:
logger.setLevel(DEFAULT_LOG_LEVEL)
return logger

def setup_logging(level=DEFAULT_LOG_LEVEL, format=DEFAULT_LOG_FORMAT, stream=sys.stderr):
"""
Configures the root logger.

Args:
level (int, optional): The logging level. Defaults to logging.WARNING.
format (str, optional): The log message format. Defaults to DEFAULT_LOG_FORMAT.
stream (TextIOWrapper, optional): The output stream. Defaults to sys.stderr.
"""
def setup_logging(
level: int = DEFAULT_LOG_LEVEL,
format: str = DEFAULT_LOG_FORMAT,
stream: TextIO = sys.stderr,
logger_name: Optional[str] = None,
) -> None:
logger = get_logger(logger_name)
logger.setLevel(level)

# Remove all existing handlers
for handler in logger.handlers[:]:
logger.removeHandler(handler)

handler = logging.StreamHandler(stream)
handler.setFormatter(logging.Formatter(format))
logging.basicConfig(level=level, handlers=[handler])
logger.addHandler(handler)

# Prevent the log messages from being passed to the root logger
logger.propagate = False


logger = get_logger()
78 changes: 78 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import unittest
from unittest.mock import patch

from pyqcloud_sdk.config import Config # Adjust the import path as needed
from pyqcloud_sdk.logging import logger


class TestConfig(unittest.TestCase):
def setUp(self):
self.config = Config()

def test_init(self):
"""Test the initialization of Config object."""
self.assertIsNone(self.config.Module)
self.assertIsNone(self.config.Version)
self.assertIsNone(self.config.EndPoint)
self.assertIsNone(self.config.Region)
self.assertIsNone(self.config.SecretId)
self.assertIsNone(self.config.SecretKey)

def test_deserialize_all_fields(self):
"""Test deserialization with all fields present."""
config_dict = {
"Module": "TestModule",
"Version": "1.0",
"EndPoint": "test.endpoint.com",
"Region": "test-region",
"SecretId": "test-secret-id",
"SecretKey": "test-secret-key",
}
self.config._deserialize(config_dict)

self.assertEqual(self.config.Module, "TestModule")
self.assertEqual(self.config.Version, "1.0")
self.assertEqual(self.config.EndPoint, "test.endpoint.com")
self.assertEqual(self.config.Region, "test-region")
self.assertEqual(self.config.SecretId, "test-secret-id")
self.assertEqual(self.config.SecretKey, "test-secret-key")

def test_deserialize_partial_fields(self):
"""Test deserialization with only some fields present."""
config_dict = {
"Module": "TestModule",
"Version": "1.0",
}
self.config._deserialize(config_dict)

self.assertEqual(self.config.Module, "TestModule")
self.assertEqual(self.config.Version, "1.0")
self.assertIsNone(self.config.EndPoint)
self.assertIsNone(self.config.Region)
self.assertIsNone(self.config.SecretId)
self.assertIsNone(self.config.SecretKey)

def test_deserialize_extra_fields(self):
"""Test deserialization with extra fields."""
config_dict = {"Module": "TestModule", "ExtraField1": "extra1", "ExtraField2": "extra2"}

with patch("pyqcloud_sdk.config.logger") as mock_logger:
self.config._deserialize(config_dict)
mock_logger.warning.assert_called_once_with("ExtraField1,ExtraField2 fields are useless.")

self.assertEqual(self.config.Module, "TestModule")

def test_deserialize_empty_dict(self):
"""Test deserialization with an empty dictionary."""
self.config._deserialize({})

self.assertIsNone(self.config.Module)
self.assertIsNone(self.config.Version)
self.assertIsNone(self.config.EndPoint)
self.assertIsNone(self.config.Region)
self.assertIsNone(self.config.SecretId)
self.assertIsNone(self.config.SecretKey)


if __name__ == "__main__":
unittest.main()
83 changes: 83 additions & 0 deletions tests/test_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import unittest
import logging
import io
from contextlib import redirect_stdout

from pyqcloud_sdk.logging import setup_logging, get_logger # Adjust import as needed


class TestLogging(unittest.TestCase):
def setUp(self):
# Reset the root logger before each test
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
logging.root.setLevel(logging.WARNING)

def test_default_logging_setup(self):
logger = get_logger("test_logger")
self.assertIsInstance(logger, logging.Logger)
self.assertEqual(logger.level, logging.WARNING)
self.assertTrue(any(isinstance(h, logging.NullHandler) for h in logger.handlers))

def test_custom_logging_setup(self):
log_stream = io.StringIO()
setup_logging(level=logging.DEBUG, stream=log_stream, logger_name="custom_logger")
logger = get_logger("custom_logger")

logger.debug("Debug message")
logger.info("Info message")
logger.warning("Warning message")

log_output = log_stream.getvalue()
self.assertIn("DEBUG - custom_logger - Debug message", log_output)
self.assertIn("INFO - custom_logger - Info message", log_output)
self.assertIn("WARNING - custom_logger - Warning message", log_output)

def test_log_levels(self):
log_stream = io.StringIO()
setup_logging(level=logging.WARNING, stream=log_stream)
logger = get_logger()

logger.debug("Debug message")
logger.info("Info message")
logger.warning("Warning message")
logger.error("Error message")

log_output = log_stream.getvalue()
self.assertNotIn("Debug message", log_output)
self.assertNotIn("Info message", log_output)
self.assertIn("WARNING", log_output)
self.assertIn("ERROR", log_output)

def test_custom_format(self):
log_stream = io.StringIO()
custom_format = "%(levelname)s: %(message)s"
setup_logging(level=logging.INFO, format=custom_format, stream=log_stream)
logger = get_logger()

logger.info("Test message")

log_output = log_stream.getvalue()
self.assertEqual(log_output.strip(), "INFO: Test message")

def test_multiple_loggers(self):
log_stream1 = io.StringIO()
log_stream2 = io.StringIO()

setup_logging(level=logging.INFO, stream=log_stream1, logger_name="logger1")
setup_logging(level=logging.ERROR, stream=log_stream2, logger_name="logger2")

logger1 = get_logger("logger1")
logger2 = get_logger("logger2")

logger1.info("Info from logger1")
logger2.info("Info from logger2")
logger2.error("Error from logger2")

self.assertIn("Info from logger1", log_stream1.getvalue())
self.assertNotIn("Info from logger2", log_stream2.getvalue())
self.assertIn("Error from logger2", log_stream2.getvalue())


# if __name__ == "__main__":
# unittest.main()

0 comments on commit 8ec9aec

Please sign in to comment.