Skip to content

Commit

Permalink
Adapt _go_downwards from River
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoDiFrancesco committed May 3, 2024
1 parent 717161f commit a9ca4bc
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 40 deletions.
1 change: 1 addition & 0 deletions examples/classification/synthetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ fn main() {
let score = mf.score(&x_ord, &y);
// println!("=M=3 score: {:?}", score);
score_total += score;

println!(
"{score_total} / {idx} = {}",
score_total / idx.to_f32().unwrap()
Expand Down
19 changes: 10 additions & 9 deletions src/classification/mondrian_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pub struct Node<F> {
pub is_leaf: bool,
pub min_list: Array1<F>, // Lists representing the minimum and maximum values of the data points contained in the current node
pub max_list: Array1<F>,
pub delta: usize, // Dimension in which a split occurs (?)
pub delta: usize, // Dimension in which a split occurs
pub xi: F, // Split point along the dimension specified by delta
pub left: Option<usize>, // Option<Rc<RefCell<Node<F>>>>,
pub right: Option<usize>, // Option<Rc<RefCell<Node<F>>>>,
Expand All @@ -41,19 +41,20 @@ impl<F: FType> Node<F> {
self.stats.add(x, label_idx);
}
pub fn update_internal(&self, left_s: &Stats<F>, right_s: &Stats<F>) -> Stats<F> {
// match (left_s, right_s) {
// (Some(left), Some(right)) => left.merge(right),
// (None, Some(right)) => unimplemented!("uncomment the following"), // right.clone(),
// (Some(left), None) => unimplemented!("uncomment the following"), // left.clone(),
// (None, None) => unimplemented!(
// "Both left and right stats are None. Should I return simply 'self.stats'?"
// ),
// }
left_s.merge(right_s)
}
pub fn get_parent_tau(&self) -> F {
panic!("Implemented in 'mondrian_tree' instead of 'mondrian_node'")
}
/// Check if all the labels are the same in the node.
/// e.g. y=2, stats.counts=[0, 1, 10] -> False
/// e.g. y=2, stats.counts=[0, 0, 10] -> True
/// e.g. y=1, stats.counts=[0, 0, 10] -> False
///
/// From: River function
pub fn is_dirac(&self, y: usize) -> bool {
return self.stats.counts.sum() == self.stats.counts[y];
}
}

/// Stats assocociated to one node
Expand Down
106 changes: 76 additions & 30 deletions src/classification/mondrian_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::classification::mondrian_node::{Node, Stats};
use crate::common::{ClassifierOutput, ClassifierTarget, Observation};
use crate::stream::data_stream::Data;
use core::iter::zip;
use ndarray::array;
use ndarray::{array, Array};
use ndarray::{Array1, Array2};
use ndarray::{ArrayBase, Dim, ScalarOperand, ViewRepr};
use num::pow::Pow;
Expand Down Expand Up @@ -87,13 +87,19 @@ impl<F: FType> MondrianTree<F> {
}
}

fn create_leaf(&mut self, x: &Array1<F>, label: &String, parent: Option<usize>) -> usize {
fn create_leaf(
&mut self,
x: &Array1<F>,
label_idx: usize,
parent: Option<usize>,
tau: F,
) -> usize {
let num_labels = self.labels.len();
let feature_dim = self.features.len();

let mut node = Node::<F> {
parent,
tau: F::from(1e9).unwrap(), // Very large value
tau, // F::from(1e9).unwrap(), // Very large value
is_leaf: true,
min_list: x.clone(),
max_list: x.clone(),
Expand All @@ -104,7 +110,6 @@ impl<F: FType> MondrianTree<F> {
stats: Stats::new(num_labels, feature_dim),
};

let label_idx = self.labels.clone().iter().position(|l| l == label).unwrap();
node.update_leaf(x, label_idx);
self.nodes.push(node);
let node_idx = self.nodes.len() - 1;
Expand Down Expand Up @@ -147,11 +152,47 @@ impl<F: FType> MondrianTree<F> {
);
}

// fn range_extension(&self, )
fn compute_split_time(
&self,
tau: F,
exp_sample: F,
node_idx: usize,
label_idx: usize,
extensions_sum: F,
) -> F {
if self.nodes[node_idx].is_dirac(label_idx) {
println!("extend_mondrian_block()/_go_downwards() - node: {node_idx} - extensions_sum: {:?} - all same class", extensions_sum);
return F::zero();
}

if extensions_sum > F::zero() {
let split_time = tau + exp_sample;

// From River: If the node is a leaf we must split it
if self.nodes[node_idx].is_leaf {
println!("extend_mondrian_block()/_go_downwards() - node: {node_idx} - extensions_sum: {:?} - split is_leaf", extensions_sum);
return split_time;
}

// From River: Otherwise we apply Mondrian process dark magic :)
// 1. We get the creation time of the childs (left and right is the same)
let child_idx = self.nodes[node_idx].left.unwrap();
let child_time = self.nodes[child_idx].tau;
// 2. We check if splitting time occurs before child creation time
if split_time < child_time {
println!("extend_mondrian_block()/_go_downwards() - node: {node_idx} - extensions_sum: {:?} - split mid tree", extensions_sum);
// Go to next child????
return split_time;
}
println!("extend_mondrian_block()/_go_downwards() - node: {node_idx} - extensions_sum: {:?} - not increased enough to split (mid node)", extensions_sum);
} else {
println!("extend_mondrian_block()/_go_downwards() - node: {node_idx} - extensions_sum: {:?} - not outside box", extensions_sum);
}

F::zero()
}

fn extend_mondrian_block(&mut self, node_idx: usize, x: &Array1<F>, label: &String) -> usize {
// Collect necessary values for computations
let parent_tau = self.get_parent_tau(node_idx);
fn extend_mondrian_block(&mut self, node_idx: usize, x: &Array1<F>, label_idx: usize) -> usize {
// tau is 'node.time' in
let tau = self.nodes[node_idx].tau;
// TODO: 'node_min_list' and 'node_max_list' be accessible without cloning
Expand All @@ -160,8 +201,9 @@ impl<F: FType> MondrianTree<F> {

let e_min = (&node_min_list - x).mapv(|v| F::max(v, F::zero()));
let e_max = (x - &node_max_list).mapv(|v| F::max(v, F::zero()));
// Extensions sum: size of the box [x_size, y_size]
// Extensions sum: size of the box
let e_sum = &e_min + &e_max;

// TODO: epsilon is used in nel215 code, but not River. Check if it's useful.
// let lambda = e_sum.sum() + F::epsilon();
// In nel215 lambda is 'rate'
Expand All @@ -172,24 +214,22 @@ impl<F: FType> MondrianTree<F> {
// DEBUG: shadowing with Exp expected value
let exp_sample = F::one() / lambda;

let split_time_rust = tau - (parent_tau + exp_sample);
let split_time_river = tau + exp_sample;
println!("extend_mondrian_block()/_go_downwards() - node: {node_idx} - extensions_sum: {:?}, split_time_rust: {:?}, split_time_river: {:?}", e_sum.sum(), split_time_rust, split_time_river);

if parent_tau + exp_sample < tau {
// We split the current node: because the current node is a
// leaf, or because we add a new node along the path

let split_time = self.compute_split_time(tau, exp_sample, node_idx, label_idx, e_sum.sum());
// println!("extend_mondrian_block - post compute_split_time() - split_time: {:?}", split_time);
if split_time > F::zero() {
// We split the current node: if leaf we add children, otherwise we add a new node along the path
let cumsum = e_sum
.iter()
.scan(F::zero(), |acc, &x| {
*acc = *acc + x;
Some(*acc)
})
.collect::<Array1<F>>();
println!("e_sum: {:?}, cumsum: {:?}", e_sum.to_vec(), cumsum.to_vec());

let e_sample = F::from_f32(self.rng.gen::<f32>()).unwrap() * e_sum.sum();
// DEBUG: shadowing with expected value
// let e_sample = F::from_f32(0.5).unwrap() * e_sum.sum();
let e_sample = F::from_f32(0.5).unwrap() * e_sum.sum();
let delta = cumsum.iter().position(|&val| val > e_sample).unwrap_or(0);

let (lower_bound, upper_bound) = if x[delta] > node_min_list[delta] {
Expand All @@ -205,7 +245,7 @@ impl<F: FType> MondrianTree<F> {
};
let xi = F::from_f32(self.rng.gen_range(lower_bound..upper_bound)).unwrap();
// DEBUG: setting expected value
// let xi = F::from_f32((lower_bound + upper_bound) / 2.0).unwrap();
let xi = F::from_f32((lower_bound + upper_bound) / 2.0).unwrap();

let mut min_list = node_min_list;
let mut max_list = node_max_list;
Expand All @@ -215,7 +255,7 @@ impl<F: FType> MondrianTree<F> {
// Create and push new parent node
let parent_node = Node {
parent: self.nodes[node_idx].parent,
tau: parent_tau + exp_sample,
tau: self.nodes[node_idx].tau,
is_leaf: false,
min_list,
max_list,
Expand All @@ -228,7 +268,7 @@ impl<F: FType> MondrianTree<F> {

self.nodes.push(parent_node);
let parent_idx = self.nodes.len() - 1;
let sibling_idx = self.create_leaf(x, label, Some(parent_idx));
let sibling_idx = self.create_leaf(x, label_idx, Some(parent_idx), split_time);

// Set the children appropriately
if x[delta] <= xi {
Expand All @@ -244,6 +284,7 @@ impl<F: FType> MondrianTree<F> {
}

self.nodes[node_idx].parent = Some(parent_idx);
self.nodes[node_idx].tau = split_time;

self.update_internal(parent_idx);

Expand All @@ -252,29 +293,34 @@ impl<F: FType> MondrianTree<F> {
// No split, we just update the node and go to the next one

let node = &mut self.nodes[node_idx];
// println!("pre - node: {:?}, node range: ({:?}-{:?}), x: {:?}", node_idx, node.min_list.to_vec(), node.max_list.to_vec(), x.to_vec());
node.min_list.zip_mut_with(x, |a, b| *a = F::min(*a, *b));
node.max_list.zip_mut_with(x, |a, b| *a = F::max(*a, *b));
if !node.is_leaf {
// println!("post- node: {:?}, node range: ({:?}-{:?}), x: {:?}", node_idx, node.min_list.to_vec(), node.max_list.to_vec(), x.to_vec());

if node.is_leaf {
// println!("else - updating leaf");
node.update_leaf(x, label_idx);
} else {
// println!("else - updating non-leaf");
if x[node.delta] <= node.xi {
let node_left = node.left.unwrap();
let node_left_new = Some(self.extend_mondrian_block(node_left, x, label));
let node_left_new = Some(self.extend_mondrian_block(node_left, x, label_idx));
let node = &mut self.nodes[node_idx];
node.left = node_left_new;
} else {
let node_right = node.right.unwrap();
let node_right_new = Some(self.extend_mondrian_block(node_right, x, label));
let node_right_new = Some(self.extend_mondrian_block(node_right, x, label_idx));
let node = &mut self.nodes[node_idx];
node.right = node_right_new;
};
self.update_internal(node_idx);
} else {
let label_idx = self.labels.iter().position(|l| l == label).unwrap();
node.update_leaf(x, label_idx);
}
return node_idx;
}
}

/// Update 'node stats' by merging 'right child stats + left child stats'.
fn update_internal(&mut self, node_idx: usize) {
// In nel215 code update_internal is not called for the children, check if it's needed
let node = &self.nodes[node_idx];
Expand All @@ -291,9 +337,10 @@ impl<F: FType> MondrianTree<F> {
///
/// Function in River/LightRiver: "learn_one()"
pub fn partial_fit(&mut self, x: &Array1<F>, y: &String) {
let label_idx = self.labels.clone().iter().position(|l| l == y).unwrap();
self.root = match self.root {
None => Some(self.create_leaf(x, y, None)),
Some(root_idx) => Some(self.extend_mondrian_block(root_idx, x, y)),
None => Some(self.create_leaf(x, label_idx, None, F::zero())),
Some(root_idx) => Some(self.extend_mondrian_block(root_idx, x, label_idx)),
};
println!("partial_fit() tree post {}", self);
}
Expand All @@ -306,7 +353,6 @@ impl<F: FType> MondrianTree<F> {
///
/// Recursive function to predict probabilities.
fn predict(&self, x: &Array1<F>, node_idx: usize, p_not_separated_yet: F) -> Array1<F> {
// println!("predict() - tree {}", self);
let node = &self.nodes[node_idx];

// Step 1: Calculate the time delta from the parent node.
Expand Down
3 changes: 2 additions & 1 deletion src/datasets/synthetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ use std::{fs::File, path::Path};
pub struct Synthetic;
impl Synthetic {
pub fn load_data() -> Result<IterCsv<f32, File>, Box<dyn std::error::Error>> {
let file_name = "syntetic_dataset_paper.csv";
// let file_name = "syntetic_dataset_paper.csv";
let file_name = "syntetic_dataset_int.csv";
let file = File::open(file_name)?;
let y_cols = Some(Target::Name("label".to_string()));
match IterCsv::<f32, File>::new(file, y_cols) {
Expand Down

0 comments on commit a9ca4bc

Please sign in to comment.