From aa3b4bb04f3ea114eed58a21965b446252be356e Mon Sep 17 00:00:00 2001 From: Kushal Batra <34571348+s0nicboOm@users.noreply.github.com> Date: Wed, 4 Sep 2024 15:12:46 -0700 Subject: [PATCH] fix: support sklearn BaseEstimators (#407) The transformers and threshold model inherit BaseEstimator class. Fixes: > mlflow saving of models --------- Signed-off-by: Kushal Batra --- numalogic/registry/mlflow_registry.py | 4 ++-- pyproject.toml | 2 +- tests/registry/_mlflow_utils.py | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/numalogic/registry/mlflow_registry.py b/numalogic/registry/mlflow_registry.py index b9909881..97e39d73 100644 --- a/numalogic/registry/mlflow_registry.py +++ b/numalogic/registry/mlflow_registry.py @@ -21,9 +21,9 @@ from mlflow.exceptions import RestException from mlflow.protos.databricks_pb2 import ErrorCode, RESOURCE_DOES_NOT_EXIST from mlflow.tracking import MlflowClient -from sklearn.base import BaseEstimator from torch import nn +from numalogic.base import BaseThresholdModel, BaseTransformer from numalogic.registry import ArtifactManager, ArtifactData from numalogic.registry.artifact import ArtifactCache from numalogic.tools.exceptions import ModelVersionError @@ -104,7 +104,7 @@ def __init__( def handler_from_obj(artifact: artifact_t): if isinstance(artifact, nn.Module): return mlflow.pytorch - if isinstance(artifact, BaseEstimator): + if isinstance(artifact, (BaseThresholdModel, BaseTransformer)): return mlflow.sklearn return mlflow.pyfunc diff --git a/pyproject.toml b/pyproject.toml index 057f5e2a..0ecbf55b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "numalogic" -version = "0.12.4" +version = "0.13.0" description = "Collection of operational Machine Learning models and tools." authors = ["Numalogic Developers"] packages = [{ include = "numalogic" }] diff --git a/tests/registry/_mlflow_utils.py b/tests/registry/_mlflow_utils.py index fc75804d..4b61eddc 100644 --- a/tests/registry/_mlflow_utils.py +++ b/tests/registry/_mlflow_utils.py @@ -7,10 +7,11 @@ from mlflow.entities.model_registry import ModelVersion from mlflow.models.model import ModelInfo from mlflow.store.entities import PagedList -from sklearn.ensemble import RandomForestRegressor from sklearn.preprocessing import StandardScaler from torch import tensor +from numalogic.models.threshold import StdDevThreshold + def create_model(): x = torch.linspace(-math.pi, math.pi, 2000) @@ -33,8 +34,7 @@ def create_model(): def model_sklearn(): - params = {"n_estimators": 5, "random_state": 42} - return RandomForestRegressor(**params) + return StdDevThreshold() def mock_log_state_dict(*_, **__):