Skip to content

Commit

Permalink
Add XGBoost JSON Reader/writer (#461)
Browse files Browse the repository at this point in the history
* Start hacking on xgboost saving/loading

* Tests and updates to xgboost reading

* Types
  • Loading branch information
stephenpardy authored Jul 16, 2024
1 parent cc3f5c1 commit 878d691
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 2 deletions.
17 changes: 15 additions & 2 deletions rubicon_ml/client/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _get_data(self):
@failsafe
def get_data(
self,
deserialize: Optional[Literal["h2o", "pickle"]] = None,
deserialize: Optional[Literal["h2o", "pickle", "xgboost"]] = None,
unpickle: bool = False, # TODO: deprecate & move to `deserialize`
):
"""Loads the data associated with this artifact and
Expand All @@ -84,6 +84,7 @@ def get_data(
* None to disable deseralization and return the raw data.
* "h2o" to use `h2o.load_model` to load the data.
* "pickle" to use pickles to load the data.
* "xgboost" to use xgboost's JSON loader to load the data as a fitted model.
Defaults to None.
unpickle : bool, optional
Flag indicating whether or not to unpickle artifact data.
Expand All @@ -102,7 +103,19 @@ def get_data(

for repo in self.repositories or []:
try:
data = repo.get_artifact_data(project_name, self.id, experiment_id=experiment_id)
if deserialize == "xgboost":
# xgboost can only handle string file name locations
import xgboost

artifact_data_path = repo._get_artifact_data_path(
project_name, experiment_id, self.id
)
data = xgboost.Booster()
data.load_model(artifact_data_path)
else:
data = repo.get_artifact_data(
project_name, self.id, experiment_id=experiment_id
)
except Exception as err:
return_err = err
else:
Expand Down
44 changes: 44 additions & 0 deletions rubicon_ml/client/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import dask.dataframe as dd
import pandas as pd
import polars as pl
import xgboost as xgb

from rubicon_ml.client import Artifact, Dataframe
from rubicon_ml.domain import DOMAIN_TYPES
Expand Down Expand Up @@ -282,6 +283,49 @@ def log_h2o_model(

return artifact

@failsafe
def log_xgboost_model(
self,
xgboost_model: "xgb.Booster",
artifact_name: Optional[str] = None,
**log_artifact_kwargs: Any,
) -> Artifact:
"""Log an XGBoost model as a JSON file to this client object.
Please note that we do not currently support logging directly from the SKLearn interface.
Parameters
----------
xgboost_model: Booster
An xgboost model object in the Booster format
artifact_name : str, optional
The name of the artifact (the exported XGBoost model).
log_artifact_kwargs : Any
Additional kwargs to be passed directly to `self.log_artifact`.
Returns
-------
rubicon.client.Artifact
The new artifact.
"""
if artifact_name is None:
artifact_name = xgboost_model.__class__.__name__

# TODO: handle sklearn
booster = xgboost_model

with tempfile.TemporaryDirectory() as temp_dir_name:
model_location = f"{temp_dir_name}/{artifact_name}.json"
booster.save_model(model_location)

artifact = self.log_artifact(
name=artifact_name,
data_path=model_location,
**log_artifact_kwargs,
)

return artifact

@failsafe
def log_pip_requirements(self, artifact_name: Optional[str] = None) -> Artifact:
"""Log the pip requirements as an artifact to this client object.
Expand Down
18 changes: 18 additions & 0 deletions tests/unit/client/test_artifact_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import h2o
import pandas as pd
import pytest
import xgboost
from h2o import H2OFrame
from h2o.estimators.random_forest import H2ORandomForestEstimator

Expand Down Expand Up @@ -186,6 +187,23 @@ def test_get_data_deserialize_h2o(
assert artifact_data.__class__ == h2o_model.__class__


def test_get_data_deserialize_xgboost(
make_classification_df, rubicon_local_filesystem_client_with_project
):
"""Test logging `xgboost` model data."""
_, project = rubicon_local_filesystem_client_with_project
X, y = make_classification_df

model = xgboost.XGBClassifier(n_estimators=2)
model.fit(X, y)
model = model.get_booster()

artifact = project.log_xgboost_model(model)
artifact_data = artifact.get_data(deserialize="xgboost")

assert artifact_data.__class__ == model.__class__


def test_download_data_unzip(project_client):
"""Test downloading and unzipping artifact data."""
project = project_client
Expand Down

0 comments on commit 878d691

Please sign in to comment.