diff --git a/python/interpret-core/dev-requirements.txt b/python/interpret-core/dev-requirements.txt index c65ba3cca..b036427de 100644 --- a/python/interpret-core/dev-requirements.txt +++ b/python/interpret-core/dev-requirements.txt @@ -1,2 +1,3 @@ # Required dependencies .[required,debug,notebook,plotly,lime,sensitivity,shap,ebm,linear,decisiontree,treeinterpreter,dash,skoperules,testing] +mlflow diff --git a/python/interpret-core/interpret/glassbox/mlflow/.gitignore b/python/interpret-core/interpret/glassbox/mlflow/.gitignore new file mode 100644 index 000000000..e6ee07331 --- /dev/null +++ b/python/interpret-core/interpret/glassbox/mlflow/.gitignore @@ -0,0 +1 @@ +mlruns diff --git a/python/interpret-core/interpret/glassbox/mlflow/__init__.py b/python/interpret-core/interpret/glassbox/mlflow/__init__.py new file mode 100644 index 000000000..86eb871eb --- /dev/null +++ b/python/interpret-core/interpret/glassbox/mlflow/__init__.py @@ -0,0 +1,76 @@ +import json +import os +import yaml + +from tempfile import TemporaryDirectory + +import interpret + + +def load_model(*args, **kwargs): + import mlflow.pyfunc + return mlflow.pyfunc.load_model(*args, **kwargs) + + +def _sanitize_explanation_data(data): # TODO Explanations should have a to_json() + if isinstance(data, dict): + for key, val in data.items(): + data[key] = _sanitize_explanation_data(data[key]) + return data + + elif isinstance(data, list): + return [_sanitize_explanation_data[x] for x in data] + else: + # numpy type conversion to python https://stackoverflow.com/questions/9452775 primitive + return data.item() if hasattr(data, "item") else data + + +def _load_pyfunc(path): + import cloudpickle as pickle + with open(os.path.join(path, "model.pkl"), "rb") as f: + return pickle.load(f) + + +def _save_model(model, output_path): + import cloudpickle as pickle + if not os.path.exists(output_path): + os.mkdir(output_path) + with open(os.path.join(output_path, "model.pkl"), "wb") as stream: + pickle.dump(model, stream) + try: + with open(os.path.join(output_path, "global_explanation.json"), "w") as stream: + data = model.explain_global().data(-1)["mli"] + if isinstance(data, list): + data = data[0] + if "global" not in data["explanation_type"]: + raise Exception("Invalid explanation, not global") + for key in data: + if isinstance(data[key], list): + data[key] = [float(x) for x in data[key]] + json.dump(data, stream) + except ValueError as e: + raise Exception("Unsupported glassbox model type {}. Failed with error {}.".format(type(model), e)) + +def log_model(path, model): + try: + import mlflow.pyfunc + except ImportError as e: + raise Exception("Could not log_model to mlflow. Missing mlflow dependency, pip install mlflow to resolve the error: {}.".format(e)) + import cloudpickle as pickle + + with TemporaryDirectory() as tempdir: + _save_model(model, tempdir) + + conda_env = {"name": "mlflow-env", + "channels": ["defaults"], + "dependencies": ["pip", + {"pip": [ + "interpret=={}".format(interpret.version.__version__), + "cloudpickle=={}".format(pickle.__version__)] + } + ] + } + conda_path = os.path.join(tempdir, "conda.yaml") # TODO Open issue and bug fix for dict support + with open(conda_path, "w") as stream: + yaml.dump(conda_env, stream) + mlflow.pyfunc.log_model(path, loader_module="interpret.glassbox.mlflow", data_path=tempdir, conda_env=conda_path) diff --git a/python/interpret-core/interpret/glassbox/mlflow/test/test_save_load_model.py b/python/interpret-core/interpret/glassbox/mlflow/test/test_save_load_model.py new file mode 100644 index 000000000..f2d84548c --- /dev/null +++ b/python/interpret-core/interpret/glassbox/mlflow/test/test_save_load_model.py @@ -0,0 +1,45 @@ +# Copyright (c) 2019 Microsoft Corporation +# Distributed under the MIT software license + +import json +import os + +import pytest + +from sklearn.datasets import load_breast_cancer, load_boston +from sklearn.linear_model import LogisticRegression as SKLogistic +from sklearn.linear_model import Lasso as SKLinear + +from interpret.glassbox.linear import LogisticRegression, LinearRegression +from interpret.glassbox.mlflow import load_model, log_model + + +@pytest.fixture() +def glassbox_model(): + boston = load_boston() + return LinearRegression(feature_names=boston.feature_names, random_state=1) + + +@pytest.fixture() +def model(): + return SKLinear(random_state=1) + + +def test_linear_regression_save_load(glassbox_model, model): + boston = load_boston() + X, y = boston.data, boston.target + + model.fit(X, y) + glassbox_model.fit(X, y) + + save_location = "save_location" + log_model(save_location, glassbox_model) + + + import mlflow + glassbox_model_loaded = load_model("runs:/{}/{}".format(mlflow.active_run().info.run_id, save_location)) + + name = "name" + explanation_glassbox_data = glassbox_model.explain_global(name).data(-1)["mli"] + explanation_glassbox_data_loaded = glassbox_model_loaded.explain_global(name).data(-1)["mli"] + assert explanation_glassbox_data == explanation_glassbox_data_loaded