Skip to content

Commit

Permalink
Add create leafs when reaching a leaf
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoDiFrancesco committed May 24, 2024
1 parent 4d9ef48 commit 0217db2
Show file tree
Hide file tree
Showing 6 changed files with 209 additions and 132 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ cargo run --release --example credit_card
### 📊 Classification

```sh
RUSTFLAGS=-Awarnings cargo run --release --example synthetic
RUSTFLAGS=-Awarnings cargo run --example synthetic
```

### 🛒 Recsys
Expand Down
12 changes: 6 additions & 6 deletions examples/classification/synthetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,12 @@ fn main() {
if idx != 0 {
let score = mf.score(&x, y);
score_total += score;
// println!(
// "Accuracy: {} / {} = {}",
// score_total,
// dataset_size - 1,
// score_total / idx.to_f32().unwrap()
// );
println!(
"Accuracy: {} / {} = {}",
score_total,
dataset_size - 1,
score_total / idx.to_f32().unwrap()
);
}

// println!("=M=1 partial_fit {x}");
Expand Down
2 changes: 1 addition & 1 deletion src/classification/mondrian_forest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl<F: FType> MondrianForestClassifier<F> {
let mut tot_probs = Array1::<F>::zeros(self.n_labels);
for tree in &self.trees {
let probs = tree.predict_proba(x);
assert!(
debug_assert!(
!probs.iter().any(|&x| x.is_nan()),
"Probability should not be NaN. Found: {:?}.",
probs.to_vec()
Expand Down
32 changes: 15 additions & 17 deletions src/classification/mondrian_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::classification::alias::FType;

use ndarray::{Array1, Array2};

use num::{Float, FromPrimitive};
use num::{Float, FromPrimitive, ToPrimitive};

use std::fmt;

Expand All @@ -16,8 +16,8 @@ pub struct Node<F> {
pub parent: Option<usize>,
pub time: F, // Time: how much I increased the size of the box
pub is_leaf: bool,
pub min_list: Array1<F>, // Lists representing the minimum and maximum values of the data points contained in the current node
pub max_list: Array1<F>,
pub range_min: Array1<F>, // Lists representing the minimum and maximum values of the data points contained in the current node
pub range_max: Array1<F>,
pub feature: usize, // Feature in which a split occurs
pub threshold: F, // Threshold in which the split occures
pub left: Option<usize>,
Expand All @@ -30,8 +30,8 @@ impl<F: FType + fmt::Display> fmt::Display for Node<F> {
f,
"Node<time={:.3}, min={:?}, max={:?}, counts={:?}>",
self.time,
self.min_list.to_vec(),
self.max_list.to_vec(),
self.range_min.to_vec(),
self.range_max.to_vec(),
self.stats.counts.to_vec(),
)?;
Ok(())
Expand Down Expand Up @@ -121,7 +121,6 @@ impl<F: FType> Stats<F> {
/// Return probabilities of sample 'x' belonging to each class.
pub fn predict_proba(&self, x: &Array1<F>) -> Array1<F> {
let mut probs = Array1::zeros(self.n_labels);
let mut sum_prob = F::zero();

// println!("predict_proba() - start {}", self);

Expand All @@ -146,24 +145,23 @@ impl<F: FType> Stats<F> {
// epsilon added since exponent.exp() could be zero if exponent is very small
let mut prob = (exponent.exp() + epsilon) / z;
if count <= 0 {
assert!(prob.is_nan(), "Probabaility should be NaN. Found: {prob}.");
debug_assert!(prob.is_nan(), "Probabaility should be NaN. Found: {prob}.");
prob = F::zero();
}
sum_prob += prob;
probs[index] = prob;
}

// Check at least one probability is non-zero. Otherwise we have division by zero.
assert!(
!probs.iter().all(|&x| x == F::zero()),
"At least one probability should not be zero. Found: {:?}.",
probs.to_vec()
);

if probs.iter().all(|&x| x == F::zero()) {
// [0, 0, 0] -> [0.33, 0.33, 0.33]
probs = probs
.iter()
.map(|_| F::one() / F::from_f32(probs.len().to_f32().unwrap()).unwrap())
.collect();
}
let probs_sum = probs.sum();
for prob in probs.iter_mut() {
*prob /= sum_prob;
*prob /= probs_sum;
}
// println!("predict_proba() post - probs: {:?}", probs.to_vec());
probs
}
}
Loading

0 comments on commit 0217db2

Please sign in to comment.