Skip to content

Commit

Permalink
Add debug statement for overwriting variance aware estimation
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoDiFrancesco committed May 29, 2024
1 parent a5bd895 commit 3544c28
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 84 deletions.
9 changes: 4 additions & 5 deletions examples/classification/synthetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ fn main() {
};
let y = labels.clone().iter().position(|l| l == &y).unwrap();

println!("=M=1 x:{}, idx: {}", x, idx);
// println!("=M=1 x:{}, idx: {}", x, idx);

// Skip first sample since tree has still no node
if idx != 0 {
Expand All @@ -87,12 +87,11 @@ fn main() {
);
}

// println!("=M=1 partial_fit {x}");
mf.partial_fit(&x, y);

// if idx == 166 {
// if idx == 527 {
// break;
// }

mf.partial_fit(&x, y);
}

let elapsed_time = now.elapsed();
Expand Down
6 changes: 2 additions & 4 deletions src/classification/mondrian_forest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@ impl<F: FType> MondrianForestClassifier<F> {
MondrianForestClassifier::<F> { trees, n_labels }
}

/// Note: In Nel215 codebase should work on multiple records, here it's
/// working only on one.
///
/// Function in River/LightRiver: "learn_one()"
/// Function in River is "learn_one()"
pub fn partial_fit(&mut self, x: &Array1<F>, y: usize) {
for tree in &mut self.trees {
tree.partial_fit(x, y);
Expand Down Expand Up @@ -54,6 +51,7 @@ impl<F: FType> MondrianForestClassifier<F> {
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(idx, _)| idx)
.unwrap();
// println!("probs: {}, pred_idx: {}, y (correct): {}, is_correct: {}", probs, pred_idx, y, pred_idx == y);
if pred_idx == y {
F::one()
} else {
Expand Down
23 changes: 17 additions & 6 deletions src/classification/mondrian_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,19 +98,21 @@ impl<F: FType> Stats<F> {
probs * w
}
pub fn add(&mut self, x: &Array1<F>, y: usize) {
// Checked on May 29th on few samples, looks correct
// println!("add() - x={x}, y={y}, count={}, \nsums={}, \nsq_sums={}", self.counts, self.sums, self.sq_sums);

// Same as: self.sums[y] += x;
self.sums.row_mut(y).zip_mut_with(&x, |a, &b| *a += b);

// Same as: self.sq_sums[y] += x*x;
// e.g. x: [1.059 0.580] -> x*x: [1.122 0.337]
self.sq_sums
.row_mut(y)
.zip_mut_with(&x, |a, &b| *a += b * b);

self.counts[y] += 1;

// println!(" - y={y}, count={}, \nsums={}, \nsq_sums={}", self.counts, self.sums, self.sq_sums);
}
fn merge(&self, s: &Stats<F>) -> Stats<F> {
// NOTE: nel215 returns a new Stats object, we are only changing the node values here
Stats {
sums: self.sums.clone() + &s.sums,
sq_sums: self.sq_sums.clone() + &s.sq_sums,
Expand All @@ -124,13 +126,18 @@ impl<F: FType> Stats<F> {

// println!("predict_proba() - start {}", self);

for (index, ((sum, sq_sum), &count)) in self
// println!("var aware est - counts: {}", self.counts);

// Iterate over each label
for (idx, ((sum, sq_sum), &count)) in self
.sums
.outer_iter()
.zip(self.sq_sums.outer_iter())
.zip(self.counts.iter())
.enumerate()
{
// println!(" - idx: {idx}, count: {count}, sum: {sum}, sq_sum: {sq_sum}");

let epsilon = F::epsilon();
let count_f = F::from_usize(count).unwrap();
let avg = &sum / count_f;
Expand All @@ -145,10 +152,13 @@ impl<F: FType> Stats<F> {
// epsilon added since exponent.exp() could be zero if exponent is very small
let mut prob = (exponent.exp() + epsilon) / z;
if count <= 0 {
debug_assert!(prob.is_nan(), "Probabaility should be NaN. Found: {prob}.");
// prob is NaN
prob = F::zero();
}
probs[index] = prob;
probs[idx] = prob;

// DEBUG: stop using variance aware estimation
probs[idx] = count_f;
}

if probs.iter().all(|&x| x == F::zero()) {
Expand All @@ -162,6 +172,7 @@ impl<F: FType> Stats<F> {
for prob in probs.iter_mut() {
*prob /= probs_sum;
}
// println!(" - probs out: {}", probs);
probs
}
}
106 changes: 38 additions & 68 deletions src/classification/mondrian_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ impl<F: FType + fmt::Display> fmt::Display for MondrianTreeClassifier<F> {
}

impl<F: FType + fmt::Display> MondrianTreeClassifier<F> {
/// Helper method to recursively format node details.
fn recursive_repr(
&self,
node_idx: Option<usize>,
Expand All @@ -54,6 +53,7 @@ impl<F: FType + fmt::Display> MondrianTreeClassifier<F> {
writeln!(
f,
"{}{}Node {}: time={:.3}, min={:?}, max={:?}, thrs={:.2}, f={}, counts={}",
// "{}{}Node {}: time={:.3}, min={:?}, max={:?}, thrs={:.2}, f={}, counts={}, \nsums={}, \nsq_sums={}",
// "{}{}Node {}: left={:?}, right={:?}, parent={:?}, time={:.3}, min={:?}, max={:?}, thrs={:.2}, f={}, counts={}",
prefix,
node_prefix,
Expand All @@ -67,6 +67,8 @@ impl<F: FType + fmt::Display> MondrianTreeClassifier<F> {
node.threshold,
feature,
node.stats.counts,
// node.stats.sums,
// node.stats.sq_sums,
// node.is_leaf,
)?;

Expand Down Expand Up @@ -97,18 +99,11 @@ impl<F: FType> MondrianTreeClassifier<F> {
}
}

fn create_node(
&mut self,
x: &Array1<F>,
y: usize,
parent: Option<usize>,
time: F,
is_leaf: bool,
) -> usize {
fn create_leaf(&mut self, x: &Array1<F>, y: usize, parent: Option<usize>, time: F) -> usize {
let mut node = Node::<F> {
parent,
time, // F::from(1e9).unwrap(), // Very large value
is_leaf,
is_leaf: true,
range_min: x.clone(),
range_max: x.clone(),
feature: usize::MAX,
Expand All @@ -123,7 +118,7 @@ impl<F: FType> MondrianTreeClassifier<F> {
node_idx
}

fn create_node_empty(&mut self, parent: Option<usize>, time: F) -> usize {
fn create_empty_node(&mut self, parent: Option<usize>, time: F) -> usize {
let node = Node::<F> {
parent,
time, // F::from(1e9).unwrap(), // Very large value
Expand Down Expand Up @@ -391,10 +386,10 @@ impl<F: FType> MondrianTreeClassifier<F> {
extensions_sum: F,
) -> F {
if self.nodes[node_idx].is_dirac(y) {
println!(
"compute_split_time() - node: {node_idx} - extensions_sum: {:?} - same class",
extensions_sum
);
// println!(
// "compute_split_time() - node: {node_idx} - extensions_sum: {:?} - same class",
// extensions_sum
// );
return F::zero();
}

Expand All @@ -403,10 +398,10 @@ impl<F: FType> MondrianTreeClassifier<F> {

// From River: If the node is a leaf we must split it
if self.nodes[node_idx].is_leaf {
println!(
"compute_split_time() - node: {node_idx} - extensions_sum: {:?} - split is_leaf",
extensions_sum
);
// println!(
// "compute_split_time() - node: {node_idx} - extensions_sum: {:?} - split is_leaf",
// extensions_sum
// );
return split_time;
}

Expand All @@ -416,45 +411,25 @@ impl<F: FType> MondrianTreeClassifier<F> {
let child_time = self.nodes[child_idx].time;
// 2. We check if splitting time occurs before child creation time
if split_time < child_time {
println!(
"compute_split_time() - node: {node_idx} - extensions_sum: {:?} - split mid tree",
extensions_sum
);
// println!(
// "compute_split_time() - node: {node_idx} - extensions_sum: {:?} - split mid tree",
// extensions_sum
// );
return split_time;
}
println!("compute_split_time() - node: {node_idx} - extensions_sum: {:?} - not increased enough to split (mid node)", extensions_sum);
// println!("compute_split_time() - node: {node_idx} - extensions_sum: {:?} - not increased enough to split (mid node)", extensions_sum);
} else {
println!(
"compute_split_time() - node: {node_idx} - extensions_sum: {:?} - not outside box",
extensions_sum
);
// println!(
// "compute_split_time() - node: {node_idx} - extensions_sum: {:?} - not outside box",
// extensions_sum
// );
}

F::zero()
}

fn go_downwards(&mut self, node_idx: usize, x: &Array1<F>, y: usize) -> usize {
let time = self.nodes[node_idx].time;
// Set 0 if any value is Inf
// TODO: remove it if not useful
// let node_range_min = if self.nodes[node_idx]
// .range_min
// .iter()
// .any(|&x| !x.is_infinite())
// {
// self.nodes[node_idx].range_min.clone()
// } else {
// Array1::zeros(self.nodes[node_idx].range_min.len())
// };
// let node_range_max = if self.nodes[node_idx]
// .range_max
// .iter()
// .any(|&x| x.is_infinite())
// {
// self.nodes[node_idx].range_max.clone()
// } else {
// Array1::zeros(self.nodes[node_idx].range_max.len())
// };
let node_range_min = &self.nodes[node_idx].range_min;
let node_range_max = &self.nodes[node_idx].range_max;
let extensions = {
Expand Down Expand Up @@ -516,9 +491,10 @@ impl<F: FType> MondrianTreeClassifier<F> {
range_max.zip_mut_with(x, |a, &b| *a = F::max(*a, b));

if self.nodes[node_idx].is_leaf {
println!("go_downwards() - split_time > 0 (is leaf)");
let leaf_full = self.create_node(x, y, Some(node_idx), split_time, true);
let leaf_empty = self.create_node_empty(Some(node_idx), split_time);
// Add two leaves.
// println!("go_downwards() - split_time > 0 (is leaf)");
let leaf_full = self.create_leaf(x, y, Some(node_idx), split_time);
let leaf_empty = self.create_empty_node(Some(node_idx), split_time);
// if x[feature] <= threshold {
if is_right_extension {
self.nodes[node_idx].left = Some(leaf_empty);
Expand All @@ -535,7 +511,8 @@ impl<F: FType> MondrianTreeClassifier<F> {
self.update_downwards(node_idx);
return node_idx;
} else {
println!("go_downwards() - split_time > 0 (not leaf)");
// Add node along the path.
// println!("go_downwards() - split_time > 0 (not leaf)");
let parent_node = Node {
parent: self.nodes[node_idx].parent,
time: self.nodes[node_idx].time,
Expand All @@ -551,14 +528,11 @@ impl<F: FType> MondrianTreeClassifier<F> {
self.nodes.push(parent_node);
let parent_idx = self.nodes.len() - 1;

// === Changed "create_node_empty" to "create_node"
// TODO: check is_leaf, sometimes is true, sometimes false??
let sibling_idx = self.create_node(x, y, Some(parent_idx), split_time, true);

println!(
"grandpa: {:?}, parent: {:?}, child: {:?}, sibling: {:?}",
self.nodes[node_idx].parent, parent_idx, node_idx, sibling_idx
);
let sibling_idx = self.create_leaf(x, y, Some(parent_idx), split_time);
// println!(
// "grandpa: {:?}, parent: {:?}, child: {:?}, sibling: {:?}",
// self.nodes[node_idx].parent, parent_idx, node_idx, sibling_idx
// );
// Node 1. Grandpa: self.nodes[node_idx].parent
// └─Node 3. (new) Parent: parent_idx
// ├─Node 2. Child: node_idx
Expand All @@ -570,20 +544,18 @@ impl<F: FType> MondrianTreeClassifier<F> {
self.nodes[parent_idx].left = Some(sibling_idx);
self.nodes[parent_idx].right = Some(node_idx);
}
// 'stats' copied from River
self.nodes[parent_idx].stats = self.nodes[node_idx].stats.clone();
self.nodes[node_idx].parent = Some(parent_idx);
self.nodes[node_idx].time = split_time;

// This if is required to not break 'child_inside_parent' test. Even though
// This 'if' is required to not break 'child_inside_parent' test. Even though
// it's probably correct I'll comment it until we get a 1:1 with River.
// if self.nodes[node_idx].is_leaf {
self.nodes[node_idx].range_min = Array1::from_elem(self.n_features, F::infinity());
self.nodes[node_idx].range_max = Array1::from_elem(self.n_features, -F::infinity());
self.nodes[node_idx].stats = Stats::new(self.n_labels, self.n_features);
// }
// self.update_downwards(parent_idx);
// From River: added "update_leaf" after "update_downwards"
self.nodes[parent_idx].update_leaf(x, y);
return parent_idx;
}
Expand All @@ -610,8 +582,6 @@ impl<F: FType> MondrianTreeClassifier<F> {
let node = &mut self.nodes[node_idx];
node.right = node_right_new;
};
// "update_downwards" was not in Nel215 implementation, added because of Python implementation
// Later changed from "update_downwards" to "update_leaf"
// self.update_downwards(node_idx);
self.nodes[node_idx].update_leaf(x, y);
}
Expand Down Expand Up @@ -641,10 +611,10 @@ impl<F: FType> MondrianTreeClassifier<F> {
/// Function in River/LightRiver: "learn_one()"
pub fn partial_fit(&mut self, x: &Array1<F>, y: usize) {
self.root = match self.root {
None => Some(self.create_node(x, y, None, F::zero(), true)),
None => Some(self.create_leaf(x, y, None, F::zero())),
Some(root_idx) => Some(self.go_downwards(root_idx, x, y)),
};
println!("partial_fit() tree post {}===========", self);
// println!("partial_fit() tree post {}===========", self);
}

fn fit(&self) {
Expand All @@ -671,7 +641,7 @@ impl<F: FType> MondrianTreeClassifier<F> {
let eta = dist_min.sum() + dist_max.sum();
F::one() - (-d * eta).exp()
};
debug_assert!(!p.is_nan(), "Found probability of splitting NaN. This is probably because range_max and range_min are [inf, inf]");
debug_assert!(!p.is_nan(), "Found probability of splitting NaN. This is probably because range_max and range_min are [inf, inf].");

// Generate a result for the current node using its statistics.
let res = node.stats.create_result(x, p_not_separated_yet * p);
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/synthetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub struct Synthetic;
impl Synthetic {
pub fn load_data() -> IterCsv<f32, File> {
let url = "https://marcodifrancesco.com/assets/img/LightRiver/syntetic_dataset.csv";
let file_name = "syntetic_dataset_v2.1.csv";
let file_name = "syntetic_dataset_v2.csv";
if !Path::new(file_name).exists() {
utils::download_csv_file(url, file_name);
}
Expand Down

0 comments on commit 3544c28

Please sign in to comment.