Skip to content

Commit

Permalink
fix: support sklearn BaseEstimators (#407)
Browse files Browse the repository at this point in the history
The transformers and threshold model inherit BaseEstimator class.

Fixes:
> mlflow saving of models

---------

Signed-off-by: Kushal Batra <[email protected]>
  • Loading branch information
s0nicboOm authored Sep 4, 2024
1 parent eb9ceb1 commit aa3b4bb
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions numalogic/registry/mlflow_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
from mlflow.exceptions import RestException
from mlflow.protos.databricks_pb2 import ErrorCode, RESOURCE_DOES_NOT_EXIST
from mlflow.tracking import MlflowClient
from sklearn.base import BaseEstimator
from torch import nn

from numalogic.base import BaseThresholdModel, BaseTransformer
from numalogic.registry import ArtifactManager, ArtifactData
from numalogic.registry.artifact import ArtifactCache
from numalogic.tools.exceptions import ModelVersionError
Expand Down Expand Up @@ -104,7 +104,7 @@ def __init__(
def handler_from_obj(artifact: artifact_t):
if isinstance(artifact, nn.Module):
return mlflow.pytorch
if isinstance(artifact, BaseEstimator):
if isinstance(artifact, (BaseThresholdModel, BaseTransformer)):
return mlflow.sklearn
return mlflow.pyfunc

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "numalogic"
version = "0.12.4"
version = "0.13.0"
description = "Collection of operational Machine Learning models and tools."
authors = ["Numalogic Developers"]
packages = [{ include = "numalogic" }]
Expand Down
6 changes: 3 additions & 3 deletions tests/registry/_mlflow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from mlflow.entities.model_registry import ModelVersion
from mlflow.models.model import ModelInfo
from mlflow.store.entities import PagedList
from sklearn.ensemble import RandomForestRegressor
from sklearn.preprocessing import StandardScaler
from torch import tensor

from numalogic.models.threshold import StdDevThreshold


def create_model():
x = torch.linspace(-math.pi, math.pi, 2000)
Expand All @@ -33,8 +34,7 @@ def create_model():


def model_sklearn():
params = {"n_estimators": 5, "random_state": 42}
return RandomForestRegressor(**params)
return StdDevThreshold()


def mock_log_state_dict(*_, **__):
Expand Down

0 comments on commit aa3b4bb

Please sign in to comment.