From a35d649aa2f3c733cc7e4aeabc7ee71fd4d1179b Mon Sep 17 00:00:00 2001 From: Max Halford Date: Thu, 5 Oct 2023 10:59:01 +0200 Subject: [PATCH] don't clone features --- src/main.rs | 45 ++++++++++++++++++++------------------------- 1 file changed, 20 insertions(+), 25 deletions(-) diff --git a/src/main.rs b/src/main.rs index 5d2c981..ddd9845 100644 --- a/src/main.rs +++ b/src/main.rs @@ -32,7 +32,7 @@ struct HST { } impl HST { - fn new(height: u32, features: Vec, rng: &mut ThreadRng) -> Self { + fn new(height: u32, features: &Vec, rng: &mut ThreadRng) -> Self { // TODO: padding // TODO: handle non [0, 1] features // TODO: weighted sampling of features @@ -123,7 +123,7 @@ fn main() { let mut trees: Vec = Vec::with_capacity(n_trees as usize); for _ in 0..n_trees { - trees.push(HST::new(height, features.clone(), &mut rng)); + trees.push(HST::new(height, &features, &mut rng)); } // LOOP @@ -200,14 +200,8 @@ fn main() { let threshold = tree.threshold[node as usize]; // Get the value of the current feature - let value = match line.get_x().get(feature) { - Some(Data::Scalar(value)) => Some(value), - Some(Data::String(_)) => panic!("String feature not supported yet"), - None => None, - }; - - node = match value { - Some(value) => { + node = match line.get_x().get(feature) { + Some(Data::Scalar(value)) => { // Update the mass of the current node if *value < threshold { left_child(node) @@ -215,15 +209,16 @@ fn main() { right_child(node) } } - // If the feature is missing, go down both branches and select the node with the - // the biggest l_mass + Some(Data::String(_)) => panic!("String feature not supported yet"), None => { + // If the feature is missing, go down both branches and select the node with the + // the biggest l_mass if tree.l_mass[left_child(node) as usize] - < tree.l_mass[right_child(node) as usize] + > tree.l_mass[right_child(node) as usize] { - right_child(node) - } else { left_child(node) + } else { + right_child(node) } } }; @@ -233,16 +228,16 @@ fn main() { // Pivot if the window is full counter += 1; if counter == window_size { - for tree in trees.iter_mut() { - for node in 0..tree.l_mass.len() { - tree.r_mass[node] = tree.l_mass[node]; - tree.l_mass[node] = 0.0; - } - } - // trees.iter_mut().for_each(|tree| { - // tree.r_mass.copy_from_slice(&tree.l_mass); - // tree.l_mass.fill(0.0); - // }); + // for tree in trees.iter_mut() { + // for node in 0..tree.l_mass.len() { + // tree.r_mass[node] = tree.l_mass[node]; + // tree.l_mass[node] = 0.0; + // } + // } + trees.iter_mut().for_each(|tree| { + tree.r_mass.copy_from_slice(&tree.l_mass); + tree.l_mass.fill(0.0); + }); counter = 0; }