Skip to content

Commit

Permalink
fix: support all sklearn classes (#408)
Browse files Browse the repository at this point in the history
1. Support artifact-type in mlflow registry from user.
2. Update docs

---------

Signed-off-by: Kushal Batra <[email protected]>
  • Loading branch information
s0nicboOm authored Sep 5, 2024
1 parent aa3b4bb commit c4d3bcc
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 39 deletions.
6 changes: 3 additions & 3 deletions docs/ml-flow.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Numalogic provides `MLflowRegistry`, to save and load models to/from MLflow.

Here, `tracking_uri` is the uri where mlflow server is running. The `static_keys` and `dynamic_keys` are used to form a unique key for the model.

The `artifact` would be the model or transformer object that needs to be saved.
The `artifact` would be the model or transformer object that needs to be saved. Artifact saving also takes in 'artifact_type' which is the type of the artifact being saved. Currently, 'pytorch', 'sklearn' and 'pyfunc' is supported.
A dictionary of metadata can also be saved along with the artifact.
```python
from numalogic.registry import MLflowRegistry
Expand All @@ -37,13 +37,13 @@ dynamic_keys = ["vanilla", "seq10"]

registry = MLflowRegistry(tracking_uri="http://0.0.0.0:5000")
registry.save(
skeys=static_keys, dkeys=dynamic_keys, artifact=model, seq_len=10, lr=0.001
skeys=static_keys, dkeys=dynamic_keys, artifact_type='pytorch', artifact=model, seq_len=10, lr=0.001
)
```

### Model loading

Once, the models are save to MLflow, the `load` function of `MLflowRegistry` can be used to load the model.
Once, the models are save to MLflow, the `load` function of `MLflowRegistry` can be used to load the model. Like how the artifacts were saved with 'artifact_type', the same type shall be passed to the `load` function as well.

```python
from numalogic.registry import MLflowRegistry
Expand Down
13 changes: 9 additions & 4 deletions examples/multi_udf/src/udf/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,15 @@ def __init__(self):
self.model_key = "ae::model"

def _save_artifact(
self, model, skeys: list[str], dkeys: list[str], _: Optional[TimeseriesTrainer] = None
self,
model,
skeys: list[str],
dkeys: list[str],
artifact_type: str,
_: Optional[TimeseriesTrainer] = None,
) -> None:
"""Saves the model in the registry."""
self.registry.save(skeys=skeys, dkeys=dkeys, artifact=model)
self.registry.save(skeys=skeys, dkeys=dkeys, artifact=model, artifact_type=artifact_type)

@staticmethod
def _fit_preprocess(data: pd.DataFrame) -> npt.NDArray[float]:
Expand Down Expand Up @@ -93,8 +98,8 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
thresh_clf = self._fit_threshold(train_reconerr.numpy())

# Save to registry
self._save_artifact(model, ["ae"], ["model"], trainer)
self._save_artifact(thresh_clf, ["thresh_clf"], ["model"])
self._save_artifact(model, ["ae"], ["model"], "pytorch", trainer)
self._save_artifact(thresh_clf, ["thresh_clf"], ["model"], artifact_type="sklearn")
LOGGER.info("%s - Model Saving complete", payload.uuid)

# Train is the last vertex in the graph
Expand Down
29 changes: 14 additions & 15 deletions numalogic/registry/mlflow_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@
from mlflow.exceptions import RestException
from mlflow.protos.databricks_pb2 import ErrorCode, RESOURCE_DOES_NOT_EXIST
from mlflow.tracking import MlflowClient
from torch import nn

from numalogic.base import BaseThresholdModel, BaseTransformer
from numalogic.registry import ArtifactManager, ArtifactData
from numalogic.registry.artifact import ArtifactCache
from numalogic.tools.exceptions import ModelVersionError
Expand Down Expand Up @@ -65,7 +63,8 @@ class MLflowRegistry(ArtifactManager):
>>> data = [[0, 0], [0, 0], [1, 1], [1, 1]]
>>> scaler = StandardScaler.fit(data)
>>> registry = MLflowRegistry(tracking_uri="http://0.0.0.0:8080")
>>> registry.save(skeys=["model"], dkeys=["AE"], artifact=VanillaAE(10))
>>> registry.save(skeys=["model"], dkeys=["AE"], artifact=VanillaAE(10),
>>> artifact_type="pytorch")
>>> artifact_data = registry.load(skeys=["model"], dkeys=["AE"], artifact_type="pytorch")
"""

Expand Down Expand Up @@ -100,17 +99,13 @@ def __init__(
self.model_stage = model_stage
self.cache_registry = cache_registry

@staticmethod
def handler_from_obj(artifact: artifact_t):
if isinstance(artifact, nn.Module):
return mlflow.pytorch
if isinstance(artifact, (BaseThresholdModel, BaseTransformer)):
return mlflow.sklearn
return mlflow.pyfunc

@staticmethod
def handler_from_type(artifact_type: str):
"""Helper method to return the right handler given the artifact type."""
if not artifact_type:
raise ValueError(
"Artifact Type not provided. Options include: {pytorch, sklearn, pyfunc}"
)
if artifact_type == "pytorch":
return mlflow.pytorch
if artifact_type == "sklearn":
Expand All @@ -137,9 +132,9 @@ def load(
self,
skeys: KEYS,
dkeys: KEYS,
artifact_type: Optional[str] = None,
latest: bool = True,
version: Optional[str] = None,
artifact_type: str = "pytorch",
) -> Optional[ArtifactData]:
"""Load the artifact from the registry. The artifact is loaded from the cache if available.
Expand All @@ -149,7 +144,8 @@ def load(
dkeys: Dynamic keys
latest: Load the latest version of the model (default = True)
version: Version of the model to load (default = None)
artifact_type: Type of the artifact to load (default = "pytorch").
artifact_type: Type of the artifact to load. Options include: pytorch, pyfunc
and sklearn.
Returns
-------
Expand Down Expand Up @@ -205,6 +201,7 @@ def save(
dkeys: KEYS,
artifact: artifact_t,
run_id: Optional[str] = None,
artifact_type: Optional[str] = None,
**metadata: META_VT,
) -> Optional[ModelVersion]:
"""Saves the artifact into mlflow registry and updates version.
Expand All @@ -216,13 +213,15 @@ def save(
artifact: primary artifact to be saved
run_id: mlflow run id
metadata: additional metadata surrounding the artifact that needs to be saved.
artifact_type: Type of the artifact to save. Options include: pytorch, pyfunc
and sklearn.
Returns
-------
mlflow ModelVersion instance
"""
model_key = self.construct_key(skeys, dkeys)
handler = self.handler_from_obj(artifact)
handler = self.handler_from_type(artifact_type)
try:
mlflow.start_run(run_id=run_id)
handler.log_model(artifact, "model", registered_model_name=model_key)
Expand All @@ -241,7 +240,7 @@ def save(
@staticmethod
def is_artifact_stale(artifact_data: ArtifactData, freq_hr: int) -> bool:
"""Returns whether the given artifact is stale or not, i.e. if
more time has elasped since it was last retrained.
more time has elapsed since it was last retrained.
Args:
----
Expand Down
67 changes: 50 additions & 17 deletions tests/registry/test_mlflow_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def test_save_model(self):
ml = MLflowRegistry(TRACKING_URI)
skeys = self.skeys
dkeys = self.dkeys
status = ml.save(skeys=skeys, dkeys=dkeys, artifact=self.model, run_id="1234")
status = ml.save(
skeys=skeys, dkeys=dkeys, artifact=self.model, run_id="1234", artifact_type="pytorch"
)
mock_status = "READY"
self.assertEqual(mock_status, status.status)

Expand All @@ -79,7 +81,7 @@ def test_save_model_sklearn(self):
ml = MLflowRegistry(TRACKING_URI)
skeys = self.skeys
dkeys = self.dkeys
status = ml.save(skeys=skeys, dkeys=dkeys, artifact=model)
status = ml.save(skeys=skeys, dkeys=dkeys, artifact=model, artifact_type="sklearn")
mock_status = "READY"
self.assertEqual(mock_status, status.status)

Expand All @@ -96,7 +98,7 @@ def test_load_model_when_pytorch_model_exist1(self):
ml = MLflowRegistry(TRACKING_URI)
skeys = self.skeys
dkeys = self.dkeys
ml.save(skeys=skeys, dkeys=dkeys, artifact=model, **{"lr": 0.01})
ml.save(skeys=skeys, dkeys=dkeys, artifact=model, **{"lr": 0.01}, artifact_type="pytorch")
data = ml.load(skeys=skeys, dkeys=dkeys, artifact_type="pytorch")
self.assertIsNotNone(data.metadata)
self.assertIsInstance(data.artifact, VanillaAE)
Expand All @@ -113,7 +115,7 @@ def test_load_model_when_pytorch_model_exist2(self):
ml = MLflowRegistry(TRACKING_URI, models_to_retain=2)
skeys = self.skeys
dkeys = self.dkeys
ml.save(skeys=skeys, dkeys=dkeys, artifact=model)
ml.save(skeys=skeys, dkeys=dkeys, artifact=model, artifact_type="pytorch")
data = ml.load(skeys=skeys, dkeys=dkeys, artifact_type="pytorch")
self.assertEqual(data.metadata, {})
self.assertIsInstance(data.artifact, VanillaAE)
Expand All @@ -139,8 +141,8 @@ def test_load_model_when_sklearn_model_exist(self):
skeys = self.skeys
dkeys = self.dkeys
scaler = StandardScaler()
ml.save(skeys=skeys, dkeys=dkeys, artifact=scaler)
data = ml.load(skeys=skeys, dkeys=dkeys)
ml.save(skeys=skeys, dkeys=dkeys, artifact=scaler, artifact_type="sklearn")
data = ml.load(skeys=skeys, dkeys=dkeys, artifact_type="sklearn")
print(data)
self.assertIsInstance(data.artifact, StandardScaler)
self.assertEqual(data.metadata, {})
Expand All @@ -158,8 +160,8 @@ def test_load_model_with_version(self):
ml = MLflowRegistry(TRACKING_URI)
skeys = self.skeys
dkeys = self.dkeys
ml.save(skeys=skeys, dkeys=dkeys, artifact=model)
data = ml.load(skeys=skeys, dkeys=dkeys, version="5", latest=False)
ml.save(skeys=skeys, dkeys=dkeys, artifact=model, artifact_type="pytorch")
data = ml.load(skeys=skeys, dkeys=dkeys, version="5", latest=False, artifact_type="pytorch")
self.assertIsInstance(data.artifact, VanillaAE)
self.assertEqual(data.metadata, {})

Expand All @@ -175,7 +177,7 @@ def test_staging_model_load_error(self):
ml = MLflowRegistry(TRACKING_URI, model_stage=ModelStage.STAGE)
skeys = self.skeys
dkeys = self.dkeys
ml.load(skeys=skeys, dkeys=dkeys)
ml.load(skeys=skeys, dkeys=dkeys, artifact_type="pytorch")
self.assertRaises(ModelVersionError)

@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
Expand All @@ -188,7 +190,7 @@ def test_both_version_latest_model_with_version(self):
skeys = self.skeys
dkeys = self.dkeys
with self.assertRaises(ValueError):
ml.load(skeys=skeys, dkeys=dkeys, latest=False)
ml.load(skeys=skeys, dkeys=dkeys, latest=False, artifact_type="pytorch")

@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
Expand All @@ -211,7 +213,11 @@ def test_load_model_when_no_model_02(self):
fake_dkeys = ["error"]
ml = MLflowRegistry(TRACKING_URI)
with self.assertLogs(level="ERROR") as log:
ml.load(skeys=fake_skeys, dkeys=fake_dkeys, artifact_type="pytorch")
ml.load(
skeys=fake_skeys,
dkeys=fake_dkeys,
artifact_type="pytorch",
)
self.assertTrue(log.output)

@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
Expand All @@ -237,6 +243,9 @@ def test_no_implementation(self):
with self.assertLogs(level="ERROR") as log:
ml.load(skeys=fake_skeys, dkeys=fake_dkeys, artifact_type="somerandom")
self.assertTrue(log.output)
with self.assertLogs(level="ERROR") as log:
ml.load(skeys=fake_skeys, dkeys=fake_dkeys)
self.assertTrue(log.output)

@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
@patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict()))
Expand All @@ -252,7 +261,7 @@ def test_delete_model_when_model_exist(self):
ml = MLflowRegistry(TRACKING_URI)
skeys = self.skeys
dkeys = self.dkeys
ml.save(skeys=skeys, dkeys=dkeys, artifact=model, **{"lr": 0.01})
ml.save(skeys=skeys, dkeys=dkeys, artifact=model, artifact_type="pytorch", **{"lr": 0.01})
ml.delete(skeys=skeys, dkeys=dkeys, version="5")
with self.assertLogs(level="ERROR") as log:
ml.load(skeys=skeys, dkeys=dkeys)
Expand All @@ -276,7 +285,9 @@ def test_save_failed(self):

ml = MLflowRegistry(TRACKING_URI)
with self.assertLogs(level="ERROR") as log:
ml.save(skeys=fake_skeys, dkeys=fake_dkeys, artifact=self.model)
ml.save(
skeys=fake_skeys, dkeys=fake_dkeys, artifact=self.model, artifact_type="pytorch"
)
self.assertTrue(log.output)

@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
Expand All @@ -290,7 +301,11 @@ def test_load_no_model_found(self):
ml = MLflowRegistry(TRACKING_URI)
skeys = self.skeys
dkeys = self.dkeys
data = ml.load(skeys=skeys, dkeys=dkeys, artifact_type="pytorch")
data = ml.load(
skeys=skeys,
dkeys=dkeys,
artifact_type="pytorch",
)
self.assertIsNone(data)

@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
Expand All @@ -317,7 +332,13 @@ def test_load_other_mlflow_err(self):
def test_is_model_stale_true(self):
model = self.model
ml = MLflowRegistry(TRACKING_URI)
ml.save(skeys=self.skeys, dkeys=self.dkeys, artifact=model, **{"lr": 0.01})
ml.save(
skeys=self.skeys,
dkeys=self.dkeys,
artifact=model,
**{"lr": 0.01},
artifact_type="pytorch",
)
data = ml.load(skeys=self.skeys, dkeys=self.dkeys, artifact_type="pytorch")
self.assertTrue(ml.is_artifact_stale(data, 12))

Expand All @@ -332,7 +353,13 @@ def test_is_model_stale_true(self):
def test_is_model_stale_false(self):
model = self.model
ml = MLflowRegistry(TRACKING_URI)
ml.save(skeys=self.skeys, dkeys=self.dkeys, artifact=model, **{"lr": 0.01})
ml.save(
skeys=self.skeys,
dkeys=self.dkeys,
artifact=model,
**{"lr": 0.01},
artifact_type="pytorch",
)
data = ml.load(skeys=self.skeys, dkeys=self.dkeys, artifact_type="pytorch")
with freeze_time("2022-05-24 10:30:00"):
self.assertFalse(ml.is_artifact_stale(data, 12))
Expand Down Expand Up @@ -365,7 +392,13 @@ def test_cache(self):
def test_cache_loading(self):
cache_registry = LocalLRUCache(ttl=50000)
ml = MLflowRegistry(TRACKING_URI, cache_registry=cache_registry)
ml.save(skeys=self.skeys, dkeys=self.dkeys, artifact=self.model, **{"lr": 0.01})
ml.save(
skeys=self.skeys,
dkeys=self.dkeys,
artifact=self.model,
**{"lr": 0.01},
artifact_type="pytorch",
)
ml.load(skeys=self.skeys, dkeys=self.dkeys, artifact_type="pytorch")
key = MLflowRegistry.construct_key(self.skeys, self.dkeys)
self.assertIsNotNone(ml._load_from_cache(key))
Expand Down

0 comments on commit c4d3bcc

Please sign in to comment.