-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[python-package][PySpark] Expose Training and Validation Metrics
- Loading branch information
a.cherkaoui
committed
Dec 29, 2024
1 parent
bec2d32
commit 3d7161a
Showing
3 changed files
with
65 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
"""Xgboost training summary integration submodule.""" | ||
|
||
from dataclasses import dataclass, field | ||
from typing import Dict, List | ||
|
||
|
||
@dataclass | ||
class XGBoostTrainingSummary: | ||
""" | ||
A class that holds the training and validation objective history | ||
of an XGBoost model during its training process. | ||
""" | ||
|
||
train_objective_history: Dict[str, List[float]] = field(default_factory=dict) | ||
validation_objective_history: Dict[str, List[float]] = field(default_factory=dict) | ||
|
||
@staticmethod | ||
def from_metrics( | ||
metrics: Dict[str, Dict[str, List[float]]] | ||
) -> "XGBoostTrainingSummary": | ||
""" | ||
Create an XGBoostTrainingSummary instance from a nested dictionary of metrics. | ||
Parameters | ||
---------- | ||
metrics : dict of str to dict of str to list of float | ||
A dictionary containing training and validation metrics. | ||
Example format: | ||
{ | ||
"training": {"logloss": [0.1, 0.08]}, | ||
"validation": {"logloss": [0.12, 0.1]} | ||
} | ||
Returns | ||
------- | ||
A new instance of XGBoostTrainingSummary. | ||
""" | ||
train_objective_history = metrics.get("training", {}) | ||
validation_objective_history = metrics.get("validation", {}) | ||
return XGBoostTrainingSummary( | ||
train_objective_history, validation_objective_history | ||
) |