From 3701066a176457e2a8598024335d9d0944bfea2c Mon Sep 17 00:00:00 2001 From: Viktor Taskov Date: Wed, 26 Jul 2023 15:53:50 +0100 Subject: [PATCH] Lowercase global vars, they are not constants --- mlserver/context.py | 12 ++++++------ mlserver/logging.py | 6 +++--- mlserver/metrics/context.py | 6 +++--- tests/test_context.py | 14 +++++++------- 4 files changed, 19 insertions(+), 19 deletions(-) diff --git a/mlserver/context.py b/mlserver/context.py index 99c49471f..81d8fa1a9 100644 --- a/mlserver/context.py +++ b/mlserver/context.py @@ -3,21 +3,21 @@ from .settings import ModelSettings -MODEL_NAME_VAR: ContextVar[str] = ContextVar("model_name") -MODEL_VERSION_VAR: ContextVar[str] = ContextVar("model_version") +model_name_var: ContextVar[str] = ContextVar("model_name") +model_version_var: ContextVar[str] = ContextVar("model_version") @contextmanager def model_context(model_settings: ModelSettings): - model_name_token = MODEL_NAME_VAR.set(model_settings.name) + model_name_token = model_name_var.set(model_settings.name) model_version = "" if model_settings.version: model_version = model_settings.version - model_version_token = MODEL_VERSION_VAR.set(model_version) + model_version_token = model_version_var.set(model_version) try: yield finally: - MODEL_NAME_VAR.reset(model_name_token) - MODEL_VERSION_VAR.reset(model_version_token) + model_name_var.reset(model_name_token) + model_version_var.reset(model_version_token) diff --git a/mlserver/logging.py b/mlserver/logging.py index 3c68d02ef..ca2ad9183 100644 --- a/mlserver/logging.py +++ b/mlserver/logging.py @@ -7,7 +7,7 @@ from typing import Optional, Dict, Union import logging.config -from .context import MODEL_NAME_VAR, MODEL_VERSION_VAR +from .context import model_name_var, model_version_var from .settings import Settings LoggerName = "mlserver" @@ -75,8 +75,8 @@ def _format_structured_model_details(name: str, version: str) -> str: return model_details def format(self, record: logging.LogRecord) -> str: - model_name = MODEL_NAME_VAR.get("") - model_version = MODEL_VERSION_VAR.get("") + model_name = model_name_var.get("") + model_version = model_version_var.get("") record.model = ( self._format_structured_model_details(model_name, model_version) diff --git a/mlserver/metrics/context.py b/mlserver/metrics/context.py index 3b58f17fe..ccb828c11 100644 --- a/mlserver/metrics/context.py +++ b/mlserver/metrics/context.py @@ -2,7 +2,7 @@ from .registry import REGISTRY from .errors import InvalidModelContext -from ..context import MODEL_NAME_VAR, MODEL_VERSION_VAR +from ..context import model_name_var, model_version_var SELDON_MODEL_NAME_LABEL = "model_name" SELDON_MODEL_VERSION_LABEL = "model_version" @@ -28,8 +28,8 @@ def register(name: str, description: str) -> Histogram: def _get_labels_from_context() -> dict: try: - model_name = MODEL_NAME_VAR.get() - model_version = MODEL_VERSION_VAR.get() + model_name = model_name_var.get() + model_version = model_version_var.get() return { SELDON_MODEL_NAME_LABEL: model_name, SELDON_MODEL_VERSION_LABEL: model_version, diff --git a/tests/test_context.py b/tests/test_context.py index 27fc071e0..e7e8fa301 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -3,7 +3,7 @@ from typing import Optional from mlserver.settings import ModelSettings, ModelParameters -from mlserver.context import model_context, MODEL_NAME_VAR, MODEL_VERSION_VAR +from mlserver.context import model_context, model_name_var, model_version_var from .fixtures import SumModel @@ -29,19 +29,19 @@ def test_model_context(name: str, version: Optional[str], expected_version: str) ) with pytest.raises(LookupError): - _ = MODEL_NAME_VAR.get() + _ = model_name_var.get() with pytest.raises(LookupError): - _ = MODEL_VERSION_VAR.get() + _ = model_version_var.get() with model_context(model_settings): - var_name = MODEL_NAME_VAR.get() - var_version = MODEL_VERSION_VAR.get() + var_name = model_name_var.get() + var_version = model_version_var.get() assert var_name == name assert var_version == expected_version with pytest.raises(LookupError): - _ = MODEL_NAME_VAR.get() + _ = model_name_var.get() with pytest.raises(LookupError): - _ = MODEL_VERSION_VAR.get() + _ = model_version_var.get()