Skip to content

Commit

Permalink
Add update_leaf flag to create_leaf
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoDiFrancesco committed May 13, 2024
1 parent c4753f1 commit a08f922
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 12 deletions.
14 changes: 14 additions & 0 deletions src/classification/mondrian_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,20 @@ pub struct Node<F> {
pub right: Option<usize>,
pub stats: Stats<F>,
}
impl<F: FType + fmt::Display> fmt::Display for Node<F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Node<time={:.3}, min={:?}, max={:?}, counts={:?}>",
self.time,
self.min_list.to_vec(),
self.max_list.to_vec(),
self.stats.counts.to_vec(),
)?;
Ok(())
}
}

impl<F: FType> Node<F> {
pub fn update_leaf(&mut self, x: &Array1<F>, y: usize) {
self.stats.add(x, y);
Expand Down
51 changes: 39 additions & 12 deletions src/classification/mondrian_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,16 @@ impl<F: FType + fmt::Display> MondrianTreeClassifier<F> {
// prefix, idx, node.left, node.right, node.parent, node.time, node.is_leaf, node.min_list.to_vec(), node.max_list.to_vec())?;
writeln!(
f,
"{}├─Node {}: left={:?}, right={:?}, parent={:?}, time={:.3}, min={:?}, max={:?}",
"{}├─Node {}: left={:?}, right={:?}, parent={:?}, time={:.3}, min={:?}, max={:?}, counts={:?}",
prefix,
idx,
node.left,
node.right,
node.parent,
node.time,
node.min_list.to_vec(),
node.max_list.to_vec()
node.max_list.to_vec(),
node.stats.counts.to_vec(),
)?;

self.recursive_repr(node.left, f, &(prefix.to_owned() + "│ "))?;
Expand All @@ -71,7 +72,14 @@ impl<F: FType> MondrianTreeClassifier<F> {
}
}

fn create_leaf(&mut self, x: &Array1<F>, y: usize, parent: Option<usize>, time: F) -> usize {
fn create_leaf(
&mut self,
x: &Array1<F>,
y: usize,
parent: Option<usize>,
time: F,
update_leaf: bool,
) -> usize {
let mut node = Node::<F> {
parent,
time, // F::from(1e9).unwrap(), // Very large value
Expand All @@ -84,8 +92,10 @@ impl<F: FType> MondrianTreeClassifier<F> {
right: None,
stats: Stats::new(self.n_labels, self.n_features),
};

node.update_leaf(x, y);
// TODO: check if 'update_leaf' is used
if update_leaf {
node.update_leaf(x, y);
}
self.nodes.push(node);
let node_idx = self.nodes.len() - 1;
node_idx
Expand Down Expand Up @@ -242,14 +252,31 @@ impl<F: FType> MondrianTreeClassifier<F> {

self.nodes.push(parent_node);
let parent_idx = self.nodes.len() - 1;
let sibling_idx = self.create_leaf(x, y, Some(parent_idx), split_time);
let sibling_idx = self.create_leaf(x, y, Some(parent_idx), split_time, true);

// println!(
// "sibling_idx: {}, sibling: {}",
// sibling_idx, self.nodes[sibling_idx]
// );

// From:
// ┌─Node 1: left=Some(0), right=Some(2), parent=None, time=0.000, min=[0.2, 0.2], max=[0.4, 0.3], counts=[1, 1, 0]
// │ ├─Node 0: left=None, right=None, parent=Some(1), time=1.228, min=[0.2, 0.2], max=[0.2, 0.2], counts=[1, 0, 0]
// │ ├─Node 2: left=None, right=None, parent=Some(1), time=1.228, min=[0.4, 0.3], max=[0.4, 0.3], counts=[0, 1, 0]
// To:
// ┌─Node 1: left=Some(0), right=Some(3), parent=None, time=0.000, min=[0.2, 0.2], max=[0.8, 0.8], counts=[1, 1, 0]
// │ ├─Node 0: left=None, right=None, parent=Some(1), time=1.228, min=[0.2, 0.2], max=[0.2, 0.2], counts=[1, 0, 0]
// │ ├─Node 3: left=Some(2), right=Some(4), parent=Some(1), time=1.228, min=[0.4, 0.3], max=[0.8, 0.8], counts=[0, 1, 1]
// │ │ ├─Node 2: left=None, right=None, parent=Some(3), time=1.738, min=[0.4, 0.3], max=[0.4, 0.3], counts=[0, 1, 0]
// │ │ ├─Node 4: left=None, right=None, parent=Some(3), time=1.738, min=[0.8, 0.8], max=[0.8, 0.8], counts=[0, 0, 1]

// Node 1. Grandpa: self.nodes[node_idx].parent
// Node 3. (new) Parent: parent_idx
// Node 2. Child: node_idx
// Node 4. (new) Sibling: sibling_idx

// Set the children appropriately
if x[feature] <= threshold {
// Grandpa: self.nodes[node_idx].parent
// (new) Parent: parent_idx
// Child: node_idx
// (new) Sibling: sibling_idx
self.nodes[parent_idx].left = Some(sibling_idx);
self.nodes[parent_idx].right = Some(node_idx);
} else {
Expand Down Expand Up @@ -313,10 +340,10 @@ impl<F: FType> MondrianTreeClassifier<F> {
/// Function in River/LightRiver: "learn_one()"
pub fn partial_fit(&mut self, x: &Array1<F>, y: usize) {
self.root = match self.root {
None => Some(self.create_leaf(x, y, None, F::zero())),
None => Some(self.create_leaf(x, y, None, F::zero(), true)),
Some(root_idx) => Some(self.go_downwards(root_idx, x, y)),
};
// println!("partial_fit() tree post {}", self);
println!("partial_fit() tree post {}", self);
}

fn fit(&self) {
Expand Down

0 comments on commit a08f922

Please sign in to comment.