Skip to content

Commit

Permalink
feat: add custom logger that includes the origin
Browse files Browse the repository at this point in the history
When running servers, use a custom structlog logger that binds
an extra parameter called origin. Can be: generic_server, copilot_proxy

Closes: #301
  • Loading branch information
yrobla committed Dec 18, 2024
1 parent 99f7489 commit afabff3
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 36 deletions.
23 changes: 16 additions & 7 deletions src/codegate/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import Dict, Optional

import click
import structlog
from uvicorn.config import Config as UvicornConfig
from uvicorn.server import Server

Expand All @@ -20,6 +19,7 @@
from codegate.providers.copilot.provider import CopilotProvider
from codegate.server import init_app
from codegate.storage.utils import restore_storage_backup
from codegate.logger.logger import OriginLogger


class UvicornServer:
Expand All @@ -33,7 +33,9 @@ def __init__(self, config: UvicornConfig, server: Server):
self._startup_complete = asyncio.Event()
self._shutdown_event = asyncio.Event()
self._should_exit = False
self.logger = structlog.get_logger("codegate")

logger_obj = OriginLogger("generic_server")
self.logger = logger_obj.logger

async def serve(self) -> None:
"""Start the uvicorn server and handle shutdown gracefully."""
Expand Down Expand Up @@ -84,8 +86,10 @@ async def cleanup(self) -> None:

def validate_port(ctx: click.Context, param: click.Parameter, value: int) -> int:
"""Validate the port number is in valid range."""
logger = structlog.get_logger("codegate")
logger.debug(f"Validating port number: {value}")
cli_logger_obj = OriginLogger("cli")
cli_logger = cli_logger_obj.logger

