From 90f6216a41fa72f3f0910730c19a3fb41e73a246 Mon Sep 17 00:00:00 2001 From: "P. Sai Vinay" Date: Thu, 30 Sep 2021 21:28:17 +0530 Subject: [PATCH] Add number_samples to XGBoost Ml Models --- eland/ml/transformers/xgboost.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/eland/ml/transformers/xgboost.py b/eland/ml/transformers/xgboost.py index 9eb1d189..b170c24f 100644 --- a/eland/ml/transformers/xgboost.py +++ b/eland/ml/transformers/xgboost.py @@ -89,7 +89,11 @@ def extract_node_id(self, node_id: str, curr_tree: int) -> int: ) def build_leaf_node(self, row: pd.Series, curr_tree: int) -> TreeNode: - return TreeNode(node_idx=row["Node"], leaf_value=[float(row["Gain"])]) + return TreeNode( + node_idx=row["Node"], + leaf_value=[float(row["Gain"])], + number_samples=int(row["Cover"]), + ) def build_tree_node(self, row: pd.Series, curr_tree: int) -> TreeNode: node_index = row["Node"] @@ -103,6 +107,7 @@ def build_tree_node(self, row: pd.Series, curr_tree: int) -> TreeNode: right_child=self.extract_node_id(row["No"], curr_tree), threshold=float(row["Split"]), split_feature=self.get_feature_id(row["Feature"]), + number_samples=int(row["Cover"]), ) def build_tree(self, nodes: List[TreeNode]) -> Tree: @@ -238,7 +243,9 @@ def build_leaf_node(self, row: pd.Series, curr_tree: int) -> TreeNode: return super().build_leaf_node(row, curr_tree) leaf_val = [0.0] * self._num_classes leaf_val[curr_tree % self._num_classes] = float(row["Gain"]) - return TreeNode(node_idx=row["Node"], leaf_value=leaf_val) + return TreeNode( + node_idx=row["Node"], leaf_value=leaf_val, number_samples=int(row["Cover"]) + ) def determine_target_type(self) -> str: return "classification"