Skip to content

Commit

Permalink
don't clone features
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxHalford committed Oct 5, 2023
1 parent 3e388cc commit a35d649
Showing 1 changed file with 20 additions and 25 deletions.
45 changes: 20 additions & 25 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ struct HST {
}

impl HST {
fn new(height: u32, features: Vec<String>, rng: &mut ThreadRng) -> Self {
fn new(height: u32, features: &Vec<String>, rng: &mut ThreadRng) -> Self {
// TODO: padding
// TODO: handle non [0, 1] features
// TODO: weighted sampling of features
Expand Down Expand Up @@ -123,7 +123,7 @@ fn main() {

let mut trees: Vec<HST> = Vec::with_capacity(n_trees as usize);
for _ in 0..n_trees {
trees.push(HST::new(height, features.clone(), &mut rng));
trees.push(HST::new(height, &features, &mut rng));
}

// LOOP
Expand Down Expand Up @@ -200,30 +200,25 @@ fn main() {
let threshold = tree.threshold[node as usize];

// Get the value of the current feature
let value = match line.get_x().get(feature) {
Some(Data::Scalar(value)) => Some(value),
Some(Data::String(_)) => panic!("String feature not supported yet"),
None => None,
};

node = match value {
Some(value) => {
node = match line.get_x().get(feature) {
Some(Data::Scalar(value)) => {
// Update the mass of the current node
if *value < threshold {
left_child(node)
} else {
right_child(node)
}
}
// If the feature is missing, go down both branches and select the node with the
// the biggest l_mass
Some(Data::String(_)) => panic!("String feature not supported yet"),
None => {
// If the feature is missing, go down both branches and select the node with the
// the biggest l_mass
if tree.l_mass[left_child(node) as usize]
< tree.l_mass[right_child(node) as usize]
> tree.l_mass[right_child(node) as usize]
{
right_child(node)
} else {
left_child(node)
} else {
right_child(node)
}
}
};
Expand All @@ -233,16 +228,16 @@ fn main() {
// Pivot if the window is full
counter += 1;
if counter == window_size {
for tree in trees.iter_mut() {
for node in 0..tree.l_mass.len() {
tree.r_mass[node] = tree.l_mass[node];
tree.l_mass[node] = 0.0;
}
}
// trees.iter_mut().for_each(|tree| {
// tree.r_mass.copy_from_slice(&tree.l_mass);
// tree.l_mass.fill(0.0);
// });
// for tree in trees.iter_mut() {
// for node in 0..tree.l_mass.len() {
// tree.r_mass[node] = tree.l_mass[node];
// tree.l_mass[node] = 0.0;
// }
// }
trees.iter_mut().for_each(|tree| {
tree.r_mass.copy_from_slice(&tree.l_mass);
tree.l_mass.fill(0.0);
});
counter = 0;
}

Expand Down

0 comments on commit a35d649

Please sign in to comment.