Skip to content

Commit

Permalink
Add unit test for predict_proba
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoDiFrancesco committed Apr 24, 2024
1 parent 107354a commit 4385fe8
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 22 deletions.
2 changes: 1 addition & 1 deletion src/classification/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub mod alias;
pub mod mondrian_forest;
mod mondrian_node;
pub mod mondrian_node;
pub mod mondrian_tree;
59 changes: 39 additions & 20 deletions src/classification/mondrian_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ impl<F: FType> Node<F> {
/// In nel215 code it is "Classifier"
#[derive(Clone)]
pub struct Stats<F> {
sums: Array2<F>,
sq_sums: Array2<F>,
counts: Array1<usize>,
pub sums: Array2<F>,
pub sq_sums: Array2<F>,
pub counts: Array1<usize>,
num_labels: usize,
}
impl<F: FType + fmt::Display> fmt::Display for Stats<F> {
Expand Down Expand Up @@ -93,16 +93,11 @@ impl<F: FType> Stats<F> {
num_labels,
}
}
pub fn create_result(&self, x: &Array1<F>, w: F) -> ClassifierOutput<F> {
let probabilities = self.predict_proba(x);
unimplemented!("Fix first predict_proba()");
let mut results = HashMap::new();
for (index, &prob) in probabilities.iter().enumerate() {
results.insert(ClassifierTarget::from(index.to_string()), prob * w);
}
ClassifierOutput::Probabilities(results)
pub fn create_result(&self, x: &Array1<F>, w: F) -> Array1<F> {
let probs = self.predict_proba(x);
probs * w
}
fn add(&mut self, x: &Array1<F>, label_idx: usize) {
pub fn add(&mut self, x: &Array1<F>, label_idx: usize) {
// Same as: self.sums[label] += x;
self.sums
.row_mut(label_idx)
Expand All @@ -126,7 +121,38 @@ impl<F: FType> Stats<F> {
// *self_count += other.counts[i];
// }
}
fn predict_proba(&self, x: &Array1<F>) -> Array1<F> {
/// Return probabilities of sample 'x' belonging to each class.
///
/// e.g. probs: [0.1, 0.2, 0.7]
///
/// TODO: Remove this example, I was testing if unit tests make sense, but as
/// shown below this does not show the error. The function is just too complex.
///
/// # Example
/// ```
/// use light_river::classification::alias::FType;
/// use light_river::classification::mondrian_node::Stats;
/// use ndarray::{Array1, Array2};
///
/// let mut stats = Stats::new(3, 2); // 3 classes and 2 features
/// stats.sums = Array2::from_shape_vec((3,2), vec![1.0, 2.0, 1.0, 2.0, 1.0, 2.0])
/// .expect("Failed to create Array2");
/// stats.sq_sums = Array2::from_shape_vec((3,2), vec![1.0, 2.0, 1.0, 2.0, 1.0, 2.0])
/// .expect("Failed to create Array2");;
/// stats.counts = Array1::from_vec(vec![4, 5]);
/// stats.add(&Array1::from_vec(vec![1.0, 2.0]), 0);
/// stats.add(&Array1::from_vec(vec![2.0, 3.0]), 1);
/// stats.add(&Array1::from_vec(vec![2.0, 4.0]), 1);
///
/// let x = Array1::from_vec(vec![1.5, 3.0]);
/// let probs = stats.predict_proba(&x);
/// let expected = vec![0.998075, 0.001924008, 0.0];
/// assert!(
/// (probs - Array1::from_vec(expected)).mapv(|a: f32| a.abs()).iter().all(|&x| x < 1e-4),
/// "Probabilities do not match expected values"
/// );
/// ```
pub fn predict_proba(&self, x: &Array1<F>) -> Array1<F> {
let mut probs = Array1::zeros(self.num_labels);
let mut sum_prob = F::zero();
println!("{self}");
Expand All @@ -138,13 +164,6 @@ impl<F: FType> Stats<F> {
.zip(self.counts.iter())
.enumerate()
{
// Shadow with bogous values count, sum, sq_sum, x
// let xx: Array1<F> = Array1::from_vec(vec![F::from_f32(1.5).unwrap(), F::from_f32(3.0).unwrap()]);
// let x = &xx;
// let count = 2;
// let sum = Array1::from_vec(vec![F::from_f32(1.0).unwrap(), F::from_f32(2.0).unwrap()]);
// let sq_sum: ArrayBase<ndarray::OwnedRepr<F>, Dim<[usize; 1]>> = Array1::from_vec(vec![F::from_f32(1.0).unwrap(), F::from_f32(2.0).unwrap()]);

let count_f = F::from_usize(count).unwrap();
let avg = &sum / count_f;
let var = (&sq_sum / count_f) - (&avg * &avg) + F::epsilon();
Expand Down
8 changes: 7 additions & 1 deletion src/classification/mondrian_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,17 @@ impl<F: FType> MondrianTree<F> {

// Step 4: Generate a result for the current node using its statistics.
let result = node.stats.create_result(x, p_not_separated_yet * p);

// Shadowing with bogous values
let result = Array1::from_vec(vec![
F::from_f32(0.7).unwrap(),
F::from_f32(0.2).unwrap(),
F::from_f32(0.1).unwrap(),
]);
println!(
"predict() - result: {:?}, p_not_separated_yet: {:?}, p: {:?}",
result, p_not_separated_yet, p
);

// if node.is_leaf() {
// let w = p_not_separated_yet * (F::one() - p);
// return result.merge(node.stats.create_result(x, w));
Expand Down

0 comments on commit 4385fe8

Please sign in to comment.