From 3e05e15533d68d03bee18d0c9b7a634e916ff03f Mon Sep 17 00:00:00 2001 From: Max Halford Date: Mon, 6 Nov 2023 21:23:07 +0100 Subject: [PATCH] start benchmarks --- Cargo.toml | 6 +++- README.md | 2 +- benches/hst.rs | 59 ++++++++++++++++++++++++++++++++++ src/anomaly/half_space_tree.rs | 17 ++++++++++ 4 files changed, 82 insertions(+), 2 deletions(-) create mode 100644 benches/hst.rs diff --git a/Cargo.toml b/Cargo.toml index 8bf56c1..447bc54 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ rand = "0.8.5" time = "0.3.29" [dev-dependencies] -criterion = { version = "0.4", features = ["html_reports"] } +criterion = { version = "0.5", features = ["html_reports"] } [profile.dev] opt-level = 0 @@ -27,3 +27,7 @@ opt-level = 3 [[example]] name = "credit_card" path = "examples/anomaly_detection/credit_card.rs" + +[[bench]] +name = "hst" +harness = false diff --git a/README.md b/README.md index 3b239a8..e019ca0 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ -Reed is an online machine learning library written in Rust. It is meant to be used in high-throughput environments, as well as TinyML systems. +LightRiver is an online machine learning library written in Rust. It is meant to be used in high-throughput environments, as well as TinyML systems. This library is complementary to [River](https://github.com/online-ml/river/). The latter provides a wide array of online methods, but is not ideal when it comes to performance. The idea is to take the algorithms that work best in River, and implement them in a way that is more performant. As such, LightRiver is not meant to be a general purpose library. It is meant to be a fast online machine learning library that provides a few algorithms that are known to work well in online settings. This is a akin to the way [scikit-learn](https://scikit-learn.org/) and [LightGBM](https://lightgbm.readthedocs.io/en/stable/) are complementary to each other. diff --git a/benches/hst.rs b/benches/hst.rs new file mode 100644 index 0000000..a080ced --- /dev/null +++ b/benches/hst.rs @@ -0,0 +1,59 @@ +use criterion::{criterion_group, criterion_main, Criterion, Throughput}; +use light_river::anomaly::half_space_tree::HalfSpaceTree; + +fn creation(c: &mut Criterion) { + let mut group = c.benchmark_group("creation"); + + let features: Vec = vec![ + String::from("V1"), + String::from("V2"), + String::from("V3"), + String::from("V4"), + String::from("V5"), + String::from("V6"), + String::from("V7"), + String::from("V8"), + String::from("V9"), + String::from("V10"), + String::from("V11"), + String::from("V12"), + String::from("V13"), + String::from("V14"), + String::from("V15"), + String::from("V16"), + String::from("V17"), + String::from("V18"), + String::from("V19"), + String::from("V20"), + String::from("V21"), + String::from("V22"), + String::from("V23"), + String::from("V24"), + String::from("V25"), + String::from("V26"), + String::from("V27"), + String::from("V28"), + String::from("V29"), + String::from("V30"), + ]; + + for height in [2, 6, 10, 14].iter() { + for n_trees in [3, 30, 300].iter() { + let input = (*height, *n_trees); + // Calculate the throughput based on the provided formula + let throughput = ((2u32.pow(*height) - 1) * *n_trees) as u64; + group.throughput(Throughput::Elements(throughput)); + group.bench_with_input( + format!("height={}-n_trees={}", height, n_trees), + &input, + |b, &input| { + b.iter(|| HalfSpaceTree::new(0, input.1, input.0, Some(features.clone()))); + }, + ); + } + } + group.finish(); +} + +criterion_group!(benches, creation); +criterion_main!(benches); diff --git a/src/anomaly/half_space_tree.rs b/src/anomaly/half_space_tree.rs index 21b60ac..550c9e6 100644 --- a/src/anomaly/half_space_tree.rs +++ b/src/anomaly/half_space_tree.rs @@ -200,3 +200,20 @@ impl HalfSpaceTree { self.update(observation, true, false) } } + +mod tests { + use super::*; + #[test] + fn test_left_child() { + let node = 42; + let child = left_child(node); + assert_eq!(child, 85); + } + + #[test] + fn test_right_child() { + let node = 42; + let child = right_child(node); + assert_eq!(child, 86); + } +}