Skip to content

Commit

Permalink
Update readme with classification run instructions
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoDiFrancesco committed May 6, 2024
1 parent 85030ad commit c4753f1
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 20 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ target/
/.vscode/
# Local configuration
.cargo/config.toml
/.venv*/
/.venv*/
generate_data_synthetic.py
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ cargo run --release --example credit_card

### 📊 Classification

🏗️ We plan to implement Aggregated Mondrian Forests.
```sh
RUSTFLAGS=-Awarnings cargo run --release --example synthetic
```

### 🛒 Recsys

Expand Down
13 changes: 10 additions & 3 deletions examples/classification/synthetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ fn get_labels(transactions: IterCsv<f32, File>) -> Vec<String> {
}

fn main() {
let now = Instant::now();
let n_trees: usize = 1;

let transactions_f = Synthetic::load_data();
Expand All @@ -47,6 +46,8 @@ fn main() {
MondrianForestClassifier::new(n_trees, features.len(), labels.len());
let mut score_total = 0.0;

let now = Instant::now();

let transactions = Synthetic::load_data();
for (idx, transaction) in transactions.enumerate() {
let data = transaction.unwrap();
Expand All @@ -71,13 +72,19 @@ fn main() {
// println!("=M=3 score: {:?}", score);
score_total += score;

// println!(
// "{score_total} / {idx} = {}",
// score_total / idx.to_f32().unwrap()
// );
}
if idx == 100_000 - 1 {
println!(
"{score_total} / {idx} = {}",
"Accuracy: {score_total} / {idx} = {}",
score_total / idx.to_f32().unwrap()
);
}

println!("=M=1 partial_fit {x_ord}");
// println!("=M=1 partial_fit {x_ord}");
mf.partial_fit(&x_ord, y);
}

Expand Down
12 changes: 0 additions & 12 deletions generate_data.py

This file was deleted.

6 changes: 3 additions & 3 deletions src/classification/mondrian_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ impl<F: FType> MondrianTreeClassifier<F> {
let exp_dist = Exp::new(lambda.to_f32().unwrap()).unwrap();
let exp_sample = F::from_f32(exp_dist.sample(&mut self.rng)).unwrap();
// DEBUG: shadowing with Exp expected value
let exp_sample = F::one() / lambda;
// let exp_sample = F::one() / lambda;
exp_sample
};
let split_time = self.compute_split_time(time, exp_sample, node_idx, y, extensions.sum());
Expand All @@ -202,7 +202,7 @@ impl<F: FType> MondrianTreeClassifier<F> {
.collect::<Array1<F>>();
let e_sample = F::from_f32(self.rng.gen::<f32>()).unwrap() * extensions.sum();
// DEBUG: shadowing with expected value
let e_sample = F::from_f32(0.5).unwrap() * extensions.sum();
// let e_sample = F::from_f32(0.5).unwrap() * extensions.sum();
cumsum.iter().position(|&val| val > e_sample).unwrap()
};

Expand All @@ -219,7 +219,7 @@ impl<F: FType> MondrianTreeClassifier<F> {
};
let threshold = F::from_f32(self.rng.gen_range(lower_bound..upper_bound)).unwrap();
// DEBUG: split in the middle
let threshold = F::from_f32((lower_bound + upper_bound) / 2.0).unwrap();
// let threshold = F::from_f32((lower_bound + upper_bound) / 2.0).unwrap();

let mut min_list = node_min_list.clone();
let mut max_list = node_max_list.clone();
Expand Down

0 comments on commit c4753f1

Please sign in to comment.