diff --git a/examples/classification/synthetic.rs b/examples/classification/synthetic.rs index 675a590..49a6d36 100644 --- a/examples/classification/synthetic.rs +++ b/examples/classification/synthetic.rs @@ -89,6 +89,10 @@ fn main() { // println!("=M=1 partial_fit {x}"); mf.partial_fit(&x, y); + + // if idx == 163 { + // break; + // } } let elapsed_time = now.elapsed(); diff --git a/src/classification/mondrian_tree.rs b/src/classification/mondrian_tree.rs index e20bb33..c8dd7eb 100644 --- a/src/classification/mondrian_tree.rs +++ b/src/classification/mondrian_tree.rs @@ -345,14 +345,20 @@ impl MondrianTreeClassifier { self.nodes[parent_idx].stats = self.nodes[node_idx].stats.clone(); self.nodes[node_idx].parent = Some(parent_idx); - self.nodes[node_idx].time = split_time; - // From Python: reset child - self.nodes[node_idx].range_min = Array1::from_elem(self.n_features, F::infinity()); - self.nodes[node_idx].range_max = Array1::from_elem(self.n_features, -F::infinity()); - self.nodes[node_idx].stats = Stats::new(self.n_labels, self.n_features); - + // self.nodes[node_idx].time = split_time; + + // I'm 0% sure if this "if" is required. + if self.nodes[node_idx].is_leaf { + // { + // From River: reset child + self.nodes[node_idx].range_min = + Array1::from_elem(self.n_features, F::infinity()); + self.nodes[node_idx].range_max = + Array1::from_elem(self.n_features, -F::infinity()); + self.nodes[node_idx].stats = Stats::new(self.n_labels, self.n_features); + } // self.update_downwards(parent_idx); - // From Python: added "update_leaf" after "update_downwards" + // From River: added "update_leaf" after "update_downwards" self.nodes[parent_idx].update_leaf(x, y); return parent_idx; } @@ -440,6 +446,7 @@ impl MondrianTreeClassifier { let eta = dist_min.sum() + dist_max.sum(); F::one() - (-d * eta).exp() }; + debug_assert!(!p.is_nan(), "Found probability of splitting NaN. This is probably because range_max and range_min are [inf, inf]"); // Generate a result for the current node using its statistics. let res = node.stats.create_result(x, p_not_separated_yet * p); diff --git a/src/datasets/synthetic.rs b/src/datasets/synthetic.rs index 9a4be41..5418dd9 100644 --- a/src/datasets/synthetic.rs +++ b/src/datasets/synthetic.rs @@ -12,7 +12,7 @@ pub struct Synthetic; impl Synthetic { pub fn load_data() -> IterCsv { let url = "https://marcodifrancesco.com/assets/img/LightRiver/syntetic_dataset.csv"; - let file_name = "syntetic_dataset_v2.csv"; + let file_name = "syntetic_dataset_v2.1.csv"; if !Path::new(file_name).exists() { utils::download_csv_file(url, file_name); }