diff --git a/Dockerfile b/Dockerfile index f8e9e88c..60f11aa5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,6 +6,9 @@ ENV USE_OTEL=${USE_OTEL} # Set the USE_MSSQL env variable to true to enable SQL Server support ARG USE_MSSQL=true ENV USE_MSSQL=${USE_MSSQL} +# Set default log config +ARG LOG_CONFIG=/code/assets/production_log_config.json +ENV LOG_CONFIG=${LOGGING_CONFIG} # Updgrade system packages and install curl RUN apt-get update && apt-get upgrade -y && apt-get install curl -y @@ -27,7 +30,7 @@ RUN mkdir -p /code/src/recordlinker # Copy over just the pyproject.toml file and install the dependencies doing this # before copying the rest of the code allows for caching of the dependencies COPY ./pyproject.toml /code/pyproject.toml -RUN pip install '.' +RUN pip install '.[prod]' # Conditionally install OpenTelemetry packages if USE_OTEL is true RUN if [ "$USE_OTEL" = "true" ]; then \ @@ -46,9 +49,7 @@ EXPOSE 8080 # Conditionally run the application with or without OpenTelemetry CMD if [ "$USE_OTEL" = "true" ]; then \ opentelemetry-instrument --service_name recordlinker \ - uvicorn recordlinker.main:app --app-dir src --host 0 --port 8080 \ - --log-config src/recordlinker/log_config.yml; \ + uvicorn recordlinker.main:app --host 0 --port 8080; \ else \ - uvicorn recordlinker.main:app --app-dir src --host 0 --port 8080 \ - --log-config src/recordlinker/log_config.yml; \ + uvicorn recordlinker.main:app --host 0 --port 8080; \ fi diff --git a/assets/production_log_config.json b/assets/production_log_config.json new file mode 100644 index 00000000..6f4b6afa --- /dev/null +++ b/assets/production_log_config.json @@ -0,0 +1,70 @@ +{ + "version": 1, + "disable_existing_loggers": false, + "filters": { + "correlation_id": { + "()": "asgi_correlation_id.CorrelationIdFilter", + "default_value": "-" + }, + "dict_values": { + "()": "recordlinker.log.DictArgFilter" + } + }, + "formatters": { + "default": { + "()": "recordlinker.log.JSONFormatter", + "format": "%(levelname)s %(name)s %(message)s %(correlation_id)s", + "timestamp": true + }, + "access": { + "()": "recordlinker.log.JSONFormatter", + "fmt": "%(message)s", + "static_fields": {"message": "ACCESS"} + } + }, + "handlers": { + "console": { + "formatter": "default", + "class": "logging.StreamHandler", + "filters": ["correlation_id"], + "stream": "ext://sys.stderr" + }, + "access": { + "formatter": "access", + "class": "logging.StreamHandler", + "filters": ["dict_values"], + "stream": "ext://sys.stdout" + } + }, + "loggers": { + "": { + "handlers": ["console"], + "level": "WARNING" + }, + "uvicorn": { + "handlers": ["console"], + "level": "INFO", + "propagate": false + }, + "uvicorn.error": { + "handlers": ["console"], + "level": "INFO", + "propagate": false + }, + "uvicorn.access": { + "handlers": ["console"], + "level": "CRITICAL", + "propagate": false + }, + "recordlinker": { + "handlers": ["console"], + "level": "INFO", + "propagate": false + }, + "recordlinker.access": { + "handlers": ["access"], + "level": "INFO", + "propagate": false + } + } +} diff --git a/compose.yml b/compose.yml index 17d85fdc..d8f6b195 100644 --- a/compose.yml +++ b/compose.yml @@ -29,6 +29,7 @@ services: args: USE_OTEL: ${USE_OTEL:-true} USE_MSSQL: ${USE_MSSQL:-false} + LOG_CONFIG: "" ports: - "8080:8080" environment: diff --git a/pyproject.toml b/pyproject.toml index e418c389..991c8755 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,8 @@ dependencies = [ "rapidfuzz", "opentelemetry-api", "opentelemetry-sdk", + "python-json-logger", + "asgi-correlation-id", # Database drivers "psycopg2-binary", # PostgreSQL "PyMySQL", # MySQL & MariaDB @@ -37,6 +39,7 @@ dependencies = [ [project.optional-dependencies] dev = [ + # development-only dependencies here "fastapi[standard]", "pytest>=8.3", "pytest-cov", @@ -46,7 +49,7 @@ dev = [ "types-python-dateutil" ] prod = [ - # List any additional production-only dependencies here + # production-only dependencies here ] [tool.setuptools] @@ -94,5 +97,3 @@ env = [ [tool.mypy] files = ["src"] mypy_path = ["src"] -# TODO: Resolve static typing issues in main.py -exclude = ["src/recordlinker/main.py", "src/recordlinker/linkage"] diff --git a/scripts/local_server.sh b/scripts/local_server.sh index ea28dc2f..40709b49 100755 --- a/scripts/local_server.sh +++ b/scripts/local_server.sh @@ -2,7 +2,7 @@ # # Run the API server locally. # -# Usage: scripts/local_service.sh +# Usage: scripts/local_server.sh # # is the port on which to run the API server. If not specified, the server # will run on port 8000. @@ -14,4 +14,4 @@ cd "$(dirname "$0")/.." PORT=${1:-8000} # Start the API server -uvicorn recordlinker.main:app --app-dir src --reload --reload-dir src/ --host 0 --port ${PORT} --log-config src/recordlinker/log_config.yml +uvicorn recordlinker.main:app --reload --reload-dir src/ --port ${PORT} diff --git a/scripts/open_pulls.sh b/scripts/open_pulls.sh index ab3f53bf..fd9678ba 100755 --- a/scripts/open_pulls.sh +++ b/scripts/open_pulls.sh @@ -36,7 +36,9 @@ results="" while IFS= read -r pr; do number=$(echo "$pr" | jq -r '.number') # get the timestamp of the ready_for_review event - ready_for_review=$(gh_api issues/${number}/timeline | jq -r 'map(select(.event == "ready_for_review")) | .[0].created_at') + # TODO: this call assumes a PR will have less than 100 timeline events, + # which is not always the case, we should handle pagination + ready_for_review=$(gh_api "issues/${number}/timeline?per_page=100" | jq -r 'map(select(.event == "ready_for_review")) | .[0].created_at') # calculate the number of days since the PR was ready for review # only calculate if ready_for_review does not equal "null" if [ "$ready_for_review" == "null" ]; then diff --git a/src/recordlinker/__init__.py b/src/recordlinker/__init__.py index e69de29b..bc922c28 100644 --- a/src/recordlinker/__init__.py +++ b/src/recordlinker/__init__.py @@ -0,0 +1,2 @@ +# initialize the configuration early +import recordlinker.config # noqa: F401 diff --git a/src/recordlinker/config.py b/src/recordlinker/config.py index 608eb539..33f25c57 100644 --- a/src/recordlinker/config.py +++ b/src/recordlinker/config.py @@ -1,3 +1,5 @@ +import json +import logging.config import typing import pydantic @@ -33,6 +35,10 @@ class Settings(pydantic_settings.BaseSettings): "above the connection pool size", default=10, ) + log_config: typing.Optional[str] = pydantic.Field( + description="The path to the logging configuration file", + default="", + ) initial_algorithms: str = pydantic.Field( description=( "The path to the initial algorithms file that is loaded on startup if the " @@ -42,5 +48,53 @@ class Settings(pydantic_settings.BaseSettings): default="assets/initial_algorithms.json", ) + def default_log_config(self) -> dict: + """ + Return the default logging configuration. + """ + return { + "version": 1, + "disable_existing_loggers": False, + "filters": {"key_value": {"()": "recordlinker.log.KeyValueFilter"}}, + "formatters": { + "default": { + "()": "uvicorn.logging.DefaultFormatter", + "fmt": "%(levelprefix)s [%(asctime)s] ... %(message)s", + "datefmt": "%H:%M:%S", + } + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "formatter": "default", + "filters": ["key_value"], + "stream": "ext://sys.stderr", + } + }, + "loggers": { + "": {"handlers": ["console"], "level": "WARNING"}, + "recordlinker": {"handlers": ["console"], "level": "INFO", "propagate": False}, + "recordlinker.access": {"handlers": ["console"], "level": "CRITICAL", "propagate": False}, + }, + } + + def configure_logging(self) -> None: + """ + Configure logging based on the provided configuration file. If no configuration + file is provided, use the default configuration. + """ + config = None + if self.log_config: + # Load logging config from the provided file + try: + with open(self.log_config, "r") as fobj: + config = json.loads(fobj.read()) + except Exception as exc: + raise ConfigurationError( + f"Error loading log configuration: {self.log_config}" + ) from exc + logging.config.dictConfig(config or self.default_log_config()) + settings = Settings() # type: ignore +settings.configure_logging() diff --git a/src/recordlinker/linkage/__init__.py b/src/recordlinker/linkage/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/recordlinker/linkage/algorithms.py b/src/recordlinker/linkage/algorithms.py deleted file mode 100644 index 21fa08aa..00000000 --- a/src/recordlinker/linkage/algorithms.py +++ /dev/null @@ -1,97 +0,0 @@ -# DEFAULT DIBBS ALGORITHMS -# These algorithms and log odds scores are the updated values developed after -# substantial statistical tuning. - -LOG_ODDS_SCORES = { - "address": 8.438284928858774, - "birthdate": 10.126641103800338, - "city": 2.438553006137189, - "first_name": 6.849475906891162, - "last_name": 6.350720397426025, - "mrn": 0.3051262572525359, - "sex": 0.7510419059643679, - "state": 0.022376768992488694, - "zip": 4.975031471124867, -} -FUZZY_THRESHOLDS = { - "first_name": 0.9, - "last_name": 0.9, - "birthdate": 0.95, - "address": 0.9, - "city": 0.92, - "zip": 0.95, -} - -DIBBS_BASIC = [ - { - "funcs": { - "first_name": "func:recordlinker.linkage.matchers.feature_match_fuzzy_string", - "last_name": "func:recordlinker.linkage.matchers.feature_match_exact", - }, - "blocks": [ - {"value": "birthdate"}, - {"value": "mrn", "transformation": "last4"}, - {"value": "sex"}, - ], - "matching_rule": "func:recordlinker.linkage.matchers.eval_perfect_match", - "cluster_ratio": 0.9, - "kwargs": {"thresholds": FUZZY_THRESHOLDS}, - }, - { - "funcs": { - "address": "func:recordlinker.linkage.matchers.feature_match_fuzzy_string", - "birthdate": "func:recordlinker.linkage.matchers.feature_match_exact", - }, - "blocks": [ - {"value": "zip"}, - {"value": "first_name", "transformation": "first4"}, - {"value": "last_name", "transformation": "first4"}, - {"value": "sex"}, - ], - "matching_rule": "func:recordlinker.linkage.matchers.eval_perfect_match", - "cluster_ratio": 0.9, - "kwargs": {"thresholds": FUZZY_THRESHOLDS}, - }, -] - -DIBBS_ENHANCED = [ - { - "funcs": { - "first_name": "func:recordlinker.linkage.matchers.feature_match_log_odds_fuzzy_compare", - "last_name": "func:recordlinker.linkage.matchers.feature_match_log_odds_fuzzy_compare", - }, - "blocks": [ - {"value": "birthdate"}, - {"value": "mrn", "transformation": "last4"}, - {"value": "sex"}, - ], - "matching_rule": "func:recordlinker.linkage.matchers.eval_log_odds_cutoff", - "cluster_ratio": 0.9, - "kwargs": { - "similarity_measure": "JaroWinkler", - "thresholds": FUZZY_THRESHOLDS, - "true_match_threshold": 12.2, - "log_odds": LOG_ODDS_SCORES, - }, - }, - { - "funcs": { - "address": "func:recordlinker.linkage.matchers.feature_match_log_odds_fuzzy_compare", - "birthdate": "func:recordlinker.linkage.matchers.feature_match_log_odds_fuzzy_compare", - }, - "blocks": [ - {"value": "zip"}, - {"value": "first_name", "transformation": "first4"}, - {"value": "last_name", "transformation": "first4"}, - {"value": "sex"}, - ], - "matching_rule": "func:recordlinker.linkage.matchers.eval_log_odds_cutoff", - "cluster_ratio": 0.9, - "kwargs": { - "similarity_measure": "JaroWinkler", - "thresholds": FUZZY_THRESHOLDS, - "true_match_threshold": 17.0, - "log_odds": LOG_ODDS_SCORES, - }, - }, -] diff --git a/src/recordlinker/log.py b/src/recordlinker/log.py new file mode 100644 index 00000000..5bdf9753 --- /dev/null +++ b/src/recordlinker/log.py @@ -0,0 +1,44 @@ +import logging +import typing + +import pythonjsonlogger.jsonlogger + +RESERVED_ATTRS = pythonjsonlogger.jsonlogger.RESERVED_ATTRS + ("taskName",) + + +# Custom filter to transform log arguments into JSON fields +class DictArgFilter(logging.Filter): + def filter(self, record): + """ + Filter the log record to extract the dictionary arguments as fields. + """ + # if the args are a dictionary, set the key-value pairs as attributes + if isinstance(record.args, dict): + for key, value in record.args.items(): + setattr(record, key, value) + return True + + +class KeyValueFilter(logging.Filter): + def filter(self, record): + """ + Filter the log record to extract the key-value pairs from the log message. + """ + for key, value in record.__dict__.items(): + if key not in RESERVED_ATTRS: + record.msg = f"{record.msg} {key}={value}" + return True + + +class JSONFormatter(pythonjsonlogger.jsonlogger.JsonFormatter): + """ + A custom JSON formatter that excldues the taskName field by default. + """ + + def __init__( + self, + *args: typing.Any, + reserved_attrs: tuple[str, ...] = RESERVED_ATTRS, + **kwargs: typing.Any, + ): + super().__init__(*args, reserved_attrs=reserved_attrs, **kwargs) diff --git a/src/recordlinker/log_config.yml b/src/recordlinker/log_config.yml deleted file mode 100644 index 6f1bf1e2..00000000 --- a/src/recordlinker/log_config.yml +++ /dev/null @@ -1,34 +0,0 @@ -version: 1 -disable_existing_loggers: False -formatters: - default: - # "()": uvicorn.logging.DefaultFormatter - format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - access: - # "()": uvicorn.logging.AccessFormatter - format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s" -handlers: - default: - formatter: default - class: logging.StreamHandler - stream: ext://sys.stderr - access: - formatter: access - class: logging.StreamHandler - stream: ext://sys.stdout -loggers: - uvicorn.error: - level: INFO - handlers: - - default - propagate: no - uvicorn.access: - level: INFO - handlers: - - access - propagate: no -root: - level: INFO - handlers: - - default - propagate: no diff --git a/src/recordlinker/main.py b/src/recordlinker/main.py index 3d006544..13d98aeb 100644 --- a/src/recordlinker/main.py +++ b/src/recordlinker/main.py @@ -15,6 +15,7 @@ from sqlalchemy import orm from sqlalchemy.sql import expression +from recordlinker import middleware from recordlinker import schemas from recordlinker import utils from recordlinker.base_service import BaseService @@ -31,8 +32,8 @@ include_health_check_endpoint=False, # openapi_url="/record-linkage/openapi.json", ).start() - - +app.add_middleware(middleware.CorrelationIdMiddleware) +app.add_middleware(middleware.AccessLogMiddleware) app.include_router(algorithm_router, prefix="/algorithm", tags=["algorithm"]) diff --git a/src/recordlinker/middleware.py b/src/recordlinker/middleware.py new file mode 100644 index 00000000..33df5b4a --- /dev/null +++ b/src/recordlinker/middleware.py @@ -0,0 +1,56 @@ +import logging +import time +import typing + +import asgi_correlation_id +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request + +DEFAULT_CORRELATION_ID_LENGTH = 12 +ACCESS_LOGGER = logging.getLogger("recordlinker.access") + + +class CorrelationIdMiddleware(asgi_correlation_id.CorrelationIdMiddleware): + """ + Override the default ASGI correlation ID middleware to provide a + default correlation ID length. + """ + def __init__(self, app: typing.Callable, correlation_id_length: int = DEFAULT_CORRELATION_ID_LENGTH): + super().__init__(app) + self.transformer = lambda a: a[:correlation_id_length] + + +class AccessLogMiddleware(BaseHTTPMiddleware): + """ + This custom access logging middleware is meant to be used instead of the default + Uvicorn access log middleware. As such, it provides more information about the + request including processing time and correlation ID. + """ + + async def dispatch(self, request: Request, call_next): + """ + Log the request and response details. + """ + # Record the start time of the request + start_time = time.time() + # Process the request and get the response + response = await call_next(request) + data = { + # Record the end time after the response + "process_time": (time.time() - start_time) * 1000, + # Record the correlation ID, if present + "correlation_id": request.headers.get(CorrelationIdMiddleware.header_name, "-"), + # Log details of the request + "client_ip": getattr(request.client, "host", "-"), + "method": request.method, + "path": request.url.path, + "http_version": request.scope.get("http_version", "unknown"), + "status_code": response.status_code, + } + msg = ( + '[%(correlation_id)s] %(client_ip)s - "%(method)s %(path)s ' + 'HTTP/%(http_version)s" %(status_code)d %(process_time).2fms' + ) + # Log the message + ACCESS_LOGGER.info(msg, data) + return response diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py new file mode 100644 index 00000000..5a2e1b4e --- /dev/null +++ b/tests/unit/test_config.py @@ -0,0 +1,14 @@ +import pytest + +from recordlinker import config + + +class TestSettings: + def test_configure_logging_invalid_file(self): + obj = config.Settings(log_config="invalid.json") + with pytest.raises(config.ConfigurationError): + obj.configure_logging() + + def test_configure_logging(self): + obj = config.Settings(log_config="assets/production_log_config.json") + assert obj.configure_logging() is None diff --git a/tests/unit/test_log.py b/tests/unit/test_log.py new file mode 100644 index 00000000..ae17948a --- /dev/null +++ b/tests/unit/test_log.py @@ -0,0 +1,90 @@ +import logging + +from recordlinker import log + + +class TestDictArgFilter: + def test_filter(self): + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test_log.py", + lineno=10, + exc_info=None, + msg="test", + args=[{"key": "value"}], + ) + assert log.DictArgFilter().filter(record) + assert record.key == "value" + + def test_no_dict_args(self): + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test_log.py", + lineno=10, + exc_info=None, + msg="test", + args=["value"], + ) + assert log.DictArgFilter().filter(record) + assert not hasattr(record, "value") + + +class TestKeyValueFilter: + def test_filter(self): + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test_log.py", + lineno=10, + exc_info=None, + msg="test", + args=[], + ) + record.key = "value" + assert log.KeyValueFilter().filter(record) + assert record.msg == "test key=value" + + def test_reserved_attrs(self): + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test_log.py", + lineno=10, + exc_info=None, + msg="test", + args=[], + ) + record.taskName = "task" + assert log.KeyValueFilter().filter(record) + assert record.msg == "test" + + +class TestJsonFormatter: + def test_format(self): + formatter = log.JSONFormatter() + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test_log.py", + lineno=10, + exc_info=None, + msg="test", + args=[], + ) + assert formatter.format(record) == '{"message": "test"}' + + def test_format_reserved_attrs(self): + formatter = log.JSONFormatter() + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test_log.py", + lineno=10, + exc_info=None, + msg="test", + args=[], + ) + record.taskName = "task" + assert formatter.format(record) == '{"message": "test"}' diff --git a/tests/unit/test_middleware.py b/tests/unit/test_middleware.py new file mode 100644 index 00000000..0da2299e --- /dev/null +++ b/tests/unit/test_middleware.py @@ -0,0 +1,41 @@ +import unittest.mock + +import starlette.applications + +from recordlinker import middleware + + +class TestCorrelationIdMiddleware: + def test_default(self): + app = starlette.applications.Starlette() + obj = middleware.CorrelationIdMiddleware(app) + assert obj.transformer("1234567890") == "1234567890" + assert obj.transformer("123456789012345678") == "123456789012" + + def test_custom_length(self): + app = starlette.applications.Starlette() + obj = middleware.CorrelationIdMiddleware(app, correlation_id_length=4) + assert obj.transformer("1234567890") == "1234" + assert obj.transformer("123456789012345678") == "1234" + + +class TestAccessLogMiddleware: + def test_dispatch(self, client): + with unittest.mock.patch("recordlinker.middleware.ACCESS_LOGGER") as mock_logger: + response = client.get("/") + # Verify the response + assert response.status_code == 200 + assert response.json() == {"status": "OK"} + assert len(mock_logger.info.mock_calls) == 1 + expected = ( + '[%(correlation_id)s] %(client_ip)s - "%(method)s %(path)s ' + 'HTTP/%(http_version)s" %(status_code)d %(process_time).2fms' + ) + assert mock_logger.info.call_args[0][0] == expected + assert mock_logger.info.call_args[0][1]["client_ip"] == "testclient" + assert mock_logger.info.call_args[0][1]["method"] == "GET" + assert mock_logger.info.call_args[0][1]["path"] == "/" + assert mock_logger.info.call_args[0][1]["http_version"] == "1.1" + assert mock_logger.info.call_args[0][1]["status_code"] == 200 + assert mock_logger.info.call_args[0][1]["process_time"] > 0 + assert len(mock_logger.info.call_args[0][1]["correlation_id"]) == 12