diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4d8febf2..fd2525b6 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -78,6 +78,7 @@ jobs: python-version: 3.8 - name: Install dependencies run: | + python -m pip install --upgrade pip python -m pip install ".[dev]" python -m pip install --upgrade git+https://github.com/rstudio/vetiver-python@${{ github.sha }} - name: run Docker diff --git a/docs/source/advancedusage/custom_handler.md b/docs/source/advancedusage/custom_handler.md index 333bda48..9009f403 100644 --- a/docs/source/advancedusage/custom_handler.md +++ b/docs/source/advancedusage/custom_handler.md @@ -12,6 +12,7 @@ class CustomHandler(BaseHandler): super().__init__(model, ptype_data) model_type = staticmethod(lambda: newmodeltype) + pip_name = "scikit-learn" # pkg name on pip, used for tracking pkg versions def handler_predict(self, input_data, check_ptype: bool): """ diff --git a/vetiver/attach_pkgs.py b/vetiver/attach_pkgs.py index bf845b2a..14341d52 100644 --- a/vetiver/attach_pkgs.py +++ b/vetiver/attach_pkgs.py @@ -1,6 +1,7 @@ import tempfile import os -from vetiver import VetiverModel +from .vetiver_model import VetiverModel +from .meta import VetiverMeta def load_pkgs(model: VetiverModel = None, packages: list = None, path=""): @@ -19,8 +20,12 @@ def load_pkgs(model: VetiverModel = None, packages: list = None, path=""): required_pkgs = ["vetiver"] if packages: required_pkgs = list(set(required_pkgs + packages)) - if model.metadata.get("required_pkgs"): - required_pkgs = list(set(required_pkgs + model.metadata.get("required_pkgs"))) + + if isinstance(model.metadata, dict): + model.metadata = VetiverMeta.from_dict(model.metadata) + + if model.metadata.required_pkgs: + required_pkgs = list(set(required_pkgs + model.metadata.required_pkgs)) tmp = tempfile.NamedTemporaryFile(suffix=".in", delete=False) tmp.close() diff --git a/vetiver/handlers/base.py b/vetiver/handlers/base.py index 2ceda71d..82772977 100644 --- a/vetiver/handlers/base.py +++ b/vetiver/handlers/base.py @@ -3,7 +3,7 @@ from contextlib import suppress from ..prototype import vetiver_create_prototype -from ..meta import _model_meta +from ..meta import VetiverMeta class InvalidModelError(Exception): @@ -43,7 +43,7 @@ def create_handler(model, prototype_data): >>> model = vetiver.mock.get_mock_model() >>> handler = vetiver.create_handler(model, X) >>> handler.describe() - "Scikit-learn model" + 'A scikit-learn DummyRegressor model' """ raise InvalidModelError( @@ -79,19 +79,20 @@ def __init__(self, model, prototype_data): def describe(self): """Create description for model""" - desc = f"{self.model.__class__} model" + + pip_name = self.pip_name if hasattr(self, "pip_name") else "" + obj_name = type(self.model).__qualname__ + + desc = f"A {pip_name} {obj_name} model" + return desc - def create_meta( - user: list = None, - version: str = None, - url: str = None, - required_pkgs: list = [], - ): + def create_meta(self, metadata): """Create metadata for a model""" - meta = _model_meta(user, version, url, required_pkgs) - return meta + pip_name = self.pip_name if hasattr(self, "pip_name") else None + + return VetiverMeta.from_dict(metadata, pip_name) def construct_prototype(self): """Create data prototype for a model diff --git a/vetiver/handlers/sklearn.py b/vetiver/handlers/sklearn.py index d1449872..774f2b3b 100644 --- a/vetiver/handlers/sklearn.py +++ b/vetiver/handlers/sklearn.py @@ -1,7 +1,6 @@ import pandas as pd import sklearn -from ..meta import _model_meta from .base import BaseHandler @@ -15,23 +14,7 @@ class SKLearnHandler(BaseHandler): """ model_class = staticmethod(lambda: sklearn.base.BaseEstimator) - - def describe(self): - """Create description for sklearn model""" - desc = f"Scikit-learn {self.model.__class__} model" - return desc - - def create_meta( - user: list = None, - version: str = None, - url: str = None, - required_pkgs: list = [], - ): - """Create metadata for sklearn model""" - required_pkgs = required_pkgs + ["scikit-learn"] - meta = _model_meta(user, version, url, required_pkgs) - - return meta + pip_name = "scikit-learn" def handler_predict(self, input_data, check_prototype): """Generates method for /predict endpoint in VetiverAPI diff --git a/vetiver/handlers/statsmodels.py b/vetiver/handlers/statsmodels.py index 8675485e..6888338b 100644 --- a/vetiver/handlers/statsmodels.py +++ b/vetiver/handlers/statsmodels.py @@ -1,6 +1,5 @@ import pandas as pd -from ..meta import _model_meta from .base import BaseHandler sm_exists = True @@ -20,23 +19,8 @@ class StatsmodelsHandler(BaseHandler): """ model_class = staticmethod(lambda: statsmodels.base.wrapper.ResultsWrapper) - - def describe(self): - """Create description for statsmodels model""" - desc = f"Statsmodels {self.model.__class__} model." - return desc - - def create_meta( - user: list = None, - version: str = None, - url: str = None, - required_pkgs: list = [], - ): - """Create metadata for statsmodel""" - required_pkgs = required_pkgs + ["statsmodels"] - meta = _model_meta(user, version, url, required_pkgs) - - return meta + if sm_exists: + pip_name = "statsmodels" def handler_predict(self, input_data, check_prototype): """Generates method for /predict endpoint in VetiverAPI diff --git a/vetiver/handlers/torch.py b/vetiver/handlers/torch.py index 788fb601..13b37888 100644 --- a/vetiver/handlers/torch.py +++ b/vetiver/handlers/torch.py @@ -1,6 +1,5 @@ import numpy as np -from ..meta import _model_meta from .base import BaseHandler torch_exists = True @@ -20,23 +19,8 @@ class TorchHandler(BaseHandler): """ model_class = staticmethod(lambda: torch.nn.Module) - - def describe(self): - """Create description for torch model""" - desc = f"Pytorch model of type {type(self.model)}" - return desc - - def create_meta( - user: list = None, - version: str = None, - url: str = None, - required_pkgs: list = [], - ): - """Create metadata for torch model""" - required_pkgs = required_pkgs + ["torch"] - meta = _model_meta(user, version, url, required_pkgs) - - return meta + if torch_exists: + pip_name = "torch" def handler_predict(self, input_data, check_prototype): """Generates method for /predict endpoint in VetiverAPI diff --git a/vetiver/handlers/xgboost.py b/vetiver/handlers/xgboost.py index f719c9a1..f8c5b2e3 100644 --- a/vetiver/handlers/xgboost.py +++ b/vetiver/handlers/xgboost.py @@ -1,6 +1,5 @@ import pandas as pd -from ..meta import _model_meta from .base import BaseHandler xgb_exists = True @@ -20,23 +19,8 @@ class XGBoostHandler(BaseHandler): """ model_class = staticmethod(lambda: xgboost.Booster) - - def describe(self): - """Create description for xgboost model""" - desc = f"XGBoost {self.model.__class__} model." - return desc - - def create_meta( - user: list = None, - version: str = None, - url: str = None, - required_pkgs: list = [], - ): - """Create metadata for xgboost""" - required_pkgs = required_pkgs + ["xgboost"] - meta = _model_meta(user, version, url, required_pkgs) - - return meta + if xgb_exists: + pip_name = "xgboost" def handler_predict(self, input_data, check_prototype): """Generates method for /predict endpoint in VetiverAPI diff --git a/vetiver/meta.py b/vetiver/meta.py index bfcdb6d7..d2a10f93 100644 --- a/vetiver/meta.py +++ b/vetiver/meta.py @@ -1,23 +1,33 @@ -def _model_meta( - user: dict = None, version: str = None, url: str = None, required_pkgs: list = None -): - """Populate relevant metadata for VetiverModel - - Args - ---- - user: dict - Extra user-defined information - version: str - Model version, generally populated from pins - url: str - Discoverable URL for API - required_pkgs: list - Packages necessary to make predictions - """ - meta = { - "user": user, - "version": version, - "url": url, - "required_pkgs": required_pkgs, - } - return meta +from dataclasses import dataclass, asdict, field +from typing import Mapping + + +@dataclass +class VetiverMeta: + """Metadata in a VetiverModel""" + + user: "dict | None" = field(default_factory=dict) + version: "str | None" = None + url: "str | None" = None + required_pkgs: "list | None" = field(default_factory=list) + + def to_dict(self) -> Mapping: + data = asdict(self) + + return data + + @classmethod + def from_dict(cls, metadata, pip_name=None) -> "VetiverMeta": + + metadata = {} if metadata is None else metadata + + user = metadata.get("user", metadata) + version = metadata.get("version", None) + url = metadata.get("url", None) + required_pkgs = metadata.get("required_pkgs", []) + + if pip_name: + if not list(filter(lambda x: pip_name in x, required_pkgs)): + required_pkgs = required_pkgs + [f"{pip_name}"] + + return cls(user, version, url, required_pkgs) diff --git a/vetiver/pin_read_write.py b/vetiver/pin_read_write.py index 7532da48..42c8c216 100644 --- a/vetiver/pin_read_write.py +++ b/vetiver/pin_read_write.py @@ -1,4 +1,5 @@ from .vetiver_model import VetiverModel +from .meta import VetiverMeta from .utils import inform import warnings import logging @@ -54,6 +55,10 @@ def vetiver_pin_write(board, model: VetiverModel, versioned: bool = True): # convert older model's ptype to prototype if hasattr(model, "ptype"): model.prototype = model.ptype + delattr(model, "ptype") + # metadata is dict + if isinstance(model.metadata, dict): + model.metadata = VetiverMeta.from_dict(model.metadata) board.pin_write( model.model, @@ -61,8 +66,11 @@ def vetiver_pin_write(board, model: VetiverModel, versioned: bool = True): type="joblib", description=model.description, metadata={ - "required_pkgs": model.metadata.get("required_pkgs"), - "prototype": None if model.prototype is None else model.prototype().json(), + "user": model.metadata.user, + "vetiver_meta": { + "required_pkgs": model.metadata.required_pkgs, + "prototype": None if not model.prototype else model.prototype().json(), + }, }, versioned=versioned, ) diff --git a/vetiver/server.py b/vetiver/server.py index e621d404..9d148916 100644 --- a/vetiver/server.py +++ b/vetiver/server.py @@ -12,6 +12,7 @@ from .utils import _jupyter_nb from .vetiver_model import VetiverModel +from .meta import VetiverMeta class VetiverAPI: @@ -82,11 +83,14 @@ def docs_redirect(): return RedirectResponse(redirect) - if self.model.metadata.get("url") is not None: + if isinstance(self.model.metadata, dict): + self.model.metadata = VetiverMeta.from_dict(self.model.metadata) + + if self.model.metadata.url is not None: @app.get("/pin-url") def pin_url(): - return repr(self.model.metadata.get("url")) + return repr(self.model.metadata.url) @app.get("/ping", include_in_schema=True) async def ping(): diff --git a/vetiver/tests/test_build_vetiver_model.py b/vetiver/tests/test_build_vetiver_model.py index 2d8df52c..2832ea0c 100644 --- a/vetiver/tests/test_build_vetiver_model.py +++ b/vetiver/tests/test_build_vetiver_model.py @@ -1,6 +1,7 @@ import sklearn import vetiver as vt +from vetiver.meta import VetiverMeta from vetiver.mock import get_mock_data, get_mock_model import pandas as pd @@ -73,9 +74,8 @@ def test_vetiver_model_basemodel_prototype(): model=model, prototype_data=m, model_name="model", - versioned=None, + versioned=False, description=None, - metadata=None, ) assert vt4.model == model @@ -99,16 +99,21 @@ def test_vetiver_model_no_prototype(): def test_vetiver_model_use_ptype(): vt5 = vt.VetiverModel( model=model, - ptype_data=X_df, + prototype_data=None, model_name="model", versioned=None, description=None, - metadata=None, + metadata={"test": 123}, ) assert vt5.model == model - assert isinstance(vt5.prototype.construct(), pydantic.BaseModel) - assert list(vt5.prototype.__fields__.values())[0].type_ == int + assert vt5.prototype is None + assert vt5.metadata == VetiverMeta( + user={"test": 123}, + version=None, + url=None, + required_pkgs=["scikit-learn"], + ) def test_vetiver_model_from_pin(): @@ -119,12 +124,18 @@ def test_vetiver_model_from_pin(): model_name="model", versioned=None, description=None, - metadata=None, + metadata={"test": 123}, ) + board = pins.board_temp(allow_pickle_read=True) vt.vetiver_pin_write(board=board, model=v) v2 = vt.VetiverModel.from_pin(board, "model") + assert isinstance(v2, vt.VetiverModel) assert isinstance(v2.model, sklearn.base.BaseEstimator) assert isinstance(v2.prototype.construct(), pydantic.BaseModel) + assert v2.metadata.user == {"test": 123} + assert v2.metadata.version is not None + assert v2.metadata.required_pkgs == ["scikit-learn"] + board.pin_delete("model") diff --git a/vetiver/tests/test_custom_handler.py b/vetiver/tests/test_custom_handler.py index 873c21d0..f08ac551 100644 --- a/vetiver/tests/test_custom_handler.py +++ b/vetiver/tests/test_custom_handler.py @@ -35,10 +35,10 @@ def test_custom_vetiver_model(): prototype_data=X, model_name="my_model", versioned=None, - description="A regression model for testing purposes", ) - assert v.description == "A regression model for testing purposes" + assert v.description == "A DummyRegressor model" + assert not v.metadata.required_pkgs assert isinstance(v.model, sklearn.dummy.DummyRegressor) assert isinstance(v.prototype.construct(), pydantic.BaseModel) diff --git a/vetiver/tests/test_pytorch.py b/vetiver/tests/test_pytorch.py index e57c4676..11c7ab4f 100644 --- a/vetiver/tests/test_pytorch.py +++ b/vetiver/tests/test_pytorch.py @@ -53,6 +53,7 @@ def test_vetiver_build(): ) assert vt2.model == torch_model + assert vt2.metadata.required_pkgs == ["torch"] def test_torch_predict_ptype(): diff --git a/vetiver/tests/test_sklearn.py b/vetiver/tests/test_sklearn.py index 29fe9fdf..0d03c026 100644 --- a/vetiver/tests/test_sklearn.py +++ b/vetiver/tests/test_sklearn.py @@ -20,6 +20,19 @@ def _start_application(save_prototype: bool = True): return app +def test_build_sklearn(): + X, y = mock.get_mock_data() + model = mock.get_mock_model().fit(X, y) + v = VetiverModel( + model=model, + ptype_data=X, + model_name="my_model", + description="A regression model for testing purposes", + ) + + assert v.metadata.required_pkgs == ["scikit-learn"] + + def test_predict_endpoint_ptype(): np.random.seed(500) client = TestClient(_start_application().app) diff --git a/vetiver/tests/test_xgboost.py b/vetiver/tests/test_xgboost.py index c708237a..425898ab 100644 --- a/vetiver/tests/test_xgboost.py +++ b/vetiver/tests/test_xgboost.py @@ -45,6 +45,13 @@ def vetiver_client_check_ptype_false(xgb_model): # With check_prototype=True return client +def test_model(xgb_model): + v = xgb_model + + assert v.metadata.required_pkgs == ["xgboost"] + assert not v.metadata.user + + def test_vetiver_build(vetiver_client): data = mtcars.head(1).drop(columns="mpg") diff --git a/vetiver/vetiver_model.py b/vetiver/vetiver_model.py index 843f607b..f01e7a3f 100644 --- a/vetiver/vetiver_model.py +++ b/vetiver/vetiver_model.py @@ -1,7 +1,6 @@ import json from warnings import warn from vetiver.handlers.base import create_handler -from .meta import _model_meta class NoModelAvailableError(Exception): @@ -62,7 +61,7 @@ class VetiverModel: >>> model = mock.get_mock_model().fit(X, y) >>> v = VetiverModel(model = model, model_name = "my_model", prototype_data = X) >>> v.description - "Scikit-learn model" + 'A scikit-learn DummyRegressor model' """ def __init__( @@ -91,12 +90,8 @@ def __init__( self.model_name = model_name self.description = description if description else translator.describe() self.versioned = versioned - self.metadata = ( - metadata - if metadata - else translator.create_meta(metadata, required_pkgs=["vetiver"]) - ) self.handler_predict = translator.handler_predict + self.metadata = translator.create_meta(metadata) @classmethod def from_pin(cls, board, name: str, version: str = None): @@ -104,23 +99,31 @@ def from_pin(cls, board, name: str, version: str = None): model = board.pin_read(name, version) meta = board.pin_meta(name, version) - if meta.user.get("ptype"): - get_prototype = meta.user.get("ptype") - elif meta.user.get("prototype"): - get_prototype = meta.user.get("prototype") + if "vetiver_meta" in meta.user: + get_prototype = meta.user.get("vetiver_meta").get("prototype", None) + required_pkgs = meta.user.get("vetiver_meta").get("required_pkgs", None) + meta.user.pop("vetiver_meta") else: - get_prototype = None + # ptype = meta.user.get("ptype", None) + + get_prototype = meta.user.get("ptype") + # elif meta.user.get("prototype"): + # get_prototype = meta.user.get("prototype") + # else: + # get_prototype = None + + required_pkgs = meta.user.get("required_pkgs") return cls( model=model, model_name=name, description=meta.description, - metadata=_model_meta( - user=meta.user, - version=meta.version.version, - url=meta.local.get("url"), # None all the time, besides Connect - required_pkgs=meta.user.get("required_pkgs"), - ), + metadata={ + "user": meta.user.get("user"), + "version": meta.version.version, + "url": meta.local.get("url"), # None all the time, besides Connect, + "required_pkgs": required_pkgs, + }, prototype_data=json.loads(get_prototype) if get_prototype else None, versioned=True, )