Skip to content

Commit

Permalink
use single arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxHalford committed Oct 5, 2023
1 parent 75e27b9 commit 98f527c
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 31 deletions.
2 changes: 2 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ ROC AUC appears roughly similar between the Python and Rust implementations. Not
Using `with_capacity` on each `Vec` in `HST`, as well as the list of HSTs, we gain 1 second. We are now at **~5 seconds**.

We can't find a nice profiler. So for now we comment code and measure time.

Storing all attributes in a single array, instead of one array per tree, makes us reach **~3 seconds**.
60 changes: 29 additions & 31 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ struct HST {
}

impl HST {
fn new(height: u32, features: &Vec<String>, rng: &mut ThreadRng) -> Self {
fn new(n_trees: u32, height: u32, features: &Vec<String>, rng: &mut ThreadRng) -> Self {
// TODO: padding
// TODO: handle non [0, 1] features
// TODO: weighted sampling of features

// #nodes = 2 ^ height - 1
let n_nodes: usize = usize::try_from(u32::pow(2, height) - 1).unwrap();
let n_nodes: usize = usize::try_from(n_trees * u32::pow(2, height) - 1).unwrap();
// #branches = 2 ^ (height - 1) - 1
let n_branches = usize::try_from(u32::pow(2, height - 1) - 1).unwrap();
let n_branches = usize::try_from(n_trees * u32::pow(2, height - 1) - 1).unwrap();

// Helper function to create and populate a Vec with a given capacity
fn init_vec<T>(capacity: usize, default_value: T) -> Vec<T>
Expand Down Expand Up @@ -122,10 +122,9 @@ fn main() {
let start = SystemTime::now();
// INITIALIZATION

let mut trees: Vec<HST> = Vec::with_capacity(n_trees as usize);
for _ in 0..n_trees {
trees.push(HST::new(height, &features, &mut rng));
}
let mut hst = HST::new(n_trees, height, &features, &mut rng);
let n_nodes = u32::pow(2, height) - 1;
let n_branches = u32::pow(2, height - 1) - 1;

Check warning on line 127 in src/main.rs

View workflow job for this annotation

GitHub Actions / clippy

unused variable: `n_branches`

warning: unused variable: `n_branches` --> src/main.rs:127:9 | 127 | let n_branches = u32::pow(2, height - 1) - 1; | ^^^^^^^^^^ help: if this is intentional, prefix it with an underscore: `_n_branches` | = note: `#[warn(unused_variables)]` on by default

// LOOP

Expand All @@ -136,20 +135,20 @@ fn main() {

// SCORE
let mut score: f32 = 0.0;
for tree in trees.iter() {
let depth: u32 = 0;

for tree in 0..n_trees {
let offset: u32 = tree * n_nodes;
let mut node: u32 = 0;
loop {
score += tree.r_mass[node as usize] * u32::pow(2, depth) as f32;
for depth in 0..height {
score += hst.r_mass[(offset + node) as usize] * u32::pow(2, depth) as f32;
// Stop if the node is a leaf or if the mass of the node is too small
if (node >= tree.feature.len() as u32) || (tree.r_mass[node as usize] < size_limit)
{
if node >= n_nodes || (hst.r_mass[(offset + node) as usize] < size_limit) {
break;
}
// Get the feature and threshold of the current node so that we can determine
// whether to go left or right
let feature = &tree.feature[node as usize];
let threshold = tree.threshold[node as usize];
let feature = &hst.feature[(offset + node) as usize];
let threshold = hst.threshold[(offset + node) as usize];

// Get the value of the current feature
let value = match line.get_x().get(feature) {
Expand All @@ -170,8 +169,8 @@ fn main() {
// If the feature is missing, go down both branches and select the node with the
// the biggest l_mass
None => {
if tree.l_mass[left_child(node) as usize]
< tree.l_mass[right_child(node) as usize]
if hst.l_mass[(offset + left_child(node)) as usize]
< hst.l_mass[(offset + right_child(node)) as usize]
{
right_child(node)
} else {
Expand All @@ -185,20 +184,21 @@ fn main() {
let _ = csv_writer.serialize(score);

// UPDATE
for tree in trees.iter_mut() {
for tree in 0..n_trees {
// Walk over the tree
let offset: u32 = tree * n_nodes;
let mut node: u32 = 0;
loop {
for _ in 0..height {
// Update the l_mass
tree.l_mass[node as usize] += 1.0;
hst.l_mass[0 as usize] += 1.0;

Check warning on line 193 in src/main.rs

View workflow job for this annotation

GitHub Actions / clippy

casting integer literal to `usize` is unnecessary

warning: casting integer literal to `usize` is unnecessary --> src/main.rs:193:28 | 193 | hst.l_mass[0 as usize] += 1.0; | ^^^^^^^^^^ help: try: `0_usize` | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#unnecessary_cast = note: `#[warn(clippy::unnecessary_cast)]` on by default
// Stop if the node is a leaf
if node >= tree.feature.len() as u32 {
break;
}
// if node >= n_branches {
// break;
// }
// Get the feature and threshold of the current node so that we can determine
// whether to go left or right
let feature = &tree.feature[node as usize];
let threshold = tree.threshold[node as usize];
let feature = &hst.feature[0 as usize];

Check warning on line 200 in src/main.rs

View workflow job for this annotation

GitHub Actions / clippy

casting integer literal to `usize` is unnecessary

warning: casting integer literal to `usize` is unnecessary --> src/main.rs:200:44 | 200 | let feature = &hst.feature[0 as usize]; | ^^^^^^^^^^ help: try: `0_usize` | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#unnecessary_cast
let threshold = hst.threshold[0 as usize];

Check warning on line 201 in src/main.rs

View workflow job for this annotation

GitHub Actions / clippy

casting integer literal to `usize` is unnecessary

warning: casting integer literal to `usize` is unnecessary --> src/main.rs:201:47 | 201 | let threshold = hst.threshold[0 as usize]; | ^^^^^^^^^^ help: try: `0_usize` | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#unnecessary_cast

// Get the value of the current feature
node = match line.get_x().get(feature) {
Expand All @@ -214,8 +214,8 @@ fn main() {
None => {
// If the feature is missing, go down both branches and select the node with the
// the biggest l_mass
if tree.l_mass[left_child(node) as usize]
> tree.l_mass[right_child(node) as usize]
if hst.l_mass[(offset + left_child(node)) as usize]
> hst.l_mass[(offset + right_child(node)) as usize]
{
left_child(node)
} else {
Expand All @@ -229,10 +229,8 @@ fn main() {
// Pivot if the window is full
counter += 1;
if counter == window_size {
trees.iter_mut().for_each(|tree| {
mem::swap(&mut tree.r_mass, &mut tree.l_mass);
tree.l_mass.fill(0.0);
});
mem::swap(&mut hst.r_mass, &mut hst.l_mass);
hst.l_mass.fill(0.0);
counter = 0;
}

Expand Down

0 comments on commit 98f527c

Please sign in to comment.