From 82d157b44c35101eb67ab9f2d5857680232debe8 Mon Sep 17 00:00:00 2001 From: stephenpardy Date: Sun, 30 Jun 2024 10:51:10 -0500 Subject: [PATCH 1/3] Start hacking on xgboost saving/loading --- rubicon_ml/client/artifact.py | 8 +++++++- rubicon_ml/client/mixin.py | 36 +++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/rubicon_ml/client/artifact.py b/rubicon_ml/client/artifact.py index c4582777..b6994974 100644 --- a/rubicon_ml/client/artifact.py +++ b/rubicon_ml/client/artifact.py @@ -71,7 +71,7 @@ def _get_data(self): @failsafe def get_data( self, - deserialize: Optional[Literal["h2o", "pickle"]] = None, + deserialize: Optional[Literal["h2o", "pickle", "xgboost"]] = None, unpickle: bool = False, # TODO: deprecate & move to `deserialize` ): """Loads the data associated with this artifact and @@ -84,6 +84,7 @@ def get_data( * None to disable deseralization and return the raw data. * "h2o" to use `h2o.load_model` to load the data. * "pickle" to use pickles to load the data. + * "xgboost" to use xgboost's JSON loader to load the data as a fitted model. Defaults to None. unpickle : bool, optional Flag indicating whether or not to unpickle artifact data. @@ -114,6 +115,11 @@ def get_data( ) elif deserialize == "pickle": data = pickle.loads(data) + elif deserialize == "xgboost": + import xgboost + model_data = data + data = xgboost.Booster() + data.load_model(model_data) return data diff --git a/rubicon_ml/client/mixin.py b/rubicon_ml/client/mixin.py index 48c29bc5..5315103b 100644 --- a/rubicon_ml/client/mixin.py +++ b/rubicon_ml/client/mixin.py @@ -281,6 +281,42 @@ def log_h2o_model( return artifact + @failsafe + def log_xgboost_model(self, xgboost_model, artifact_name: Optional[str] = None, **log_artifact_kwargs: Any) -> Artifact: + """Log an XGBoost model as a JSON file to this client object. + + Parameters + ---------- + xgboost_model + An xgboost model object. + artifact_name : str, optional + The name of the artifact (the exported XGBoost model). + log_artifact_kwargs : Any + Additional kwargs to be passed directly to `self.log_artifact`. + + Returns + ------- + rubicon.client.Artifact + The new artifact. + """ + if artifact_name is None: + artifact_name = xgboost_model.__class__.__name__ + + # TODO: handle sklearn + booster = xgboost_model + + with tempfile.TemporaryDirectory() as temp_dir_name: + model_location = f"{temp_dir_name}/{artifact_name}.json" + booster.save_model(model_location) + + artifact = self.log_artifact( + name=artifact_name, + data_path=model_location, + **log_artifact_kwargs, + ) + + return artifact + @failsafe def log_pip_requirements(self, artifact_name: Optional[str] = None) -> Artifact: """Log the pip requirements as an artifact to this client object. From 30b72eb81b58807ad6211c70218d1c227efcb567 Mon Sep 17 00:00:00 2001 From: stephenpardy Date: Mon, 1 Jul 2024 10:59:05 -0500 Subject: [PATCH 2/3] Tests and updates to xgboost reading --- rubicon_ml/client/artifact.py | 19 ++++++++++++------ rubicon_ml/client/mixin.py | 24 +++++++++++++++++------ tests/unit/client/test_artifact_client.py | 18 +++++++++++++++++ 3 files changed, 49 insertions(+), 12 deletions(-) diff --git a/rubicon_ml/client/artifact.py b/rubicon_ml/client/artifact.py index b6994974..8ffcbdde 100644 --- a/rubicon_ml/client/artifact.py +++ b/rubicon_ml/client/artifact.py @@ -103,7 +103,19 @@ def get_data( for repo in self.repositories or []: try: - data = repo.get_artifact_data(project_name, self.id, experiment_id=experiment_id) + if deserialize == "xgboost": + # xgboost can only handle string file name locations + import xgboost + + artifact_data_path = repo._get_artifact_data_path( + project_name, experiment_id, self.id + ) + data = xgboost.Booster() + data.load_model(artifact_data_path) + else: + data = repo.get_artifact_data( + project_name, self.id, experiment_id=experiment_id + ) except Exception as err: return_err = err else: @@ -115,11 +127,6 @@ def get_data( ) elif deserialize == "pickle": data = pickle.loads(data) - elif deserialize == "xgboost": - import xgboost - model_data = data - data = xgboost.Booster() - data.load_model(model_data) return data diff --git a/rubicon_ml/client/mixin.py b/rubicon_ml/client/mixin.py index 5315103b..c46ddac1 100644 --- a/rubicon_ml/client/mixin.py +++ b/rubicon_ml/client/mixin.py @@ -282,7 +282,12 @@ def log_h2o_model( return artifact @failsafe - def log_xgboost_model(self, xgboost_model, artifact_name: Optional[str] = None, **log_artifact_kwargs: Any) -> Artifact: + def log_xgboost_model( + self, + xgboost_model, + artifact_name: Optional[str] = None, + **log_artifact_kwargs: Any, + ) -> Artifact: """Log an XGBoost model as a JSON file to this client object. Parameters @@ -293,7 +298,7 @@ def log_xgboost_model(self, xgboost_model, artifact_name: Optional[str] = None, The name of the artifact (the exported XGBoost model). log_artifact_kwargs : Any Additional kwargs to be passed directly to `self.log_artifact`. - + Returns ------- rubicon.client.Artifact @@ -308,7 +313,7 @@ def log_xgboost_model(self, xgboost_model, artifact_name: Optional[str] = None, with tempfile.TemporaryDirectory() as temp_dir_name: model_location = f"{temp_dir_name}/{artifact_name}.json" booster.save_model(model_location) - + artifact = self.log_artifact( name=artifact_name, data_path=model_location, @@ -342,7 +347,10 @@ def log_pip_requirements(self, artifact_name: Optional[str] = None) -> Artifact: @failsafe def artifacts( - self, name: Optional[str] = None, tags: Optional[List[str]] = None, qtype: str = "or" + self, + name: Optional[str] = None, + tags: Optional[List[str]] = None, + qtype: str = "or", ) -> List[Artifact]: """Get the artifacts logged to this client object. @@ -416,7 +424,8 @@ def artifact(self, name: Optional[str] = None, id: Optional[str] = None) -> Arti for repo in self.repositories: try: artifact = client.Artifact( - repo.get_artifact_metadata(project_name, id, experiment_id), self + repo.get_artifact_metadata(project_name, id, experiment_id), + self, ) except Exception as err: return_err = err @@ -544,7 +553,10 @@ def log_dataframe( @failsafe def dataframes( - self, name: Optional[str] = None, tags: Optional[List[str]] = None, qtype: str = "or" + self, + name: Optional[str] = None, + tags: Optional[List[str]] = None, + qtype: str = "or", ) -> List[Dataframe]: """Get the dataframes logged to this client object. diff --git a/tests/unit/client/test_artifact_client.py b/tests/unit/client/test_artifact_client.py index b3d3dbfe..10135cfd 100644 --- a/tests/unit/client/test_artifact_client.py +++ b/tests/unit/client/test_artifact_client.py @@ -6,6 +6,7 @@ import h2o import pandas as pd import pytest +import xgboost from h2o import H2OFrame from h2o.estimators.random_forest import H2ORandomForestEstimator @@ -186,6 +187,23 @@ def test_get_data_deserialize_h2o( assert artifact_data.__class__ == h2o_model.__class__ +def test_get_data_deserialize_xgboost( + make_classification_df, rubicon_local_filesystem_client_with_project +): + """Test logging `xgboost` model data.""" + _, project = rubicon_local_filesystem_client_with_project + X, y = make_classification_df + + model = xgboost.XGBClassifier(n_estimators=2) + model.fit(X, y) + model = model.get_booster() + + artifact = project.log_xgboost_model(model) + artifact_data = artifact.get_data(deserialize="xgboost") + + assert artifact_data.__class__ == model.__class__ + + def test_download_data_unzip(project_client): """Test downloading and unzipping artifact data.""" project = project_client From fc67ee861289d3002a83051ce63e1cd9098365d5 Mon Sep 17 00:00:00 2001 From: stephenpardy Date: Mon, 15 Jul 2024 16:37:28 -0500 Subject: [PATCH 3/3] Types --- rubicon_ml/client/mixin.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/rubicon_ml/client/mixin.py b/rubicon_ml/client/mixin.py index c46ddac1..0c811711 100644 --- a/rubicon_ml/client/mixin.py +++ b/rubicon_ml/client/mixin.py @@ -23,6 +23,7 @@ if TYPE_CHECKING: import dask.dataframe as dd import pandas as pd + import xgboost as xgb from rubicon_ml.client import Artifact, Dataframe from rubicon_ml.domain import DOMAIN_TYPES @@ -284,16 +285,18 @@ def log_h2o_model( @failsafe def log_xgboost_model( self, - xgboost_model, + xgboost_model: "xgb.Booster", artifact_name: Optional[str] = None, **log_artifact_kwargs: Any, ) -> Artifact: """Log an XGBoost model as a JSON file to this client object. + Please note that we do not currently support logging directly from the SKLearn interface. + Parameters ---------- - xgboost_model - An xgboost model object. + xgboost_model: Booster + An xgboost model object in the Booster format artifact_name : str, optional The name of the artifact (the exported XGBoost model). log_artifact_kwargs : Any