diff --git a/src/classification/mod.rs b/src/classification/mod.rs index 418250b..50fa8b0 100644 --- a/src/classification/mod.rs +++ b/src/classification/mod.rs @@ -1,4 +1,4 @@ pub mod alias; pub mod mondrian_forest; -mod mondrian_node; +pub mod mondrian_node; pub mod mondrian_tree; diff --git a/src/classification/mondrian_node.rs b/src/classification/mondrian_node.rs index fa80250..e1ec737 100644 --- a/src/classification/mondrian_node.rs +++ b/src/classification/mondrian_node.rs @@ -59,9 +59,9 @@ impl Node { /// In nel215 code it is "Classifier" #[derive(Clone)] pub struct Stats { - sums: Array2, - sq_sums: Array2, - counts: Array1, + pub sums: Array2, + pub sq_sums: Array2, + pub counts: Array1, num_labels: usize, } impl fmt::Display for Stats { @@ -93,16 +93,11 @@ impl Stats { num_labels, } } - pub fn create_result(&self, x: &Array1, w: F) -> ClassifierOutput { - let probabilities = self.predict_proba(x); - unimplemented!("Fix first predict_proba()"); - let mut results = HashMap::new(); - for (index, &prob) in probabilities.iter().enumerate() { - results.insert(ClassifierTarget::from(index.to_string()), prob * w); - } - ClassifierOutput::Probabilities(results) + pub fn create_result(&self, x: &Array1, w: F) -> Array1 { + let probs = self.predict_proba(x); + probs * w } - fn add(&mut self, x: &Array1, label_idx: usize) { + pub fn add(&mut self, x: &Array1, label_idx: usize) { // Same as: self.sums[label] += x; self.sums .row_mut(label_idx) @@ -126,7 +121,38 @@ impl Stats { // *self_count += other.counts[i]; // } } - fn predict_proba(&self, x: &Array1) -> Array1 { + /// Return probabilities of sample 'x' belonging to each class. + /// + /// e.g. probs: [0.1, 0.2, 0.7] + /// + /// TODO: Remove this example, I was testing if unit tests make sense, but as + /// shown below this does not show the error. The function is just too complex. + /// + /// # Example + /// ``` + /// use light_river::classification::alias::FType; + /// use light_river::classification::mondrian_node::Stats; + /// use ndarray::{Array1, Array2}; + /// + /// let mut stats = Stats::new(3, 2); // 3 classes and 2 features + /// stats.sums = Array2::from_shape_vec((3,2), vec![1.0, 2.0, 1.0, 2.0, 1.0, 2.0]) + /// .expect("Failed to create Array2"); + /// stats.sq_sums = Array2::from_shape_vec((3,2), vec![1.0, 2.0, 1.0, 2.0, 1.0, 2.0]) + /// .expect("Failed to create Array2");; + /// stats.counts = Array1::from_vec(vec![4, 5]); + /// stats.add(&Array1::from_vec(vec![1.0, 2.0]), 0); + /// stats.add(&Array1::from_vec(vec![2.0, 3.0]), 1); + /// stats.add(&Array1::from_vec(vec![2.0, 4.0]), 1); + /// + /// let x = Array1::from_vec(vec![1.5, 3.0]); + /// let probs = stats.predict_proba(&x); + /// let expected = vec![0.998075, 0.001924008, 0.0]; + /// assert!( + /// (probs - Array1::from_vec(expected)).mapv(|a: f32| a.abs()).iter().all(|&x| x < 1e-4), + /// "Probabilities do not match expected values" + /// ); + /// ``` + pub fn predict_proba(&self, x: &Array1) -> Array1 { let mut probs = Array1::zeros(self.num_labels); let mut sum_prob = F::zero(); println!("{self}"); @@ -138,13 +164,6 @@ impl Stats { .zip(self.counts.iter()) .enumerate() { - // Shadow with bogous values count, sum, sq_sum, x - // let xx: Array1 = Array1::from_vec(vec![F::from_f32(1.5).unwrap(), F::from_f32(3.0).unwrap()]); - // let x = &xx; - // let count = 2; - // let sum = Array1::from_vec(vec![F::from_f32(1.0).unwrap(), F::from_f32(2.0).unwrap()]); - // let sq_sum: ArrayBase, Dim<[usize; 1]>> = Array1::from_vec(vec![F::from_f32(1.0).unwrap(), F::from_f32(2.0).unwrap()]); - let count_f = F::from_usize(count).unwrap(); let avg = &sum / count_f; let var = (&sq_sum / count_f) - (&avg * &avg) + F::epsilon(); diff --git a/src/classification/mondrian_tree.rs b/src/classification/mondrian_tree.rs index cc38620..e392eee 100644 --- a/src/classification/mondrian_tree.rs +++ b/src/classification/mondrian_tree.rs @@ -147,11 +147,17 @@ impl MondrianTree { // Step 4: Generate a result for the current node using its statistics. let result = node.stats.create_result(x, p_not_separated_yet * p); - + // Shadowing with bogous values + let result = Array1::from_vec(vec![ + F::from_f32(0.7).unwrap(), + F::from_f32(0.2).unwrap(), + F::from_f32(0.1).unwrap(), + ]); println!( "predict() - result: {:?}, p_not_separated_yet: {:?}, p: {:?}", result, p_not_separated_yet, p ); + // if node.is_leaf() { // let w = p_not_separated_yet * (F::one() - p); // return result.merge(node.stats.create_result(x, w));