Skip to content

Commit

Permalink
[python-package][PySpark] Expose Training and Validation Metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
a.cherkaoui committed Dec 29, 2024
1 parent bec2d32 commit ceb0ffa
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 8 deletions.
13 changes: 7 additions & 6 deletions python-package/xgboost/spark/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,8 +704,8 @@ def _pyspark_model_cls(cls) -> Type["_SparkXGBModel"]:
"""
raise NotImplementedError()

def _create_pyspark_model(self, xgb_model: XGBModel) -> "_SparkXGBModel":
return self._pyspark_model_cls()(xgb_model)
def _create_pyspark_model(self, xgb_model: XGBModel, evals_result: Dict[str, List[float]] = {}) -> "_SparkXGBModel":
return self._pyspark_model_cls()(xgb_model=xgb_model, evals_result=evals_result)

def _convert_to_sklearn_model(self, booster: bytearray, config: str) -> XGBModel:
xgb_sklearn_params = self._gen_xgb_params_dict(
Expand Down Expand Up @@ -1148,7 +1148,7 @@ def _train_booster(
if dvalid is not None:
dval = [(dtrain, "training"), (dvalid, "validation")]
else:
dval = None
dval = [(dtrain, "training")]
booster = worker_train(
params=booster_params,
dtrain=dtrain,
Expand All @@ -1159,6 +1159,7 @@ def _train_booster(
context.barrier()

if context.partitionId() == 0:
yield pd.DataFrame({"data": [json.dumps(dict(evals_result))]})
config = booster.save_config()
yield pd.DataFrame({"data": [config]})
booster_json = booster.save_raw("json").decode("utf-8")
Expand All @@ -1179,7 +1180,7 @@ def _run_job() -> Tuple[str, str]:
rdd_with_resource = self._try_stage_level_scheduling(rdd)
ret = rdd_with_resource.collect()
data = [v[0] for v in ret]
return data[0], "".join(data[1:])
return json.loads(data[0]), data[1], "".join(data[2:])

get_logger(_LOG_TAG).info(
"Running xgboost-%s on %s workers with"
Expand All @@ -1192,13 +1193,13 @@ def _run_job() -> Tuple[str, str]:
train_call_kwargs_params,
dmatrix_kwargs,
)
(config, booster) = _run_job()
(evals_result, config, booster) = _run_job()
get_logger(_LOG_TAG).info("Finished xgboost training!")

result_xgb_model = self._convert_to_sklearn_model(
bytearray(booster, "utf-8"), config
)
spark_model = self._create_pyspark_model(result_xgb_model)
spark_model = self._create_pyspark_model(xgb_model=result_xgb_model, evals_result=evals_result)
# According to pyspark ML convention, the model uid should be the same
# with estimator uid.
spark_model._resetUid(self.uid)
Expand Down
16 changes: 14 additions & 2 deletions python-package/xgboost/spark/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
# pylint: disable=unused-argument, too-many-locals

import warnings
from typing import Any, List, Optional, Type, Union
from typing import Any, List, Optional, Type, Union, Dict

import numpy as np
from pyspark import keyword_only
from pyspark.ml.param import Param, Params
from pyspark.ml.param.shared import HasProbabilityCol, HasRawPredictionCol

from ..collective import Config
from ..sklearn import XGBClassifier, XGBRanker, XGBRegressor
from ..sklearn import XGBClassifier, XGBRanker, XGBRegressor, XGBModel
from .core import ( # type: ignore
_ClassificationModel,
_SparkXGBEstimator,
Expand Down Expand Up @@ -252,6 +252,10 @@ class SparkXGBRegressorModel(_SparkXGBModel):
.. Note:: This API is experimental.
"""

def __init__(self, xgb_model: XGBModel, evals_result: Dict[str, Dict[str, List[float]]] = {}):
super().__init__(xgb_model)
self.evals_result = evals_result

@classmethod
def _xgb_cls(cls) -> Type[XGBRegressor]:
return XGBRegressor
Expand Down Expand Up @@ -448,6 +452,10 @@ class SparkXGBClassifierModel(_ClassificationModel):
.. Note:: This API is experimental.
"""

def __init__(self, xgb_model: XGBModel, evals_result: Dict[str, Dict[str, List[float]]] = {}):
super().__init__(xgb_model)
self.evals_result = evals_result

@classmethod
def _xgb_cls(cls) -> Type[XGBClassifier]:
return XGBClassifier
Expand Down Expand Up @@ -635,6 +643,10 @@ class SparkXGBRankerModel(_SparkXGBModel):
.. Note:: This API is experimental.
"""

def __init__(self, xgb_model: XGBModel, evals_result: Dict[str, Dict[str, List[float]]] = {}):
super().__init__(xgb_model)
self.evals_result = evals_result

@classmethod
def _xgb_cls(cls) -> Type[XGBRanker]:
return XGBRanker
Expand Down

0 comments on commit ceb0ffa

Please sign in to comment.