Skip to content

Commit

Permalink
Update function names from nel215 to River
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoDiFrancesco committed May 3, 2024
1 parent a9ca4bc commit ccc9b1d
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 125 deletions.
4 changes: 3 additions & 1 deletion examples/classification/keystroke.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,14 @@ fn main() {
ClassifierTarget::String(y) => y,
_ => unimplemented!(),
};
let y = labels.clone().iter().position(|l| l == &y).unwrap();

let x_ord = Array1::<f32>::from_vec(features.iter().map(|k| x[k]).collect());
// DEBUG: remove it
// let x_ord = x_ord.slice(s![0..2]).to_owned();

println!("=M=1 partial_fit");
mf.partial_fit(&x_ord, &y);
mf.partial_fit(&x_ord, y);

println!("=M=2 predict_proba");
let score = mf.predict_proba(&x_ord);
Expand Down
6 changes: 4 additions & 2 deletions examples/classification/synthetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,15 @@ fn main() {
ClassifierTarget::String(y) => y,
_ => unimplemented!(),
};
let y = labels.clone().iter().position(|l| l == &y).unwrap();

let x_ord = Array1::<f32>::from_vec(features.iter().map(|k| x[k]).collect());

// Skip first sample since tree has still no node
if idx != 0 {
// let probs = mf.predict_proba(&x_ord);
// println!("=M=2 probs: {:?}", probs.to_vec());
let score = mf.score(&x_ord, &y);
let score = mf.score(&x_ord, y);
// println!("=M=3 score: {:?}", score);
score_total += score;

Expand All @@ -84,7 +86,7 @@ fn main() {
// panic!("stop");
// }
println!("=M=1 partial_fit {x_ord}");
mf.partial_fit(&x_ord, &y);
mf.partial_fit(&x_ord, y);
}

let elapsed_time = now.elapsed();
Expand Down
7 changes: 3 additions & 4 deletions src/classification/mondrian_forest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl<F: FType> MondrianForest<F> {
/// working only on one.
///
/// Function in River/LightRiver: "learn_one()"
pub fn partial_fit(&mut self, x: &Array1<F>, y: &String) {
pub fn partial_fit(&mut self, x: &Array1<F>, y: usize) {
for tree in &mut self.trees {
tree.partial_fit(x, y);
}
Expand Down Expand Up @@ -78,16 +78,15 @@ impl<F: FType> MondrianForest<F> {
total_probs
}

pub fn score(&mut self, x: &Array1<F>, y: &String) -> F {
pub fn score(&mut self, x: &Array1<F>, y: usize) -> F {
let probs = self.predict_proba(x);
let y_idx = self.labels.iter().position(|l| l == y).unwrap();
let pred_idx = probs
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(idx, _)| idx)
.unwrap();
if pred_idx == y_idx {
if pred_idx == y {
F::one()
} else {
F::zero()
Expand Down
35 changes: 16 additions & 19 deletions src/classification/mondrian_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,24 @@ use std::{clone, cmp, mem, usize};
pub struct Node<F> {
// Change 'Rc' to 'Weak'
pub parent: Option<usize>, // Option<Rc<RefCell<Node<F>>>>,
pub tau: F, // Time parameter: updated during 'node creation' or 'node update'
pub time: F, // Time: how much I increased the size of the box
pub is_leaf: bool,
pub min_list: Array1<F>, // Lists representing the minimum and maximum values of the data points contained in the current node
pub max_list: Array1<F>,
pub delta: usize, // Dimension in which a split occurs
pub xi: F, // Split point along the dimension specified by delta
pub left: Option<usize>, // Option<Rc<RefCell<Node<F>>>>,
pub right: Option<usize>, // Option<Rc<RefCell<Node<F>>>>,
pub feature: usize, // Feature in which a split occurs
pub threshold: F, // Threshold in which the split occures
pub left: Option<usize>,
pub right: Option<usize>,
pub stats: Stats<F>,
}
impl<F: FType> Node<F> {
pub fn update_leaf(&mut self, x: &Array1<F>, label_idx: usize) {
self.stats.add(x, label_idx);
pub fn update_leaf(&mut self, x: &Array1<F>, y: usize) {
self.stats.add(x, y);
}
pub fn update_internal(&self, left_s: &Stats<F>, right_s: &Stats<F>) -> Stats<F> {
left_s.merge(right_s)
}
pub fn get_parent_tau(&self) -> F {
pub fn get_parent_time(&self) -> F {
panic!("Implemented in 'mondrian_tree' instead of 'mondrian_node'")
}
/// Check if all the labels are the same in the node.
Expand Down Expand Up @@ -100,19 +100,17 @@ impl<F: FType> Stats<F> {
let probs = self.predict_proba(x);
probs * w
}
pub fn add(&mut self, x: &Array1<F>, label_idx: usize) {
pub fn add(&mut self, x: &Array1<F>, y: usize) {
// Same as: self.sums[label] += x;
self.sums
.row_mut(label_idx)
.zip_mut_with(&x, |a, &b| *a += b);
self.sums.row_mut(y).zip_mut_with(&x, |a, &b| *a += b);

// Same as: self.sq_sums[label_idx] += x*x;
// Same as: self.sq_sums[y] += x*x;
// e.g. x: [1.059 0.580] -> x*x: [1.122 0.337]
self.sq_sums
.row_mut(label_idx)
.row_mut(y)
.zip_mut_with(&x, |a, &b| *a += b * b);

self.counts[label_idx] += 1;
self.counts[y] += 1;
}
fn merge(&self, s: &Stats<F>) -> Stats<F> {
// NOTE: nel215 returns a new Stats object, we are only changing the node values here
Expand Down Expand Up @@ -172,8 +170,7 @@ impl<F: FType> Stats<F> {
.enumerate()
{
// println!("predict_proba() - mid - index: {:?}, sum: {:?}, sq_sum: {:?}, count: {:?}", index, sum.to_vec(), sq_sum.to_vec(), count);
// let epsilon = F::epsilon(); // Don't use this variable, write 'F::epsilon' where needed.
let epsilon = F::from_f32(1e-9).unwrap();
let epsilon = F::epsilon(); // F::from_f32(1e-9).unwrap();
let count_f = F::from_usize(count).unwrap();
let avg = &sum / count_f;
let var = (&sq_sum / count_f) - (&avg * &avg) + epsilon;
Expand All @@ -182,9 +179,9 @@ impl<F: FType> Stats<F> {
let pi = F::from_f32(std::f32::consts::PI).unwrap() * F::from_f32(2.0).unwrap();
let z = pi.powi(x.len() as i32) * sigma.mapv(|s| s * s).sum().sqrt();
// Same as dot product
let dot_delta = (&(x - &avg) * &(x - &avg)).sum();
let dot_feature = (&(x - &avg) * &(x - &avg)).sum();
let dot_sigma = (&sigma * &sigma).sum();
let exponent = -F::from_f32(0.5).unwrap() * dot_delta / dot_sigma;
let exponent = -F::from_f32(0.5).unwrap() * dot_feature / dot_sigma;
// epsilon added since exponent.exp() could be zero if exponent is very small
let mut prob = (exponent.exp() + epsilon) / z;
if count <= 0 {
Expand Down
Loading

0 comments on commit ccc9b1d

Please sign in to comment.