diff --git a/rubicon_ml/client/artifact.py b/rubicon_ml/client/artifact.py index c4582777..8ffcbdde 100644 --- a/rubicon_ml/client/artifact.py +++ b/rubicon_ml/client/artifact.py @@ -71,7 +71,7 @@ def _get_data(self): @failsafe def get_data( self, - deserialize: Optional[Literal["h2o", "pickle"]] = None, + deserialize: Optional[Literal["h2o", "pickle", "xgboost"]] = None, unpickle: bool = False, # TODO: deprecate & move to `deserialize` ): """Loads the data associated with this artifact and @@ -84,6 +84,7 @@ def get_data( * None to disable deseralization and return the raw data. * "h2o" to use `h2o.load_model` to load the data. * "pickle" to use pickles to load the data. + * "xgboost" to use xgboost's JSON loader to load the data as a fitted model. Defaults to None. unpickle : bool, optional Flag indicating whether or not to unpickle artifact data. @@ -102,7 +103,19 @@ def get_data( for repo in self.repositories or []: try: - data = repo.get_artifact_data(project_name, self.id, experiment_id=experiment_id) + if deserialize == "xgboost": + # xgboost can only handle string file name locations + import xgboost + + artifact_data_path = repo._get_artifact_data_path( + project_name, experiment_id, self.id + ) + data = xgboost.Booster() + data.load_model(artifact_data_path) + else: + data = repo.get_artifact_data( + project_name, self.id, experiment_id=experiment_id + ) except Exception as err: return_err = err else: diff --git a/rubicon_ml/client/mixin.py b/rubicon_ml/client/mixin.py index 4a4ac228..cf7074c2 100644 --- a/rubicon_ml/client/mixin.py +++ b/rubicon_ml/client/mixin.py @@ -24,6 +24,7 @@ import dask.dataframe as dd import pandas as pd import polars as pl + import xgboost as xgb from rubicon_ml.client import Artifact, Dataframe from rubicon_ml.domain import DOMAIN_TYPES @@ -282,6 +283,49 @@ def log_h2o_model( return artifact + @failsafe + def log_xgboost_model( + self, + xgboost_model: "xgb.Booster", + artifact_name: Optional[str] = None, + **log_artifact_kwargs: Any, + ) -> Artifact: + """Log an XGBoost model as a JSON file to this client object. + + Please note that we do not currently support logging directly from the SKLearn interface. + + Parameters + ---------- + xgboost_model: Booster + An xgboost model object in the Booster format + artifact_name : str, optional + The name of the artifact (the exported XGBoost model). + log_artifact_kwargs : Any + Additional kwargs to be passed directly to `self.log_artifact`. + + Returns + ------- + rubicon.client.Artifact + The new artifact. + """ + if artifact_name is None: + artifact_name = xgboost_model.__class__.__name__ + + # TODO: handle sklearn + booster = xgboost_model + + with tempfile.TemporaryDirectory() as temp_dir_name: + model_location = f"{temp_dir_name}/{artifact_name}.json" + booster.save_model(model_location) + + artifact = self.log_artifact( + name=artifact_name, + data_path=model_location, + **log_artifact_kwargs, + ) + + return artifact + @failsafe def log_pip_requirements(self, artifact_name: Optional[str] = None) -> Artifact: """Log the pip requirements as an artifact to this client object. diff --git a/tests/unit/client/test_artifact_client.py b/tests/unit/client/test_artifact_client.py index b3d3dbfe..10135cfd 100644 --- a/tests/unit/client/test_artifact_client.py +++ b/tests/unit/client/test_artifact_client.py @@ -6,6 +6,7 @@ import h2o import pandas as pd import pytest +import xgboost from h2o import H2OFrame from h2o.estimators.random_forest import H2ORandomForestEstimator @@ -186,6 +187,23 @@ def test_get_data_deserialize_h2o( assert artifact_data.__class__ == h2o_model.__class__ +def test_get_data_deserialize_xgboost( + make_classification_df, rubicon_local_filesystem_client_with_project +): + """Test logging `xgboost` model data.""" + _, project = rubicon_local_filesystem_client_with_project + X, y = make_classification_df + + model = xgboost.XGBClassifier(n_estimators=2) + model.fit(X, y) + model = model.get_booster() + + artifact = project.log_xgboost_model(model) + artifact_data = artifact.get_data(deserialize="xgboost") + + assert artifact_data.__class__ == model.__class__ + + def test_download_data_unzip(project_client): """Test downloading and unzipping artifact data.""" project = project_client