From 4d05d95841ca492846a158efb0e2bcea39500c37 Mon Sep 17 00:00:00 2001 From: Leila Wang <77892032+yleilawang@users.noreply.github.com> Date: Thu, 26 Sep 2024 14:46:32 -0400 Subject: [PATCH] Implement: multiple save and load for mlflow registry (#416) 1. Implemented save_multiple and load_multiple for mlflow registry 2. Test cases for implementation. --------- Signed-off-by: Leila Wang --- .gitignore | 5 + numalogic/registry/mlflow_registry.py | 126 ++++++++++++++++- pyproject.toml | 2 +- tests/registry/_mlflow_utils.py | 124 +++++++++++++++++ tests/registry/test_mlflow_registry.py | 179 ++++++++++++++++++++++++- 5 files changed, 426 insertions(+), 10 deletions(-) diff --git a/.gitignore b/.gitignore index 3db98ce0..c50fd26c 100644 --- a/.gitignore +++ b/.gitignore @@ -78,6 +78,8 @@ target/ #mlflow /.mlruns *.db +mlruns/ +mlartifacts/ # Jupyter Notebook .ipynb_checkpoints @@ -169,4 +171,7 @@ cython_debug/ # Mac related *.DS_Store +# vscode +.vscode/ + .python-version diff --git a/numalogic/registry/mlflow_registry.py b/numalogic/registry/mlflow_registry.py index 43c5b519..3b406e69 100644 --- a/numalogic/registry/mlflow_registry.py +++ b/numalogic/registry/mlflow_registry.py @@ -15,11 +15,9 @@ from enum import Enum from typing import Optional, Any -import mlflow.pyfunc -import mlflow.pytorch -import mlflow.sklearn +import mlflow from mlflow.entities.model_registry import ModelVersion -from mlflow.exceptions import RestException +from mlflow.exceptions import RestException, MlflowException from mlflow.protos.databricks_pb2 import ErrorCode, RESOURCE_DOES_NOT_EXIST from mlflow.tracking import MlflowClient @@ -187,6 +185,43 @@ def load( self._save_in_cache(model_key, artifact_data) return artifact_data + def load_multiple( + self, + skeys: KEYS, + dkeys: KEYS, + ) -> Optional[ArtifactData]: + """ + Load multiple artifacts from the registry for pyfunc models. + Args: + skeys (KEYS): The source keys of the artifacts to load. + dkeys: dynamic key fields as list/tuple of strings. + + Returns + ------- + Optional[ArtifactData]: The loaded ArtifactData object if available otherwise None. + ArtifactData should contain a dictionary of artifacts. + """ + loaded_model = self.load(skeys=skeys, dkeys=dkeys, artifact_type="pyfunc") + if loaded_model is None: + return None + + try: + unwrapped_composite_model = loaded_model.artifact.unwrap_python_model() + except MlflowException as e: + raise TypeError("The loaded model is not a valid pyfunc Python model.") from e + except AttributeError: + _LOGGER.exception("The loaded model does not have an unwrap_python_model method") + return None + except Exception: + _LOGGER.exception("Unexpected error occurred while unwrapping python model.") + return None + + return ArtifactData( + artifact=unwrapped_composite_model.dict_artifacts, + metadata=loaded_model.metadata, + extras=loaded_model.extras, + ) + @staticmethod def __log_mlflow_err(mlflow_err: RestException, model_key: str) -> None: if ErrorCode.Value(mlflow_err.error_code) == RESOURCE_DOES_NOT_EXIST: @@ -225,7 +260,10 @@ def save( handler = self.handler_from_type(artifact_type) try: mlflow.start_run(run_id=run_id) - handler.log_model(artifact, "model", registered_model_name=model_key) + if artifact_type == "pyfunc": + handler.log_model("model", python_model=artifact, registered_model_name=model_key) + else: + handler.log_model(artifact, "model", registered_model_name=model_key) if metadata: mlflow.log_params(metadata) model_version = self.transition_stage(skeys=skeys, dkeys=dkeys) @@ -238,6 +276,42 @@ def save( finally: mlflow.end_run() + def save_multiple( + self, + skeys: KEYS, + dkeys: KEYS, + dict_artifacts: dict[str, artifact_t], + **metadata: META_VT, + ) -> Optional[ModelVersion]: + """ + Saves multiple artifacts into mlflow registry. The last save stores all the + artifact versions in the metadata. + + Args: + ---- + skeys (KEYS): Static key fields as a list or tuple of strings. + dkeys (KEYS): Dynamic key fields as a list or tuple of strings. + dict_artifacts (dict[str, artifact_t]): Dictionary of artifacts to save. + **metadata (META_VT): Additional metadata to be saved with the artifacts. + + Returns + ------- + Optional[ModelVersion]: An instance of the MLflow ModelVersion. + + """ + if len(dict_artifacts) == 1: + _LOGGER.warning( + "Only one artifact present in dict_artifacts. Saving directly is recommended." + ) + multiple_artifacts = CompositeModel(skeys=skeys, dict_artifacts=dict_artifacts, **metadata) + return self.save( + skeys=skeys, + dkeys=dkeys, + artifact=multiple_artifacts, + artifact_type="pyfunc", + **metadata, + ) + @staticmethod def is_artifact_stale(artifact_data: ArtifactData, freq_hr: int) -> bool: """Returns whether the given artifact is stale or not, i.e. if @@ -338,3 +412,45 @@ def __load_artifacts( version_info.version, ) return model, metadata + + +class CompositeModel(mlflow.pyfunc.PythonModel): + """A composite model that represents multiple artifacts. + + This class extends the `mlflow.pyfunc.PythonModel` class and is used to store and load + multiple artifacts in the MLflow registry. It provides a convenient way to manage and + organize multiple artifacts associated with a single model. + + Args: + skeys (KEYS): The static keys of the artifacts. + dict_artifacts (dict[str, KeyedArtifact]): A dictionary mapping dynamic keys to + `KeyedArtifact` objects. + **metadata (META_VT): Additional metadata associated with the artifacts. + + Methods + ------- + predict: Not implemented for our use case. + + Attributes + ---------- + skeys (KEYS): The static keys of the artifacts. + dict_artifacts (dict[str, KeyedArtifact]): A dictionary mapping dynamic keys to + `KeyedArtifact` objects. + metadata (META_VT): Additional metadata associated with the artifacts. + """ + + __slots__ = ("skeys", "dict_artifacts", "metadata") + + def __init__(self, skeys: KEYS, dict_artifacts: dict[str, artifact_t], **metadata: META_VT): + self.skeys = skeys + self.dict_artifacts = dict_artifacts + self.metadata = metadata + + def predict(self, context, model_input, params: Optional[dict[str, Any]] = None): + """ + Predict method is not implemented for our use case. + + The CompositeModel class is designed to store and load multiple artifacts, + and the predict method is not required for this functionality. + """ + raise NotImplementedError("The predict method is not implemented for CompositeModel.") diff --git a/pyproject.toml b/pyproject.toml index 222b544e..5fb0804b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "numalogic" -version = "0.13.2" +version = "0.13.3" description = "Collection of operational Machine Learning models and tools." authors = ["Numalogic Developers"] packages = [{ include = "numalogic" }] diff --git a/tests/registry/_mlflow_utils.py b/tests/registry/_mlflow_utils.py index 4b61eddc..3afbc96a 100644 --- a/tests/registry/_mlflow_utils.py +++ b/tests/registry/_mlflow_utils.py @@ -9,8 +9,11 @@ from mlflow.store.entities import PagedList from sklearn.preprocessing import StandardScaler from torch import tensor +from mlflow.models import Model +from numalogic.models.autoencoder.variants.vanilla import VanillaAE from numalogic.models.threshold import StdDevThreshold +from numalogic.registry.mlflow_registry import CompositeModel def create_model(): @@ -135,6 +138,103 @@ def mock_log_model_sklearn(*_, **__): ) +def mock_log_model_pyfunc(*_, **__): + return ModelInfo( + artifact_path="model", + flavors={ + "pyfunc": {"model_data": "data", "pyfunc_version": "1.11.0", "code": None}, + "python_function": { + "pickle_module_name": "mlflow.pyfunc.pickle_module", + "loader_module": "mlflow.pyfunc", + "python_version": "3.8.5", + "data": "data", + "env": "conda.yaml", + }, + }, + model_uri="runs:/a7c0b376530b40d7b23e6ce2081c899c/model", + model_uuid="a7c0b376530b40d7b23e6ce2081c899c", + run_id="a7c0b376530b40d7b23e6ce2081c899c", + saved_input_example_info=None, + signature_dict=None, + utc_time_created="2022-05-23 22:35:59.557372", + mlflow_version="2.0.1", + signature=None, + ) + + +def mock_load_model_pyfunc(*_, **__): + artifact_path = "model" + flavors = { + "python_function": { + "cloudpickle_version": "3.0.0", + "code": None, + "env": {"conda": "conda.yaml", "virtualenv": "python_env.yaml"}, + "loader_module": "mlflow.pyfunc.model", + "python_model": "python_model.pkl", + "python_version": "3.10.14", + "streamable": False, + } + } + model_size_bytes = 8912 + model_uuid = "ae27ecc166c94c01a4f4dccaf84ca5dc" + run_id = "7e85a3fa46d44e668c840f3dddc909c3" + utc_time_created = "2024-09-18 17:12:41.501209" + model = Model( + artifact_path=artifact_path, + flavors=flavors, + model_size_bytes=model_size_bytes, + model_uuid=model_uuid, + run_id=run_id, + utc_time_created=utc_time_created, + mlflow_version="2.16.0", + ) + return mlflow.pyfunc.PyFuncModel( + model_meta=model, + model_impl=TestObject( + python_model=CompositeModel( + skeys=["error"], + dict_artifacts={ + "inference": VanillaAE(10), + "precrocessing": StandardScaler(), + "threshold": StdDevThreshold(), + }, + **{"learning_rate": 0.01}, + ) + ), + ) + + +def mock_load_model_pyfunc_type_error(*_, **__): + artifact_path = "model" + flavors = { + "python_function": { + "cloudpickle_version": "3.0.0", + "code": None, + "env": {"conda": "conda.yaml", "virtualenv": "python_env.yaml"}, + "loader_module": "mlflow.pytorch.model", + "python_model": "python_model.pkl", + "python_version": "3.10.14", + "streamable": False, + } + } + model_size_bytes = 8912 + model_uuid = "ae27ecc166c94c01a4f4dccaf84ca5dc" + run_id = "7e85a3fa46d44e668c840f3dddc909c3" + utc_time_created = "2024-09-18 17:12:41.501209" + model = Model( + artifact_path=artifact_path, + flavors=flavors, + model_size_bytes=model_size_bytes, + model_uuid=model_uuid, + run_id=run_id, + utc_time_created=utc_time_created, + mlflow_version="2.16.0", + ) + return mlflow.pyfunc.PyFuncModel( + model_meta=model, model_impl=mlflow.pytorch._PyTorchWrapper(VanillaAE(10), device="cpu") + ) + + def mock_transition_stage(*_, **__): return ModelVersion( creation_timestamp=1653402941169, @@ -303,6 +403,25 @@ def return_sklearn_rundata(): ) +def return_pyfunc_rundata(): + return Run( + run_info=RunInfo( + artifact_uri="mlflow-artifacts:/0/a7c0b376530b40d7b23e6ce2081c899c/artifacts/model", + end_time=None, + experiment_id="0", + lifecycle_stage="active", + run_id="a7c0b376530b40d7b23e6ce2081c899c", + run_uuid="a7c0b376530b40d7b23e6ce2081c899c", + start_time=1658788772612, + status="RUNNING", + user_id="lol", + ), + run_data=RunData( + metrics={}, tags={}, params=[mlflow.entities.Param("learning_rate", "0.01")] + ), + ) + + def return_pytorch_rundata_dict(): return Run( run_info=RunInfo( @@ -318,3 +437,8 @@ def return_pytorch_rundata_dict(): ), run_data=RunData(metrics={}, tags={}, params=[mlflow.entities.Param("lr", "0.001")]), ) + + +class TestObject(mlflow.pyfunc.PythonModel): + def __init__(self, python_model): + self.python_model = python_model diff --git a/tests/registry/test_mlflow_registry.py b/tests/registry/test_mlflow_registry.py index 8482c6e8..716265b6 100644 --- a/tests/registry/test_mlflow_registry.py +++ b/tests/registry/test_mlflow_registry.py @@ -3,6 +3,9 @@ from contextlib import contextmanager from unittest.mock import patch, Mock +import mlflow.pytorch # noqa: F401 +import mlflow.pyfunc # noqa: F401 +import mlflow.sklearn # noqa: F401 from freezegun import freeze_time from mlflow import ActiveRun from mlflow.exceptions import RestException @@ -11,11 +14,15 @@ from sklearn.preprocessing import StandardScaler from numalogic.models.autoencoder.variants import VanillaAE +from numalogic.models.threshold._std import StdDevThreshold from numalogic.registry import MLflowRegistry, ArtifactData, LocalLRUCache -from numalogic.registry.mlflow_registry import ModelStage +from numalogic.registry.mlflow_registry import CompositeModel, ModelStage from tests.registry._mlflow_utils import ( + mock_load_model_pyfunc, + mock_load_model_pyfunc_type_error, + mock_log_model_pyfunc, model_sklearn, create_model, mock_log_model_pytorch, @@ -23,6 +30,7 @@ mock_get_model_version, mock_transition_stage, mock_log_model_sklearn, + return_pyfunc_rundata, return_pytorch_rundata_dict, return_empty_rundata, mock_list_of_model_version, @@ -56,22 +64,171 @@ def test_construct_key(self): self.assertEqual("model_:nnet::error1", key) @patch("mlflow.pytorch.log_model", mock_log_model_pytorch) - @patch("mlflow.log_param", mock_log_state_dict) + @patch("mlflow.log_params", mock_log_state_dict) @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict()))) @patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict())) @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) - @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version) + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) def test_save_model(self): ml = MLflowRegistry(TRACKING_URI) skeys = self.skeys dkeys = self.dkeys status = ml.save( - skeys=skeys, dkeys=dkeys, artifact=self.model, run_id="1234", artifact_type="pytorch" + skeys=skeys, + dkeys=dkeys, + artifact=self.model, + run_id="1234", + artifact_type="pytorch", + **{"lr": 0.01}, + ) + mock_status = "READY" + self.assertEqual(mock_status, status.status) + + @patch("mlflow.pyfunc.log_model", mock_log_model_pyfunc) + @patch("mlflow.log_params", mock_log_state_dict) + @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pyfunc_rundata()))) + @patch("mlflow.active_run", Mock(return_value=return_pyfunc_rundata())) + @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) + @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) + def test_save_multiple_models_pyfunc(self): + ml = MLflowRegistry(TRACKING_URI) + status = ml.save_multiple( + skeys=self.skeys, + dict_artifacts={ + "inference": VanillaAE(10), + "precrocessing": StandardScaler(), + "threshold": StdDevThreshold(), + }, + dkeys=["unique", "sorted"], + **{"learning_rate": 0.01}, ) + self.assertIsNotNone(status) mock_status = "READY" self.assertEqual(mock_status, status.status) + @patch("mlflow.pyfunc.log_model", mock_log_model_pyfunc) + @patch("mlflow.log_params", mock_log_state_dict) + @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pyfunc_rundata()))) + @patch("mlflow.active_run", Mock(return_value=return_pyfunc_rundata())) + @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) + @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) + def test_save_multiple_models_when_only_one_model(self): + ml = MLflowRegistry(TRACKING_URI) + with self.assertLogs(level="WARNING"): + ml.save_multiple( + skeys=self.skeys, + dict_artifacts={ + "inference": VanillaAE(10), + }, + dkeys=["unique", "sorted"], + **{"learning_rate": 0.01}, + ) + + @patch("mlflow.pyfunc.log_model", mock_log_model_pyfunc) + @patch("mlflow.log_params", mock_log_state_dict) + @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pyfunc_rundata()))) + @patch("mlflow.active_run", Mock(return_value=return_pyfunc_rundata())) + @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) + @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) + @patch("mlflow.pyfunc.load_model", mock_load_model_pyfunc) + @patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pyfunc_rundata())) + def test_load_multiple_models_when_pyfunc_model_exist(self): + ml = MLflowRegistry(TRACKING_URI) + skeys = self.skeys + dkeys = ["unique", "sorted"] + ml.save_multiple( + skeys=self.skeys, + dict_artifacts={ + "inference": VanillaAE(10), + "precrocessing": StandardScaler(), + "threshold": StdDevThreshold(), + }, + dkeys=["unique", "sorted"], + **{"learning_rate": 0.01}, + ) + data = ml.load_multiple(skeys=skeys, dkeys=dkeys) + self.assertIsNotNone(data.metadata) + self.assertIsInstance(data, ArtifactData) + self.assertIsInstance(data.artifact, dict) + self.assertIsInstance(data.artifact["inference"], VanillaAE) + self.assertIsInstance(data.artifact["precrocessing"], StandardScaler) + + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) + @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) + @patch( + "mlflow.tracking.MlflowClient.get_latest_versions", + Mock(return_value=PagedList(items=[], token=None)), + ) + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) + @patch("mlflow.pyfunc.load_model", Mock(return_value=None)) + @patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_empty_rundata())) + def test_load_model_when_no_model_pyfunc(self): + fake_skeys = ["Fakemodel_"] + fake_dkeys = ["error"] + ml = MLflowRegistry(TRACKING_URI) + with self.assertLogs(level="ERROR") as log: + o = ml.load_multiple(skeys=fake_skeys, dkeys=fake_dkeys) + self.assertIsNone(o) + self.assertTrue(log.output) + + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) + @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) + @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) + @patch("mlflow.tracking.MlflowClient.get_model_version", mock_get_model_version_obj) + @patch( + "mlflow.pyfunc.load_model", + Mock( + return_value=CompositeModel( + skeys=["error"], + dict_artifacts={ + "inference": VanillaAE(10), + "precrocessing": StandardScaler(), + "threshold": StdDevThreshold(), + }, + **{"learning_rate": 0.01}, + ) + ), + ) + @patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pyfunc_rundata())) + def test_load_multiple_attribute_error(self): + ml = MLflowRegistry(TRACKING_URI) + skeys = self.skeys + dkeys = ["unique", "sorted"] + with self.assertLogs(level="ERROR") as log: + result = ml.load_multiple(skeys=skeys, dkeys=dkeys) + self.assertIsNone(result) + self.assertTrue( + any( + "The loaded model does not have an unwrap_python_model method" in message + for message in log.output + ) + ) + + @patch("mlflow.pytorch.log_model", mock_log_model_pytorch) + @patch("mlflow.log_params", mock_log_state_dict) + @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict()))) + @patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict())) + @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) + @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) + @patch("mlflow.pyfunc.load_model", mock_load_model_pyfunc_type_error) + @patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pyfunc_rundata())) + def test_load_multiple_type_error(self): + ml = MLflowRegistry(TRACKING_URI) + ml.save( + skeys=self.skeys, + dkeys=self.dkeys, + artifact=self.model, + artifact_type="pytorch", + **{"lr": 0.01}, + ) + with self.assertRaises(TypeError): + ml.load_multiple(skeys=self.skeys, dkeys=self.dkeys) + @patch("mlflow.sklearn.log_model", mock_log_model_sklearn) @patch("mlflow.log_param", mock_log_state_dict) @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_sklearn_rundata()))) @@ -407,6 +564,20 @@ def test_cache_loading(self): key = MLflowRegistry.construct_key(self.skeys, self.dkeys) self.assertIsNotNone(ml._load_from_cache(key)) + @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pyfunc_rundata()))) + @patch("mlflow.active_run", Mock(return_value=return_pyfunc_rundata())) + @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) + @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) + @patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pyfunc_rundata())) + @patch("mlflow.pyfunc.load_model", mock_load_model_pyfunc) + def test_cache_loading_pyfunc(self): + cache_registry = LocalLRUCache(ttl=50000) + ml = MLflowRegistry(TRACKING_URI, cache_registry=cache_registry) + dkeys = ["unique", "sorted"] + ml.load_multiple(skeys=self.skeys, dkeys=dkeys) + key = MLflowRegistry.construct_key(self.skeys, dkeys) + self.assertIsNotNone(ml._load_from_cache(key)) + if __name__ == "__main__": unittest.main()