Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement: multiple save and load for mlflow registry #416

Merged
merged 7 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ target/
#mlflow
/.mlruns
*.db
mlruns/
mlartifacts/

# Jupyter Notebook
.ipynb_checkpoints
Expand Down Expand Up @@ -169,4 +171,7 @@ cython_debug/
# Mac related
*.DS_Store

# vscode
.vscode/

.python-version
126 changes: 121 additions & 5 deletions numalogic/registry/mlflow_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@
from enum import Enum
from typing import Optional, Any

import mlflow.pyfunc
import mlflow.pytorch
import mlflow.sklearn
import mlflow
from mlflow.entities.model_registry import ModelVersion
from mlflow.exceptions import RestException
from mlflow.exceptions import RestException, MlflowException
from mlflow.protos.databricks_pb2 import ErrorCode, RESOURCE_DOES_NOT_EXIST
from mlflow.tracking import MlflowClient

Expand Down Expand Up @@ -187,6 +185,43 @@ def load(
self._save_in_cache(model_key, artifact_data)
return artifact_data

def load_multiple(
self,
skeys: KEYS,
dkeys: KEYS,
) -> Optional[ArtifactData]:
"""
Load multiple artifacts from the registry for pyfunc models.
Args:
skeys (KEYS): The source keys of the artifacts to load.
dkeys: dynamic key fields as list/tuple of strings.

Returns
-------
Optional[ArtifactData]: The loaded ArtifactData object if available otherwise None.
ArtifactData should contain a dictionary of artifacts.
"""
loaded_model = self.load(skeys=skeys, dkeys=dkeys, artifact_type="pyfunc")
if loaded_model is None:
return None

try:
unwrapped_composite_model = loaded_model.artifact.unwrap_python_model()
except MlflowException as e:
raise TypeError("The loaded model is not a valid pyfunc Python model.") from e
except AttributeError:
_LOGGER.exception("The loaded model does not have an unwrap_python_model method")
return None
except Exception:
_LOGGER.exception("Unexpected error occurred while unwrapping python model.")
return None

return ArtifactData(
artifact=unwrapped_composite_model.dict_artifacts,
metadata=loaded_model.metadata,
extras=loaded_model.extras,
)

@staticmethod
def __log_mlflow_err(mlflow_err: RestException, model_key: str) -> None:
if ErrorCode.Value(mlflow_err.error_code) == RESOURCE_DOES_NOT_EXIST:
Expand Down Expand Up @@ -225,7 +260,10 @@ def save(
handler = self.handler_from_type(artifact_type)
try:
mlflow.start_run(run_id=run_id)
handler.log_model(artifact, "model", registered_model_name=model_key)
if artifact_type == "pyfunc":
handler.log_model("model", python_model=artifact, registered_model_name=model_key)
else:
handler.log_model(artifact, "model", registered_model_name=model_key)
if metadata:
mlflow.log_params(metadata)
model_version = self.transition_stage(skeys=skeys, dkeys=dkeys)
Expand All @@ -238,6 +276,42 @@ def save(
finally:
mlflow.end_run()

def save_multiple(
self,
skeys: KEYS,
dkeys: KEYS,
dict_artifacts: dict[str, artifact_t],
**metadata: META_VT,
) -> Optional[ModelVersion]:
"""
Saves multiple artifacts into mlflow registry. The last save stores all the
artifact versions in the metadata.

Args:
----
skeys (KEYS): Static key fields as a list or tuple of strings.
dkeys (KEYS): Dynamic key fields as a list or tuple of strings.
dict_artifacts (dict[str, artifact_t]): Dictionary of artifacts to save.
**metadata (META_VT): Additional metadata to be saved with the artifacts.

Returns
-------
Optional[ModelVersion]: An instance of the MLflow ModelVersion.

"""
if len(dict_artifacts) == 1:
_LOGGER.warning(
"Only one artifact present in dict_artifacts. Saving directly is recommended."
)
multiple_artifacts = CompositeModel(skeys=skeys, dict_artifacts=dict_artifacts, **metadata)
return self.save(
skeys=skeys,
dkeys=dkeys,
artifact=multiple_artifacts,
artifact_type="pyfunc",
**metadata,
)

@staticmethod
def is_artifact_stale(artifact_data: ArtifactData, freq_hr: int) -> bool:
"""Returns whether the given artifact is stale or not, i.e. if
Expand Down Expand Up @@ -338,3 +412,45 @@ def __load_artifacts(
version_info.version,
)
return model, metadata


class CompositeModel(mlflow.pyfunc.PythonModel):
"""A composite model that represents multiple artifacts.

This class extends the `mlflow.pyfunc.PythonModel` class and is used to store and load
multiple artifacts in the MLflow registry. It provides a convenient way to manage and
organize multiple artifacts associated with a single model.

Args:
skeys (KEYS): The static keys of the artifacts.
dict_artifacts (dict[str, KeyedArtifact]): A dictionary mapping dynamic keys to
`KeyedArtifact` objects.
**metadata (META_VT): Additional metadata associated with the artifacts.

Methods
-------
predict: Not implemented for our use case.

Attributes
----------
skeys (KEYS): The static keys of the artifacts.
dict_artifacts (dict[str, KeyedArtifact]): A dictionary mapping dynamic keys to
`KeyedArtifact` objects.
metadata (META_VT): Additional metadata associated with the artifacts.
"""

__slots__ = ("skeys", "dict_artifacts", "metadata")

def __init__(self, skeys: KEYS, dict_artifacts: dict[str, artifact_t], **metadata: META_VT):
self.skeys = skeys
self.dict_artifacts = dict_artifacts
self.metadata = metadata

def predict(self, context, model_input, params: Optional[dict[str, Any]] = None):
"""
Predict method is not implemented for our use case.

The CompositeModel class is designed to store and load multiple artifacts,
and the predict method is not required for this functionality.
"""
raise NotImplementedError("The predict method is not implemented for CompositeModel.")
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "numalogic"
version = "0.13.2"
version = "0.13.3"
description = "Collection of operational Machine Learning models and tools."
authors = ["Numalogic Developers"]
packages = [{ include = "numalogic" }]
Expand Down
124 changes: 124 additions & 0 deletions tests/registry/_mlflow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
from mlflow.store.entities import PagedList
from sklearn.preprocessing import StandardScaler
from torch import tensor
from mlflow.models import Model

from numalogic.models.autoencoder.variants.vanilla import VanillaAE
from numalogic.models.threshold import StdDevThreshold
from numalogic.registry.mlflow_registry import CompositeModel


def create_model():
Expand Down Expand Up @@ -135,6 +138,103 @@ def mock_log_model_sklearn(*_, **__):
)


def mock_log_model_pyfunc(*_, **__):
return ModelInfo(
artifact_path="model",
flavors={
"pyfunc": {"model_data": "data", "pyfunc_version": "1.11.0", "code": None},
"python_function": {
"pickle_module_name": "mlflow.pyfunc.pickle_module",
"loader_module": "mlflow.pyfunc",
"python_version": "3.8.5",
"data": "data",
"env": "conda.yaml",
},
},
model_uri="runs:/a7c0b376530b40d7b23e6ce2081c899c/model",
model_uuid="a7c0b376530b40d7b23e6ce2081c899c",
run_id="a7c0b376530b40d7b23e6ce2081c899c",
saved_input_example_info=None,
signature_dict=None,
utc_time_created="2022-05-23 22:35:59.557372",
mlflow_version="2.0.1",
signature=None,
)


def mock_load_model_pyfunc(*_, **__):
artifact_path = "model"
flavors = {
"python_function": {
"cloudpickle_version": "3.0.0",
"code": None,
"env": {"conda": "conda.yaml", "virtualenv": "python_env.yaml"},
"loader_module": "mlflow.pyfunc.model",
"python_model": "python_model.pkl",
"python_version": "3.10.14",
"streamable": False,
}
}
model_size_bytes = 8912
model_uuid = "ae27ecc166c94c01a4f4dccaf84ca5dc"
run_id = "7e85a3fa46d44e668c840f3dddc909c3"
utc_time_created = "2024-09-18 17:12:41.501209"
model = Model(
artifact_path=artifact_path,
flavors=flavors,
model_size_bytes=model_size_bytes,
model_uuid=model_uuid,
run_id=run_id,
utc_time_created=utc_time_created,
mlflow_version="2.16.0",
)
return mlflow.pyfunc.PyFuncModel(
model_meta=model,
model_impl=TestObject(
python_model=CompositeModel(
skeys=["error"],
dict_artifacts={
"inference": VanillaAE(10),
"precrocessing": StandardScaler(),
"threshold": StdDevThreshold(),
},
**{"learning_rate": 0.01},
)
),
)


def mock_load_model_pyfunc_type_error(*_, **__):
artifact_path = "model"
flavors = {
"python_function": {
"cloudpickle_version": "3.0.0",
"code": None,
"env": {"conda": "conda.yaml", "virtualenv": "python_env.yaml"},
"loader_module": "mlflow.pytorch.model",
"python_model": "python_model.pkl",
"python_version": "3.10.14",
"streamable": False,
}
}
model_size_bytes = 8912
model_uuid = "ae27ecc166c94c01a4f4dccaf84ca5dc"
run_id = "7e85a3fa46d44e668c840f3dddc909c3"
utc_time_created = "2024-09-18 17:12:41.501209"
model = Model(
artifact_path=artifact_path,
flavors=flavors,
model_size_bytes=model_size_bytes,
model_uuid=model_uuid,
run_id=run_id,
utc_time_created=utc_time_created,
mlflow_version="2.16.0",
)
return mlflow.pyfunc.PyFuncModel(
model_meta=model, model_impl=mlflow.pytorch._PyTorchWrapper(VanillaAE(10), device="cpu")
)


def mock_transition_stage(*_, **__):
return ModelVersion(
creation_timestamp=1653402941169,
Expand Down Expand Up @@ -303,6 +403,25 @@ def return_sklearn_rundata():
)


def return_pyfunc_rundata():
return Run(
run_info=RunInfo(
artifact_uri="mlflow-artifacts:/0/a7c0b376530b40d7b23e6ce2081c899c/artifacts/model",
end_time=None,
experiment_id="0",
lifecycle_stage="active",
run_id="a7c0b376530b40d7b23e6ce2081c899c",
run_uuid="a7c0b376530b40d7b23e6ce2081c899c",
start_time=1658788772612,
status="RUNNING",
user_id="lol",
),
run_data=RunData(
metrics={}, tags={}, params=[mlflow.entities.Param("learning_rate", "0.01")]
),
)


def return_pytorch_rundata_dict():
return Run(
run_info=RunInfo(
Expand All @@ -318,3 +437,8 @@ def return_pytorch_rundata_dict():
),
run_data=RunData(metrics={}, tags={}, params=[mlflow.entities.Param("lr", "0.001")]),
)


class TestObject(mlflow.pyfunc.PythonModel):
def __init__(self, python_model):
self.python_model = python_model
Loading
Loading