Skip to content

Commit

Permalink
Add assert to check for NaN probability
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoDiFrancesco committed May 24, 2024
1 parent 0217db2 commit 1e5a874
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 8 deletions.
4 changes: 4 additions & 0 deletions examples/classification/synthetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ fn main() {

// println!("=M=1 partial_fit {x}");
mf.partial_fit(&x, y);

// if idx == 163 {
// break;
// }
}

let elapsed_time = now.elapsed();
Expand Down
21 changes: 14 additions & 7 deletions src/classification/mondrian_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,14 +345,20 @@ impl<F: FType> MondrianTreeClassifier<F> {
self.nodes[parent_idx].stats = self.nodes[node_idx].stats.clone();
self.nodes[node_idx].parent = Some(parent_idx);

self.nodes[node_idx].time = split_time;
// From Python: reset child
self.nodes[node_idx].range_min = Array1::from_elem(self.n_features, F::infinity());
self.nodes[node_idx].range_max = Array1::from_elem(self.n_features, -F::infinity());
self.nodes[node_idx].stats = Stats::new(self.n_labels, self.n_features);

// self.nodes[node_idx].time = split_time;

// I'm 0% sure if this "if" is required.
if self.nodes[node_idx].is_leaf {
// {
// From River: reset child
self.nodes[node_idx].range_min =
Array1::from_elem(self.n_features, F::infinity());
self.nodes[node_idx].range_max =
Array1::from_elem(self.n_features, -F::infinity());
self.nodes[node_idx].stats = Stats::new(self.n_labels, self.n_features);
}
// self.update_downwards(parent_idx);
// From Python: added "update_leaf" after "update_downwards"
// From River: added "update_leaf" after "update_downwards"
self.nodes[parent_idx].update_leaf(x, y);
return parent_idx;
}
Expand Down Expand Up @@ -440,6 +446,7 @@ impl<F: FType> MondrianTreeClassifier<F> {
let eta = dist_min.sum() + dist_max.sum();
F::one() - (-d * eta).exp()
};
debug_assert!(!p.is_nan(), "Found probability of splitting NaN. This is probably because range_max and range_min are [inf, inf]");

// Generate a result for the current node using its statistics.
let res = node.stats.create_result(x, p_not_separated_yet * p);
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/synthetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub struct Synthetic;
impl Synthetic {
pub fn load_data() -> IterCsv<f32, File> {
let url = "https://marcodifrancesco.com/assets/img/LightRiver/syntetic_dataset.csv";
let file_name = "syntetic_dataset_v2.csv";
let file_name = "syntetic_dataset_v2.1.csv";
if !Path::new(file_name).exists() {
utils::download_csv_file(url, file_name);
}
Expand Down

0 comments on commit 1e5a874

Please sign in to comment.