diff --git a/src/splitter.rs b/src/splitter.rs index b3147e1..ebf147c 100644 --- a/src/splitter.rs +++ b/src/splitter.rs @@ -1092,12 +1092,12 @@ fn best_feature_split_var_hess( ) { let split_info = unsafe { split_info_slice.get_mut(feat_idx) }; split_info.split_gain = -1.0; + split_info.left_cats = HashSet::new(); + split_info.right_cats = HashSet::new(); let mut max_gain: Option = None; let mut generalization: Option; - let mut left_cats: HashSet = HashSet::new(); - let mut right_cats: HashSet = HashSet::new(); - let mut is_cat = false; + let mut all_cats: Vec = Vec::new(); let evaluate_fn = eval_callables(false, create_missing_branch); @@ -1108,12 +1108,11 @@ fn best_feature_split_var_hess( if let Some(c_index) = cat_index { if c_index.contains(&feature) { - is_cat = true; sort_cat_bins_by_stat(&mut hist, false); - right_cats = HashSet::from_iter( - hist.iter() - .map(|b| unsafe { b.get().as_ref().unwrap().cut_value } as usize), - ); + all_cats = hist + .iter() + .map(|b| unsafe { b.get().as_ref().unwrap().cut_value } as usize) + .collect(); } } @@ -1148,18 +1147,9 @@ fn best_feature_split_var_hess( let mut cuml_hessian_valid = [f32::ZERO; 5]; let mut cuml_counts_valid = [0_usize; 5]; - let mut cat: Option = None; - for bin in hist { let b = unsafe { bin.get().as_ref().unwrap() }; - if is_cat && cat.is_some() { - left_cats.insert(cat.unwrap()); - right_cats.remove(&cat.unwrap()); - } - - cat = Some(b.cut_value as usize); - let left_gradient_train = cuml_gradient_train; let left_hessian_train = cuml_hessian_train; let left_counts_train = cuml_counts_train; @@ -1300,6 +1290,17 @@ fn best_feature_split_var_hess( if (max_gain.is_none() || split_gain > max_gain.unwrap()) && (generalization.is_some() || node.num == 0) { max_gain = Some(split_gain); + let mut left_cats: HashSet = HashSet::new(); + let mut right_cats: HashSet = all_cats.iter().copied().collect(); + + for c in all_cats.iter() { + if *c == b.cut_value as usize { + break; + } + left_cats.insert(*c); + right_cats.remove(c); + } + split_info.split_gain = split_gain; split_info.split_feature = feature; split_info.split_value = b.cut_value; @@ -1308,8 +1309,8 @@ fn best_feature_split_var_hess( split_info.right_node = right_node_info; split_info.missing_node = missing_info; split_info.generalization = generalization; - split_info.left_cats = left_cats.clone(); - split_info.right_cats = right_cats.clone(); + split_info.left_cats = left_cats; + split_info.right_cats = right_cats; } } }