Skip to content

Commit

Permalink
cat_fix var_hess
Browse files Browse the repository at this point in the history
  • Loading branch information
deadsoul44 committed Nov 14, 2024
1 parent aa85cb2 commit 2c6884c
Showing 1 changed file with 20 additions and 19 deletions.
39 changes: 20 additions & 19 deletions src/splitter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1092,12 +1092,12 @@ fn best_feature_split_var_hess(
) {
let split_info = unsafe { split_info_slice.get_mut(feat_idx) };
split_info.split_gain = -1.0;
split_info.left_cats = HashSet::new();
split_info.right_cats = HashSet::new();

let mut max_gain: Option<f32> = None;
let mut generalization: Option<f32>;
let mut left_cats: HashSet<usize> = HashSet::new();
let mut right_cats: HashSet<usize> = HashSet::new();
let mut is_cat = false;
let mut all_cats: Vec<usize> = Vec::new();

let evaluate_fn = eval_callables(false, create_missing_branch);

Expand All @@ -1108,12 +1108,11 @@ fn best_feature_split_var_hess(

if let Some(c_index) = cat_index {
if c_index.contains(&feature) {
is_cat = true;
sort_cat_bins_by_stat(&mut hist, false);
right_cats = HashSet::from_iter(
hist.iter()
.map(|b| unsafe { b.get().as_ref().unwrap().cut_value } as usize),
);
all_cats = hist
.iter()
.map(|b| unsafe { b.get().as_ref().unwrap().cut_value } as usize)
.collect();
}
}

Expand Down Expand Up @@ -1148,18 +1147,9 @@ fn best_feature_split_var_hess(
let mut cuml_hessian_valid = [f32::ZERO; 5];
let mut cuml_counts_valid = [0_usize; 5];

let mut cat: Option<usize> = None;

for bin in hist {
let b = unsafe { bin.get().as_ref().unwrap() };

if is_cat && cat.is_some() {
left_cats.insert(cat.unwrap());
right_cats.remove(&cat.unwrap());
}

cat = Some(b.cut_value as usize);

let left_gradient_train = cuml_gradient_train;
let left_hessian_train = cuml_hessian_train;
let left_counts_train = cuml_counts_train;
Expand Down Expand Up @@ -1300,6 +1290,17 @@ fn best_feature_split_var_hess(
if (max_gain.is_none() || split_gain > max_gain.unwrap()) && (generalization.is_some() || node.num == 0) {
max_gain = Some(split_gain);

let mut left_cats: HashSet<usize> = HashSet::new();
let mut right_cats: HashSet<usize> = all_cats.iter().copied().collect();

for c in all_cats.iter() {
if *c == b.cut_value as usize {
break;
}
left_cats.insert(*c);
right_cats.remove(c);
}

split_info.split_gain = split_gain;
split_info.split_feature = feature;
split_info.split_value = b.cut_value;
Expand All @@ -1308,8 +1309,8 @@ fn best_feature_split_var_hess(
split_info.right_node = right_node_info;
split_info.missing_node = missing_info;
split_info.generalization = generalization;
split_info.left_cats = left_cats.clone();
split_info.right_cats = right_cats.clone();
split_info.left_cats = left_cats;
split_info.right_cats = right_cats;
}
}
}
Expand Down

0 comments on commit 2c6884c

Please sign in to comment.