cli_logger.debug(f"Validating port number: {value}")
if value is not None and not (1 <= value <= 65535):
raise click.BadParameter("Port must be between 1 and 65535")
return value
Expand Down Expand Up @@ -296,7 +300,8 @@ def serve(

# Set up logging first
setup_logging(cfg.log_level, cfg.log_format)
logger = structlog.get_logger("codegate")
cli_logger_obj = OriginLogger("cli")
logger = cli_logger_obj.logger

init_db_sync(cfg.db_path)

Expand Down Expand Up @@ -327,7 +332,9 @@ def serve(
click.echo(f"Configuration error: {e}", err=True)
sys.exit(1)
except Exception as e:
logger = structlog.get_logger("codegate")
cli_logger_obj = OriginLogger("cli")
logger = cli_logger_obj.logger

logger.exception("Unexpected error occurred")
click.echo(f"Error: {e}", err=True)
sys.exit(1)
Expand All @@ -336,7 +343,9 @@ def serve(
async def run_servers(cfg: Config, app) -> None:
"""Run the codegate server."""
try:
logger = structlog.get_logger("codegate")
cli_logger_obj = OriginLogger("cli")
logger = cli_logger_obj.logger

logger.info(
"Starting server",
extra={
Expand Down
10 changes: 10 additions & 0 deletions src/codegate/codegate_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@ def _missing_(cls, value: str) -> Optional["LogFormat"]:
)


def add_origin(logger, log_method, event_dict):
# Add 'origin' if it's bound to the logger but not explicitly in the event dict
if 'origin' not in event_dict and hasattr(logger, '_context'):
origin = logger._context.get('origin')
if origin:
event_dict['origin'] = origin
return event_dict


def setup_logging(
log_level: Optional[LogLevel] = None, log_format: Optional[LogFormat] = None
) -> logging.Logger:
Expand All @@ -74,6 +83,7 @@ def setup_logging(
shared_processors = [
structlog.processors.add_log_level,
structlog.processors.TimeStamper(fmt="%Y-%m-%dT%H:%M:%S.%03dZ", utc=True),
add_origin,
structlog.processors.CallsiteParameterAdder(
[
structlog.processors.CallsiteParameter.MODULE,
Expand Down
6 changes: 6 additions & 0 deletions src/codegate/logger/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import structlog


class OriginLogger:
def __init__(self, origin: str):
self.logger = structlog.get_logger().bind(origin=origin)
9 changes: 5 additions & 4 deletions src/codegate/providers/copilot/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Dict, List, Optional, Tuple, Union
from urllib.parse import unquote, urljoin, urlparse

import structlog
from src.codegate.logger.logger import OriginLogger
from litellm.types.utils import Delta, ModelResponse, StreamingChoices

from codegate.ca.codegate_ca import CertificateAuthority
Expand All @@ -22,7 +22,8 @@
)
from codegate.providers.copilot.streaming import SSEProcessor

logger = structlog.get_logger("codegate")
logger_obj = OriginLogger("copilot_proxy")
logger = logger_obj.logger

# Constants
MAX_BUFFER_SIZE = 10 * 1024 * 1024 # 10MB
Expand Down Expand Up @@ -637,7 +638,7 @@ async def get_target_url(path: str) -> Optional[str]:
# Check for prefix match
for route in VALIDATED_ROUTES:
# For prefix matches, keep the rest of the path
remaining_path = path[len(route.path) :]
remaining_path = path[len(route.path):]
logger.debug(f"Remaining path: {remaining_path}")
# Make sure we don't end up with double slashes
if remaining_path and remaining_path.startswith("/"):
Expand Down Expand Up @@ -791,7 +792,7 @@ def data_received(self, data: bytes) -> None:
self._proxy_transport_write(headers)
logger.debug(f"Headers sent: {headers}")

data = data[header_end + 4 :]
data = data[header_end + 4:]

self._process_chunk(data)

Expand Down
40 changes: 15 additions & 25 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,27 +190,17 @@ def cli_runner():
return CliRunner()


@pytest.fixture
def mock_logging(mocker):
return mocker.patch("your_cli_module.structlog.get_logger")


@pytest.fixture
def mock_setup_logging(mocker):
return mocker.patch("your_cli_module.setup_logging")


def test_serve_default_options(cli_runner):
"""Test serve command with default options."""
# Use patches for run_servers and logging setup
with (
patch("src.codegate.cli.run_servers") as mock_run,
patch("src.codegate.cli.structlog.get_logger") as mock_logging,
patch("src.codegate.cli.OriginLogger") as mock_origin_logger,
patch("src.codegate.cli.setup_logging") as mock_setup_logging,
):

logger_instance = MagicMock()
mock_logging.return_value = logger_instance
mock_origin_logger.return_value = logger_instance

# Invoke the CLI command
result = cli_runner.invoke(cli, ["serve"])
Expand All @@ -222,7 +212,7 @@ def test_serve_default_options(cli_runner):
mock_setup_logging.assert_called_once_with(LogLevel.INFO, LogFormat.JSON)

# Check if logging was done correctly
mock_logging.assert_called_with("codegate")
mock_origin_logger.assert_called_with("cli")

# Validate run_servers was called once
mock_run.assert_called_once()
Expand All @@ -232,12 +222,12 @@ def test_serve_custom_options(cli_runner):
"""Test serve command with custom options."""
with (
patch("src.codegate.cli.run_servers") as mock_run,
patch("src.codegate.cli.structlog.get_logger") as mock_logging,
patch("src.codegate.cli.OriginLogger") as mock_origin_logger,
patch("src.codegate.cli.setup_logging") as mock_setup_logging,
):

logger_instance = MagicMock()
mock_logging.return_value = logger_instance
mock_origin_logger.return_value = logger_instance

# Invoke the CLI command with custom options
result = cli_runner.invoke(
Expand Down Expand Up @@ -272,7 +262,7 @@ def test_serve_custom_options(cli_runner):
mock_setup_logging.assert_called_once_with(LogLevel.DEBUG, LogFormat.TEXT)

# Assert logger got called with the expected module name
mock_logging.assert_called_with("codegate")
mock_origin_logger.assert_called_with("cli")

# Validate run_servers was called once
mock_run.assert_called_once()
Expand Down Expand Up @@ -332,20 +322,20 @@ def test_serve_with_config_file(cli_runner, temp_config_file):
"""Test serve command with config file."""
with (
patch("src.codegate.cli.run_servers") as mock_run,
patch("src.codegate.cli.structlog.get_logger") as mock_logging,
patch("src.codegate.cli.OriginLogger") as mock_origin_logger,
patch("src.codegate.cli.setup_logging") as mock_setup_logging,
):

logger_instance = MagicMock()
mock_logging.return_value = logger_instance
mock_origin_logger.return_value = logger_instance

# Invoke the CLI command with the configuration file
result = cli_runner.invoke(cli, ["serve", "--config", str(temp_config_file)])

# Assertions to ensure the CLI ran successfully
assert result.exit_code == 0
mock_setup_logging.assert_called_once_with(LogLevel.DEBUG, LogFormat.JSON)
mock_logging.assert_called_with("codegate")
mock_origin_logger.assert_called_with("cli")

# Validate that run_servers was called with the expected configuration
mock_run.assert_called_once()
Expand Down Expand Up @@ -380,12 +370,12 @@ def test_serve_priority_resolution(cli_runner: CliRunner, temp_config_file: Path
with (
patch.dict(os.environ, {"LOG_LEVEL": "INFO", "PORT": "9999"}, clear=True),
patch("src.codegate.cli.run_servers") as mock_run,
patch("src.codegate.cli.structlog.get_logger") as mock_logging,
patch("src.codegate.cli.OriginLogger") as mock_origin_logger,
patch("src.codegate.cli.setup_logging") as mock_setup_logging,
):
# Set up mock logger
logger_instance = MagicMock()
mock_logging.return_value = logger_instance
mock_origin_logger.return_value = logger_instance

# Execute CLI command with specific options overriding environment and config file settings
result = cli_runner.invoke(
Expand Down Expand Up @@ -420,7 +410,7 @@ def test_serve_priority_resolution(cli_runner: CliRunner, temp_config_file: Path

# Ensure logging setup was called with the highest priority settings (CLI arguments)
mock_setup_logging.assert_called_once_with("ERROR", "TEXT")
mock_logging.assert_called_with("codegate")
mock_origin_logger.assert_called_with("cli")

# Verify that the run_servers was called with the overridden settings
config_arg = mock_run.call_args[0][0] # Assuming Config is the first positional arg
Expand Down Expand Up @@ -448,12 +438,12 @@ def test_serve_certificate_options(cli_runner: CliRunner) -> None:
"""Test serve command with certificate options."""
with (
patch("src.codegate.cli.run_servers") as mock_run,
patch("src.codegate.cli.structlog.get_logger") as mock_logging,
patch("src.codegate.cli.OriginLogger") as mock_origin_logger,
patch("src.codegate.cli.setup_logging") as mock_setup_logging,
):
# Set up mock logger
logger_instance = MagicMock()
mock_logging.return_value = logger_instance
mock_origin_logger.return_value = logger_instance

# Execute CLI command with certificate options
result = cli_runner.invoke(
Expand All @@ -478,7 +468,7 @@ def test_serve_certificate_options(cli_runner: CliRunner) -> None:

# Ensure logging setup was called with expected arguments
mock_setup_logging.assert_called_once_with("INFO", "JSON")
mock_logging.assert_called_with("codegate")
mock_origin_logger.assert_called_with("cli")

# Verify that run_servers was called with the provided certificate options
config_arg = mock_run.call_args[0][0] # Assuming Config is the first positional arg
Expand Down

0 comments on commit afabff3

Please sign in to comment.