From c4d3bccf41c09ded93a227682b0890f83f7e655a Mon Sep 17 00:00:00 2001 From: Kushal Batra <34571348+s0nicboOm@users.noreply.github.com> Date: Thu, 5 Sep 2024 12:58:05 -0700 Subject: [PATCH] fix: support all sklearn classes (#408) 1. Support artifact-type in mlflow registry from user. 2. Update docs --------- Signed-off-by: Kushal Batra --- docs/ml-flow.md | 6 +-- examples/multi_udf/src/udf/train.py | 13 +++-- numalogic/registry/mlflow_registry.py | 29 ++++++----- tests/registry/test_mlflow_registry.py | 67 +++++++++++++++++++------- 4 files changed, 76 insertions(+), 39 deletions(-) diff --git a/docs/ml-flow.md b/docs/ml-flow.md index a7d7a9fc..55d67196 100644 --- a/docs/ml-flow.md +++ b/docs/ml-flow.md @@ -23,7 +23,7 @@ Numalogic provides `MLflowRegistry`, to save and load models to/from MLflow. Here, `tracking_uri` is the uri where mlflow server is running. The `static_keys` and `dynamic_keys` are used to form a unique key for the model. -The `artifact` would be the model or transformer object that needs to be saved. +The `artifact` would be the model or transformer object that needs to be saved. Artifact saving also takes in 'artifact_type' which is the type of the artifact being saved. Currently, 'pytorch', 'sklearn' and 'pyfunc' is supported. A dictionary of metadata can also be saved along with the artifact. ```python from numalogic.registry import MLflowRegistry @@ -37,13 +37,13 @@ dynamic_keys = ["vanilla", "seq10"] registry = MLflowRegistry(tracking_uri="http://0.0.0.0:5000") registry.save( - skeys=static_keys, dkeys=dynamic_keys, artifact=model, seq_len=10, lr=0.001 + skeys=static_keys, dkeys=dynamic_keys, artifact_type='pytorch', artifact=model, seq_len=10, lr=0.001 ) ``` ### Model loading -Once, the models are save to MLflow, the `load` function of `MLflowRegistry` can be used to load the model. +Once, the models are save to MLflow, the `load` function of `MLflowRegistry` can be used to load the model. Like how the artifacts were saved with 'artifact_type', the same type shall be passed to the `load` function as well. ```python from numalogic.registry import MLflowRegistry diff --git a/examples/multi_udf/src/udf/train.py b/examples/multi_udf/src/udf/train.py index ff137c6c..cb9eae83 100644 --- a/examples/multi_udf/src/udf/train.py +++ b/examples/multi_udf/src/udf/train.py @@ -34,10 +34,15 @@ def __init__(self): self.model_key = "ae::model" def _save_artifact( - self, model, skeys: list[str], dkeys: list[str], _: Optional[TimeseriesTrainer] = None + self, + model, + skeys: list[str], + dkeys: list[str], + artifact_type: str, + _: Optional[TimeseriesTrainer] = None, ) -> None: """Saves the model in the registry.""" - self.registry.save(skeys=skeys, dkeys=dkeys, artifact=model) + self.registry.save(skeys=skeys, dkeys=dkeys, artifact=model, artifact_type=artifact_type) @staticmethod def _fit_preprocess(data: pd.DataFrame) -> npt.NDArray[float]: @@ -93,8 +98,8 @@ def exec(self, keys: list[str], datum: Datum) -> Messages: thresh_clf = self._fit_threshold(train_reconerr.numpy()) # Save to registry - self._save_artifact(model, ["ae"], ["model"], trainer) - self._save_artifact(thresh_clf, ["thresh_clf"], ["model"]) + self._save_artifact(model, ["ae"], ["model"], "pytorch", trainer) + self._save_artifact(thresh_clf, ["thresh_clf"], ["model"], artifact_type="sklearn") LOGGER.info("%s - Model Saving complete", payload.uuid) # Train is the last vertex in the graph diff --git a/numalogic/registry/mlflow_registry.py b/numalogic/registry/mlflow_registry.py index 97e39d73..8db09624 100644 --- a/numalogic/registry/mlflow_registry.py +++ b/numalogic/registry/mlflow_registry.py @@ -21,9 +21,7 @@ from mlflow.exceptions import RestException from mlflow.protos.databricks_pb2 import ErrorCode, RESOURCE_DOES_NOT_EXIST from mlflow.tracking import MlflowClient -from torch import nn -from numalogic.base import BaseThresholdModel, BaseTransformer from numalogic.registry import ArtifactManager, ArtifactData from numalogic.registry.artifact import ArtifactCache from numalogic.tools.exceptions import ModelVersionError @@ -65,7 +63,8 @@ class MLflowRegistry(ArtifactManager): >>> data = [[0, 0], [0, 0], [1, 1], [1, 1]] >>> scaler = StandardScaler.fit(data) >>> registry = MLflowRegistry(tracking_uri="http://0.0.0.0:8080") - >>> registry.save(skeys=["model"], dkeys=["AE"], artifact=VanillaAE(10)) + >>> registry.save(skeys=["model"], dkeys=["AE"], artifact=VanillaAE(10), + >>> artifact_type="pytorch") >>> artifact_data = registry.load(skeys=["model"], dkeys=["AE"], artifact_type="pytorch") """ @@ -100,17 +99,13 @@ def __init__( self.model_stage = model_stage self.cache_registry = cache_registry - @staticmethod - def handler_from_obj(artifact: artifact_t): - if isinstance(artifact, nn.Module): - return mlflow.pytorch - if isinstance(artifact, (BaseThresholdModel, BaseTransformer)): - return mlflow.sklearn - return mlflow.pyfunc - @staticmethod def handler_from_type(artifact_type: str): """Helper method to return the right handler given the artifact type.""" + if not artifact_type: + raise ValueError( + "Artifact Type not provided. Options include: {pytorch, sklearn, pyfunc}" + ) if artifact_type == "pytorch": return mlflow.pytorch if artifact_type == "sklearn": @@ -137,9 +132,9 @@ def load( self, skeys: KEYS, dkeys: KEYS, + artifact_type: Optional[str] = None, latest: bool = True, version: Optional[str] = None, - artifact_type: str = "pytorch", ) -> Optional[ArtifactData]: """Load the artifact from the registry. The artifact is loaded from the cache if available. @@ -149,7 +144,8 @@ def load( dkeys: Dynamic keys latest: Load the latest version of the model (default = True) version: Version of the model to load (default = None) - artifact_type: Type of the artifact to load (default = "pytorch"). + artifact_type: Type of the artifact to load. Options include: pytorch, pyfunc + and sklearn. Returns ------- @@ -205,6 +201,7 @@ def save( dkeys: KEYS, artifact: artifact_t, run_id: Optional[str] = None, + artifact_type: Optional[str] = None, **metadata: META_VT, ) -> Optional[ModelVersion]: """Saves the artifact into mlflow registry and updates version. @@ -216,13 +213,15 @@ def save( artifact: primary artifact to be saved run_id: mlflow run id metadata: additional metadata surrounding the artifact that needs to be saved. + artifact_type: Type of the artifact to save. Options include: pytorch, pyfunc + and sklearn. Returns ------- mlflow ModelVersion instance """ model_key = self.construct_key(skeys, dkeys) - handler = self.handler_from_obj(artifact) + handler = self.handler_from_type(artifact_type) try: mlflow.start_run(run_id=run_id) handler.log_model(artifact, "model", registered_model_name=model_key) @@ -241,7 +240,7 @@ def save( @staticmethod def is_artifact_stale(artifact_data: ArtifactData, freq_hr: int) -> bool: """Returns whether the given artifact is stale or not, i.e. if - more time has elasped since it was last retrained. + more time has elapsed since it was last retrained. Args: ---- diff --git a/tests/registry/test_mlflow_registry.py b/tests/registry/test_mlflow_registry.py index dcf41c37..2058a4dc 100644 --- a/tests/registry/test_mlflow_registry.py +++ b/tests/registry/test_mlflow_registry.py @@ -64,7 +64,9 @@ def test_save_model(self): ml = MLflowRegistry(TRACKING_URI) skeys = self.skeys dkeys = self.dkeys - status = ml.save(skeys=skeys, dkeys=dkeys, artifact=self.model, run_id="1234") + status = ml.save( + skeys=skeys, dkeys=dkeys, artifact=self.model, run_id="1234", artifact_type="pytorch" + ) mock_status = "READY" self.assertEqual(mock_status, status.status) @@ -79,7 +81,7 @@ def test_save_model_sklearn(self): ml = MLflowRegistry(TRACKING_URI) skeys = self.skeys dkeys = self.dkeys - status = ml.save(skeys=skeys, dkeys=dkeys, artifact=model) + status = ml.save(skeys=skeys, dkeys=dkeys, artifact=model, artifact_type="sklearn") mock_status = "READY" self.assertEqual(mock_status, status.status) @@ -96,7 +98,7 @@ def test_load_model_when_pytorch_model_exist1(self): ml = MLflowRegistry(TRACKING_URI) skeys = self.skeys dkeys = self.dkeys - ml.save(skeys=skeys, dkeys=dkeys, artifact=model, **{"lr": 0.01}) + ml.save(skeys=skeys, dkeys=dkeys, artifact=model, **{"lr": 0.01}, artifact_type="pytorch") data = ml.load(skeys=skeys, dkeys=dkeys, artifact_type="pytorch") self.assertIsNotNone(data.metadata) self.assertIsInstance(data.artifact, VanillaAE) @@ -113,7 +115,7 @@ def test_load_model_when_pytorch_model_exist2(self): ml = MLflowRegistry(TRACKING_URI, models_to_retain=2) skeys = self.skeys dkeys = self.dkeys - ml.save(skeys=skeys, dkeys=dkeys, artifact=model) + ml.save(skeys=skeys, dkeys=dkeys, artifact=model, artifact_type="pytorch") data = ml.load(skeys=skeys, dkeys=dkeys, artifact_type="pytorch") self.assertEqual(data.metadata, {}) self.assertIsInstance(data.artifact, VanillaAE) @@ -139,8 +141,8 @@ def test_load_model_when_sklearn_model_exist(self): skeys = self.skeys dkeys = self.dkeys scaler = StandardScaler() - ml.save(skeys=skeys, dkeys=dkeys, artifact=scaler) - data = ml.load(skeys=skeys, dkeys=dkeys) + ml.save(skeys=skeys, dkeys=dkeys, artifact=scaler, artifact_type="sklearn") + data = ml.load(skeys=skeys, dkeys=dkeys, artifact_type="sklearn") print(data) self.assertIsInstance(data.artifact, StandardScaler) self.assertEqual(data.metadata, {}) @@ -158,8 +160,8 @@ def test_load_model_with_version(self): ml = MLflowRegistry(TRACKING_URI) skeys = self.skeys dkeys = self.dkeys - ml.save(skeys=skeys, dkeys=dkeys, artifact=model) - data = ml.load(skeys=skeys, dkeys=dkeys, version="5", latest=False) + ml.save(skeys=skeys, dkeys=dkeys, artifact=model, artifact_type="pytorch") + data = ml.load(skeys=skeys, dkeys=dkeys, version="5", latest=False, artifact_type="pytorch") self.assertIsInstance(data.artifact, VanillaAE) self.assertEqual(data.metadata, {}) @@ -175,7 +177,7 @@ def test_staging_model_load_error(self): ml = MLflowRegistry(TRACKING_URI, model_stage=ModelStage.STAGE) skeys = self.skeys dkeys = self.dkeys - ml.load(skeys=skeys, dkeys=dkeys) + ml.load(skeys=skeys, dkeys=dkeys, artifact_type="pytorch") self.assertRaises(ModelVersionError) @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) @@ -188,7 +190,7 @@ def test_both_version_latest_model_with_version(self): skeys = self.skeys dkeys = self.dkeys with self.assertRaises(ValueError): - ml.load(skeys=skeys, dkeys=dkeys, latest=False) + ml.load(skeys=skeys, dkeys=dkeys, latest=False, artifact_type="pytorch") @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) @@ -211,7 +213,11 @@ def test_load_model_when_no_model_02(self): fake_dkeys = ["error"] ml = MLflowRegistry(TRACKING_URI) with self.assertLogs(level="ERROR") as log: - ml.load(skeys=fake_skeys, dkeys=fake_dkeys, artifact_type="pytorch") + ml.load( + skeys=fake_skeys, + dkeys=fake_dkeys, + artifact_type="pytorch", + ) self.assertTrue(log.output) @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) @@ -237,6 +243,9 @@ def test_no_implementation(self): with self.assertLogs(level="ERROR") as log: ml.load(skeys=fake_skeys, dkeys=fake_dkeys, artifact_type="somerandom") self.assertTrue(log.output) + with self.assertLogs(level="ERROR") as log: + ml.load(skeys=fake_skeys, dkeys=fake_dkeys) + self.assertTrue(log.output) @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict()))) @patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict())) @@ -252,7 +261,7 @@ def test_delete_model_when_model_exist(self): ml = MLflowRegistry(TRACKING_URI) skeys = self.skeys dkeys = self.dkeys - ml.save(skeys=skeys, dkeys=dkeys, artifact=model, **{"lr": 0.01}) + ml.save(skeys=skeys, dkeys=dkeys, artifact=model, artifact_type="pytorch", **{"lr": 0.01}) ml.delete(skeys=skeys, dkeys=dkeys, version="5") with self.assertLogs(level="ERROR") as log: ml.load(skeys=skeys, dkeys=dkeys) @@ -276,7 +285,9 @@ def test_save_failed(self): ml = MLflowRegistry(TRACKING_URI) with self.assertLogs(level="ERROR") as log: - ml.save(skeys=fake_skeys, dkeys=fake_dkeys, artifact=self.model) + ml.save( + skeys=fake_skeys, dkeys=fake_dkeys, artifact=self.model, artifact_type="pytorch" + ) self.assertTrue(log.output) @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict()))) @@ -290,7 +301,11 @@ def test_load_no_model_found(self): ml = MLflowRegistry(TRACKING_URI) skeys = self.skeys dkeys = self.dkeys - data = ml.load(skeys=skeys, dkeys=dkeys, artifact_type="pytorch") + data = ml.load( + skeys=skeys, + dkeys=dkeys, + artifact_type="pytorch", + ) self.assertIsNone(data) @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict()))) @@ -317,7 +332,13 @@ def test_load_other_mlflow_err(self): def test_is_model_stale_true(self): model = self.model ml = MLflowRegistry(TRACKING_URI) - ml.save(skeys=self.skeys, dkeys=self.dkeys, artifact=model, **{"lr": 0.01}) + ml.save( + skeys=self.skeys, + dkeys=self.dkeys, + artifact=model, + **{"lr": 0.01}, + artifact_type="pytorch", + ) data = ml.load(skeys=self.skeys, dkeys=self.dkeys, artifact_type="pytorch") self.assertTrue(ml.is_artifact_stale(data, 12)) @@ -332,7 +353,13 @@ def test_is_model_stale_true(self): def test_is_model_stale_false(self): model = self.model ml = MLflowRegistry(TRACKING_URI) - ml.save(skeys=self.skeys, dkeys=self.dkeys, artifact=model, **{"lr": 0.01}) + ml.save( + skeys=self.skeys, + dkeys=self.dkeys, + artifact=model, + **{"lr": 0.01}, + artifact_type="pytorch", + ) data = ml.load(skeys=self.skeys, dkeys=self.dkeys, artifact_type="pytorch") with freeze_time("2022-05-24 10:30:00"): self.assertFalse(ml.is_artifact_stale(data, 12)) @@ -365,7 +392,13 @@ def test_cache(self): def test_cache_loading(self): cache_registry = LocalLRUCache(ttl=50000) ml = MLflowRegistry(TRACKING_URI, cache_registry=cache_registry) - ml.save(skeys=self.skeys, dkeys=self.dkeys, artifact=self.model, **{"lr": 0.01}) + ml.save( + skeys=self.skeys, + dkeys=self.dkeys, + artifact=self.model, + **{"lr": 0.01}, + artifact_type="pytorch", + ) ml.load(skeys=self.skeys, dkeys=self.dkeys, artifact_type="pytorch") key = MLflowRegistry.construct_key(self.skeys, self.dkeys) self.assertIsNotNone(ml._load_from_cache(key))