Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add XGBoost JSON Reader/writer #461

Merged
merged 7 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions rubicon_ml/client/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _get_data(self):
@failsafe
def get_data(
self,
deserialize: Optional[Literal["h2o", "pickle"]] = None,
deserialize: Optional[Literal["h2o", "pickle", "xgboost"]] = None,
unpickle: bool = False, # TODO: deprecate & move to `deserialize`
):
"""Loads the data associated with this artifact and
Expand All @@ -84,6 +84,7 @@ def get_data(
* None to disable deseralization and return the raw data.
* "h2o" to use `h2o.load_model` to load the data.
* "pickle" to use pickles to load the data.
* "xgboost" to use xgboost's JSON loader to load the data as a fitted model.
Defaults to None.
unpickle : bool, optional
Flag indicating whether or not to unpickle artifact data.
Expand All @@ -102,7 +103,19 @@ def get_data(

for repo in self.repositories or []:
try:
data = repo.get_artifact_data(project_name, self.id, experiment_id=experiment_id)
if deserialize == "xgboost":
# xgboost can only handle string file name locations
import xgboost

artifact_data_path = repo._get_artifact_data_path(
project_name, experiment_id, self.id
)
data = xgboost.Booster()
data.load_model(artifact_data_path)
else:
data = repo.get_artifact_data(
project_name, self.id, experiment_id=experiment_id
)
except Exception as err:
return_err = err
else:
Expand Down
44 changes: 44 additions & 0 deletions rubicon_ml/client/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import dask.dataframe as dd
import pandas as pd
import polars as pl
import xgboost as xgb

from rubicon_ml.client import Artifact, Dataframe
from rubicon_ml.domain import DOMAIN_TYPES
Expand Down Expand Up @@ -282,6 +283,49 @@ def log_h2o_model(

return artifact

@failsafe
def log_xgboost_model(
self,
xgboost_model: "xgb.Booster",
artifact_name: Optional[str] = None,
**log_artifact_kwargs: Any,
) -> Artifact:
"""Log an XGBoost model as a JSON file to this client object.

Please note that we do not currently support logging directly from the SKLearn interface.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you mean by this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If users use the XGBRegressor or XGBClassifier classes that implement sklearn interface for xgboost this will not work. They would have to use model.get_booster() first and log that. An obvious addition is to check for that ourselves and then call model.get_booster() if relevant.

I can add that in this PR if you want.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh gotcha, nah this is fine for now


Parameters
----------
xgboost_model: Booster
An xgboost model object in the Booster format
artifact_name : str, optional
The name of the artifact (the exported XGBoost model).
log_artifact_kwargs : Any
Additional kwargs to be passed directly to `self.log_artifact`.

Returns
-------
rubicon.client.Artifact
The new artifact.
"""
if artifact_name is None:
artifact_name = xgboost_model.__class__.__name__

# TODO: handle sklearn
booster = xgboost_model

with tempfile.TemporaryDirectory() as temp_dir_name:
model_location = f"{temp_dir_name}/{artifact_name}.json"
booster.save_model(model_location)

artifact = self.log_artifact(
name=artifact_name,
data_path=model_location,
**log_artifact_kwargs,
)

return artifact

@failsafe
def log_pip_requirements(self, artifact_name: Optional[str] = None) -> Artifact:
"""Log the pip requirements as an artifact to this client object.
Expand Down
18 changes: 18 additions & 0 deletions tests/unit/client/test_artifact_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import h2o
import pandas as pd
import pytest
import xgboost
from h2o import H2OFrame
from h2o.estimators.random_forest import H2ORandomForestEstimator

Expand Down Expand Up @@ -186,6 +187,23 @@ def test_get_data_deserialize_h2o(
assert artifact_data.__class__ == h2o_model.__class__


def test_get_data_deserialize_xgboost(
make_classification_df, rubicon_local_filesystem_client_with_project
):
"""Test logging `xgboost` model data."""
_, project = rubicon_local_filesystem_client_with_project
X, y = make_classification_df

model = xgboost.XGBClassifier(n_estimators=2)
model.fit(X, y)
model = model.get_booster()

artifact = project.log_xgboost_model(model)
artifact_data = artifact.get_data(deserialize="xgboost")

assert artifact_data.__class__ == model.__class__


def test_download_data_unzip(project_client):
"""Test downloading and unzipping artifact data."""
project = project_client
Expand Down
Loading