diff --git a/.gitignore b/.gitignore
index 4f3398f..a3c97aa 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,4 +1,5 @@
-# Mini river
Cargo.lock
-
-target/
\ No newline at end of file
+.DS_Store
+target/
+*.csv
+*.zip
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 0000000..5e0108a
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,31 @@
+## Benchmarking
+
+```sh
+cargo bench --bench credit_card
+```
+
+## Changelog
+
+### 2023-10-04
+
+We test with:
+
+- 50 trees
+- Tree height of 6
+- Window size of 1000
+
+The Python baseline runs in **~60 seconds** using Python 3.11 on MacOS. It uses the classic left/right class-based implementation.
+
+We coded a first array based implementation in Rust. It runs in **~6 seconds**. Each tree is a struct. Each struct contains one array for each node attribute. We wonder if we can do better by storing all attributes in a matrix.
+
+ROC AUC appears roughly similar between the Python and Rust implementations. Note that we didn't activate min-max scaling in both cases.
+
+### 2023-10-05
+
+- Using `with_capacity` on each `Vec` in `HST`, as well as the list of HSTs, we gain 1 second. We are now at **~5 seconds**.
+- We can't find a nice profiler. So for now we comment code and measure time.
+- Storing all attributes in a single array, instead of one array per tree, makes us reach **~3 seconds**.
+- We removed the CSV logic from the benchmark, which brings us under **~2.5 second**.
+- Fixing some algorithmic issues actually brings us to **~5 seconds** :(
+- We tried using rayon to parallelize over trees, but it didn't bring any improvement whatsoever. Maybe we used it wrong, but we believe its because our loop is too cheap to be worth the overhead of spawning threads -- or whatever it is rayon does.
+- There is an opportunity to do the scoring and update logic in one fell swoop. This is because of the nature of online anomaly detection. This would bring us to **~2.5 seconds**. We are not sure if this is a good design choice though, so we may revisit this later.
diff --git a/Cargo.toml b/Cargo.toml
index f339741..cc99797 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -6,7 +6,6 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
-
csv = "1.2.0"
num = "0.4.0"
tempfile = "3.4.0"
@@ -14,3 +13,22 @@ maplit = "1.0.2"
reqwest = { version = "0.11.4", features = ["blocking"] }
zip = "0.6.4"
rand = "0.8.5"
+time = "0.3.29"
+half = "2.3.1"
+
+[dev-dependencies]
+criterion = { version = "0.5", features = ["html_reports"] }
+
+[profile.dev]
+opt-level = 0
+
+[profile.release]
+opt-level = 3
+
+[[example]]
+name = "credit_card"
+path = "examples/anomaly_detection/credit_card.rs"
+
+[[bench]]
+name = "hst"
+harness = false
diff --git a/README.md b/README.md
index 6b6bcf8..e019ca0 100644
--- a/README.md
+++ b/README.md
@@ -1 +1,58 @@
-# Mini river
\ No newline at end of file
+
🦀 LightRiver • fast and simple online machine learning
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+[![Discord](https://dcbadge.vercel.app/api/server/qNmrKEZMAn)](https://discord.gg/qNmrKEZMAn)
+
+
+
+
+
+LightRiver is an online machine learning library written in Rust. It is meant to be used in high-throughput environments, as well as TinyML systems.
+
+This library is complementary to [River](https://github.com/online-ml/river/). The latter provides a wide array of online methods, but is not ideal when it comes to performance. The idea is to take the algorithms that work best in River, and implement them in a way that is more performant. As such, LightRiver is not meant to be a general purpose library. It is meant to be a fast online machine learning library that provides a few algorithms that are known to work well in online settings. This is a akin to the way [scikit-learn](https://scikit-learn.org/) and [LightGBM](https://lightgbm.readthedocs.io/en/stable/) are complementary to each other.
+
+## 🧑💻 Usage
+
+### 🚨 Anomaly detection
+
+```sh
+cargo run --release --example credit_card
+```
+
+### 📈 Regression
+
+🏗️ We plan to implement Aggregated Mondrian Forests.
+
+### 📊 Classification
+
+🏗️ We plan to implement Aggregated Mondrian Forests.
+
+### 🛒 Recsys
+
+🏗️ [Vowpal Wabbit](https://vowpalwabbit.org/) is very good at recsys via contextual bandits. We don't plan to compete with it. Eventually we want to research a tree-based contextual bandit.
+
+## 🚀 Performance
+
+TODO: add a `benches` directory
+
+## 📝 License
+
+LightRiver is free and open-source software licensed under the [3-clause BSD license](LICENSE).
diff --git a/benches/hst.rs b/benches/hst.rs
new file mode 100644
index 0000000..a080ced
--- /dev/null
+++ b/benches/hst.rs
@@ -0,0 +1,59 @@
+use criterion::{criterion_group, criterion_main, Criterion, Throughput};
+use light_river::anomaly::half_space_tree::HalfSpaceTree;
+
+fn creation(c: &mut Criterion) {
+ let mut group = c.benchmark_group("creation");
+
+ let features: Vec = vec![
+ String::from("V1"),
+ String::from("V2"),
+ String::from("V3"),
+ String::from("V4"),
+ String::from("V5"),
+ String::from("V6"),
+ String::from("V7"),
+ String::from("V8"),
+ String::from("V9"),
+ String::from("V10"),
+ String::from("V11"),
+ String::from("V12"),
+ String::from("V13"),
+ String::from("V14"),
+ String::from("V15"),
+ String::from("V16"),
+ String::from("V17"),
+ String::from("V18"),
+ String::from("V19"),
+ String::from("V20"),
+ String::from("V21"),
+ String::from("V22"),
+ String::from("V23"),
+ String::from("V24"),
+ String::from("V25"),
+ String::from("V26"),
+ String::from("V27"),
+ String::from("V28"),
+ String::from("V29"),
+ String::from("V30"),
+ ];
+
+ for height in [2, 6, 10, 14].iter() {
+ for n_trees in [3, 30, 300].iter() {
+ let input = (*height, *n_trees);
+ // Calculate the throughput based on the provided formula
+ let throughput = ((2u32.pow(*height) - 1) * *n_trees) as u64;
+ group.throughput(Throughput::Elements(throughput));
+ group.bench_with_input(
+ format!("height={}-n_trees={}", height, n_trees),
+ &input,
+ |b, &input| {
+ b.iter(|| HalfSpaceTree::new(0, input.1, input.0, Some(features.clone())));
+ },
+ );
+ }
+ }
+ group.finish();
+}
+
+criterion_group!(benches, creation);
+criterion_main!(benches);
diff --git a/examples/anomaly_detection/credit_card.rs b/examples/anomaly_detection/credit_card.rs
new file mode 100644
index 0000000..b52c541
--- /dev/null
+++ b/examples/anomaly_detection/credit_card.rs
@@ -0,0 +1,41 @@
+use light_river::anomaly::half_space_tree::HalfSpaceTree;
+use light_river::common::ClassifierOutput;
+use light_river::common::ClassifierTarget;
+use light_river::datasets::credit_card::CreditCard;
+use light_river::metrics::rocauc::ROCAUC;
+use light_river::metrics::traits::ClassificationMetric;
+use light_river::stream::data_stream::DataStream;
+use light_river::stream::iter_csv::IterCsv;
+use std::fs::File;
+use std::time::Instant;
+
+fn main() {
+ let now = Instant::now();
+
+ // PARAMETERS
+ let window_size: u32 = 1000;
+ let n_trees: u32 = 50;
+ let height: u32 = 6;
+ let pos_val_metric = ClassifierTarget::from("1".to_string());
+ let pos_val_tree = pos_val_metric.clone();
+ let mut roc_auc: ROCAUC = ROCAUC::new(Some(10), pos_val_metric.clone());
+ // INITIALIZATION
+ let mut hst: HalfSpaceTree =
+ HalfSpaceTree::new(window_size, n_trees, height, None, Some(pos_val_tree));
+
+ // LOOP
+ let transactions: IterCsv = CreditCard::load_credit_card_transactions().unwrap();
+ for transaction in transactions {
+ let data = transaction.unwrap();
+ let observation = data.get_observation();
+ let label = data.to_classifier_target("Class").unwrap();
+ let score = hst.update(&observation, true, true).unwrap();
+ // println!("Label: {:?}", label);
+ // println!("Score: {:?}", score);
+ roc_auc.update(&score, &label, Some(1.));
+ }
+
+ let elapsed_time = now.elapsed();
+ println!("Took {}ms", elapsed_time.as_millis());
+ println!("ROCAUC: {:.2}%", roc_auc.get() * (100.0 as f32));
+}
diff --git a/measure_auc.py b/measure_auc.py
new file mode 100644
index 0000000..4478dc8
--- /dev/null
+++ b/measure_auc.py
@@ -0,0 +1,6 @@
+import pandas as pd
+from sklearn import metrics
+
+scores = pd.read_csv('scores.csv', names=['score'])['score']
+labels = pd.read_csv('creditcard.csv')['Class']
+print(f"{metrics.roc_auc_score(labels, -scores):.2%}")
diff --git a/python_baseline.py b/python_baseline.py
new file mode 100644
index 0000000..9f78906
--- /dev/null
+++ b/python_baseline.py
@@ -0,0 +1,21 @@
+from river import anomaly
+from river import datasets
+from time import time
+
+scores = []
+hst = anomaly.HalfSpaceTrees(
+ n_trees=50,
+ height=6,
+ window_size=1000,
+)
+dataset = [x for x, _ in datasets.CreditCard()]
+start = time()
+for x in dataset:
+ score = hst.score_one(x)
+ scores.append(score)
+ hst.learn_one(x)
+print(f"Time: {time() - start:.2f}s")
+
+with open('scores_py.csv', 'w') as f:
+ for score in scores:
+ f.write(f"{score}\n")
diff --git a/src/anomaly/half_space_tree.rs b/src/anomaly/half_space_tree.rs
new file mode 100644
index 0000000..62f08df
--- /dev/null
+++ b/src/anomaly/half_space_tree.rs
@@ -0,0 +1,292 @@
+// https://pastebin.com/ZLD6E5FT
+
+use rand::prelude::*;
+
+use num::{Float, FromPrimitive};
+use std::collections::HashMap;
+use std::convert::TryFrom;
+use std::mem;
+use std::ops::{AddAssign, DivAssign, MulAssign, SubAssign};
+
+use crate::common::{ClassifierOutput, ClassifierTarget, Observation};
+
+// Return the index of a node's left child node.
+#[inline]
+fn left_child(node: u32) -> u32 {
+ node * 2 + 1
+}
+
+// Return the index of a node's right child node.
+#[inline]
+fn right_child(node: u32) -> u32 {
+ node * 2 + 2
+}
+
+#[derive(Clone)]
+struct Trees {
+ feature: Vec,
+ threshold: Vec,
+ l_mass: Vec,
+ r_mass: Vec,
+}
+
+impl Trees {
+ fn new(n_trees: u32, height: u32, features: &Vec, rng: &mut ThreadRng) -> Self {
+ // #nodes = 2 ^ height - 1
+ let n_nodes: usize = usize::try_from(n_trees * (u32::pow(2, height) - 1)).unwrap();
+ // #branches = 2 ^ (height - 1) - 1
+ let n_branches = usize::try_from(n_trees * (u32::pow(2, height - 1) - 1)).unwrap();
+
+ // Helper function to create and populate a Vec with a given capacity
+ fn init_vec(capacity: usize, default_value: T) -> Vec
+ where
+ T: Clone,
+ {
+ let mut vec = Vec::with_capacity(capacity);
+ vec.resize(capacity, default_value);
+ vec
+ }
+
+ // Allocate memory for the HST
+ let mut hst = Trees {
+ feature: init_vec(n_branches, String::from("")),
+ threshold: init_vec(n_branches, F::zero()),
+ l_mass: init_vec(n_nodes, F::zero()),
+ r_mass: init_vec(n_nodes, F::zero()),
+ };
+
+ // Randomly assign features and thresholds to each branch
+ for branch in 0..n_branches {
+ let feature = features.choose(rng).unwrap();
+ hst.feature[branch] = feature.clone();
+ let random_threshold: f64 = rng.gen();
+ hst.threshold[branch] = F::from_f64(random_threshold).unwrap(); // [0, 1]
+ }
+ hst
+ }
+}
+/// Half-space trees are an online variant of isolation forests.
+/// They work well when anomalies are spread out.
+/// However, they do not work well if anomalies are packed together in windows.
+/// By default, this implementation assumes that each feature has values that are comprised
+/// between 0 and 1.
+/// # Parameters
+///
+/// - `window_size`: The number of observations to consider when computing the score.
+/// - `n_trees`: The number of trees to use.
+/// - `height`: The height of each tree.
+/// - `features`: The list of features to use. If `None`, the features will be inferred from the first observation.
+///
+/// # Example
+///
+/// ```
+///
+///
+///
+/// ```
+pub struct HalfSpaceTree {
+ window_size: u32,
+ counter: u32,
+ n_trees: u32,
+ height: u32,
+ features: Option>,
+ rng: ThreadRng,
+ n_branches: u32,
+ n_nodes: u32,
+ trees: Option>,
+ first_learn: bool,
+ pos_val: Option,
+}
+impl HalfSpaceTree {
+ pub fn new(
+ window_size: u32,
+ n_trees: u32,
+ height: u32,
+ features: Option>,
+ pos_val: Option,
+ // rng: ThreadRng,
+ ) -> Self {
+ // let mut rng = rand::thread_rng();
+ let n_branches = u32::pow(2, height - 1) - 1;
+ let n_nodes = u32::pow(2, height) - 1;
+
+ let features_clone = features.clone();
+ let mut rng = rand::thread_rng();
+ let trees = if let Some(features) = features {
+ Some(Trees::new(n_trees, height, &features, &mut rng))
+ } else {
+ None
+ };
+ HalfSpaceTree {
+ window_size: window_size,
+ counter: 0,
+ n_trees: n_trees,
+ height: height,
+ features: features_clone,
+ rng: rng,
+ n_branches: n_branches,
+ n_nodes: n_nodes,
+ trees: trees,
+ first_learn: false,
+ pos_val: pos_val,
+ }
+ }
+
+ pub fn update(
+ &mut self,
+ observation: &Observation,
+ do_score: bool,
+ do_update: bool,
+ ) -> Option> {
+ // build trees during the first pass
+ if (!self.first_learn) && self.features.is_none() {
+ self.features = Some(observation.clone().into_keys().collect());
+ self.trees = Some(Trees::new(
+ self.n_trees,
+ self.height,
+ &self.features.as_ref().unwrap(),
+ &mut self.rng,
+ ));
+ self.first_learn = true;
+ }
+
+ let mut score: F = F::zero();
+
+ for tree in 0..self.n_trees {
+ let mut node: u32 = 0;
+ for depth in 0..self.height {
+ // Update the score
+ let hst = &mut self.trees.as_mut().unwrap();
+
+ // Flag for scoring
+ if do_score {
+ score += hst.r_mass[(tree * self.n_nodes + node) as usize]
+ * F::from_u32(u32::pow(2, depth)).unwrap();
+ }
+
+ if do_update {
+ // Update the l_mass
+ hst.l_mass[(tree * self.n_nodes + node) as usize] += F::one();
+ }
+
+ // Stop if the node is a leaf or stop early if the mass of the node is too small
+ if depth == self.height - 1 {
+ break;
+ }
+
+ // Get the feature and threshold of the current node so that we can determine
+ // whether to go left or right
+ let feature = &hst.feature[(tree * self.n_branches + node) as usize];
+ let threshold = hst.threshold[(tree * self.n_branches + node) as usize];
+
+ // Get the value of the current feature
+ // node = self.walk(observation, node, tree, threshold, feature, hst);
+ node = match observation.get(feature) {
+ Some(value) => {
+ // Update the mass of the current node
+ if *value < threshold {
+ left_child(node)
+ } else {
+ right_child(node)
+ }
+ }
+ None => {
+ // If the feature is missing, go down both branches and select the node with the
+ // the biggest l_mass
+ if hst.l_mass[(tree * self.n_nodes + left_child(node)) as usize]
+ > hst.l_mass[(tree * self.n_nodes + right_child(node)) as usize]
+ {
+ left_child(node)
+ } else {
+ right_child(node)
+ }
+ }
+ };
+ }
+ }
+ if do_update {
+ // Pivot if the window is full
+ let hst = &mut self.trees.as_mut().unwrap();
+ self.counter += 1;
+ if self.counter == self.window_size {
+ mem::swap(&mut hst.r_mass, &mut hst.l_mass);
+ hst.l_mass.fill(F::zero());
+ self.counter = 0;
+ }
+ }
+ if do_score {
+ score = F::one() - (score / self.max_score());
+
+ return Some(ClassifierOutput::Probabilities(HashMap::from([(
+ ClassifierTarget::from(
+ self.pos_val.clone().unwrap_or(ClassifierTarget::from(true)),
+ ),
+ score,
+ )])));
+ // return Some(score);
+ }
+ return None;
+ }
+ pub fn learn_one(&mut self, observation: &Observation) {
+ self.update(observation, false, true);
+ }
+ pub fn score_one(&mut self, observation: &Observation) -> Option> {
+ self.update(observation, true, false)
+ }
+ fn max_score(&self) -> F {
+ F::from(self.n_trees).unwrap()
+ * F::from(self.window_size).unwrap()
+ * (F::from(2.).unwrap().powi(self.height as i32 + 1) - F::one())
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ use crate::datasets::credit_card::CreditCard;
+ use crate::stream::iter_csv::IterCsv;
+ use std::fs::File;
+ #[test]
+ fn test_hst() {
+ // PARAMETERS
+ let window_size: u32 = 1000;
+ let n_trees: u32 = 50;
+ let height: u32 = 6;
+
+ // INITIALIZATION
+ let mut hst: HalfSpaceTree = HalfSpaceTree::new(
+ window_size,
+ n_trees,
+ height,
+ None,
+ Some(ClassifierTarget::from("1".to_string())),
+ );
+
+ // LOOP
+ let transactions: IterCsv = CreditCard::load_credit_card_transactions().unwrap();
+ for transaction in transactions {
+ let data = transaction.unwrap();
+ let observation = data.get_observation();
+ let label = data.get_y().unwrap().get("Class").unwrap();
+ let _ = hst.update(&observation, true, true);
+ }
+ }
+}
+
+mod tests {
+ use super::*;
+ #[test]
+ fn test_left_child() {
+ let node = 42;
+ let child = left_child(node);
+ assert_eq!(child, 85);
+ }
+
+ #[test]
+ fn test_right_child() {
+ let node = 42;
+ let child = right_child(node);
+ assert_eq!(child, 86);
+ }
+}
diff --git a/src/anomaly/mod.rs b/src/anomaly/mod.rs
new file mode 100644
index 0000000..4762556
--- /dev/null
+++ b/src/anomaly/mod.rs
@@ -0,0 +1 @@
+pub mod half_space_tree;
diff --git a/src/common.rs b/src/common.rs
index 5c90812..6e0a8f4 100644
--- a/src/common.rs
+++ b/src/common.rs
@@ -202,6 +202,7 @@ pub type ClassifierTargetProbabilities = HashMap;
/// });
/// let mut prediction = probs.get_predicition();
/// assert_eq!(prediction, ClassifierTarget::String("Cat".to_string()));
+#[derive(Debug)]
pub enum ClassifierOutput
{
Probabilities(ClassifierTargetProbabilities),
diff --git a/src/lib.rs b/src/lib.rs
index 1eef00d..3cbbd8c 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,3 +1,4 @@
+pub mod anomaly;
pub mod common;
pub mod datasets;
pub mod metrics;
diff --git a/src/metrics/accuracy.rs b/src/metrics/accuracy.rs
index b6857b5..6d95b58 100644
--- a/src/metrics/accuracy.rs
+++ b/src/metrics/accuracy.rs
@@ -22,16 +22,16 @@ impl
{
fn update(
&mut self,
- y_true: &ClassifierOutput,
- y_pred: &ClassifierTarget,
+ y_true: &ClassifierTarget,
+ y_pred: &ClassifierOutput,
sample_weight: Option,
) {
self.cm.update(y_true, y_pred, sample_weight);
}
fn revert(
&mut self,
- y_true: &ClassifierOutput,
- y_pred: &ClassifierTarget,
+ y_true: &ClassifierTarget,
+ y_pred: &ClassifierOutput,
sample_weight: Option,
) {
self.cm.revert(y_true, y_pred, sample_weight);
diff --git a/src/metrics/mod.rs b/src/metrics/mod.rs
index ee5646b..7c9e4f9 100644
--- a/src/metrics/mod.rs
+++ b/src/metrics/mod.rs
@@ -1,4 +1,4 @@
-pub mod accuracy;
+// pub mod accuracy;
pub mod confusion;
pub mod rocauc;
pub mod traits;
diff --git a/src/metrics/traits.rs b/src/metrics/traits.rs
index a771d61..c039e9c 100644
--- a/src/metrics/traits.rs
+++ b/src/metrics/traits.rs
@@ -9,14 +9,14 @@ pub trait ClassificationMetric<
{
fn update(
&mut self,
- y_true: &ClassifierOutput,
- y_pred: &ClassifierTarget,
+ y_true: &ClassifierTarget,
+ y_pred: &ClassifierOutput,
sample_weight: Option,
);
fn revert(
&mut self,
- y_true: &ClassifierOutput,
- y_pred: &ClassifierTarget,
+ y_true: &ClassifierTarget,
+ y_pred: &ClassifierOutput,
sample_weight: Option,
);
fn get(&self) -> F;
diff --git a/src/stream/data_stream.rs b/src/stream/data_stream.rs
index fce2d4b..e579295 100644
--- a/src/stream/data_stream.rs
+++ b/src/stream/data_stream.rs
@@ -1,5 +1,6 @@
use std::collections::{HashMap, HashSet};
+use crate::common::{ClassifierTarget, Observation};
use num::Float;
/// This enum allows you to choose whether to define a single target (Name) or multiple targets (MultipleNames).
@@ -42,16 +43,36 @@ impl Target {
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum Data {
Scalar(F),
+ Int(i32),
+ Bool(bool),
String(String),
}
+impl Data {
+ pub fn to_float(&self) -> Result {
+ match self {
+ Data::Scalar(v) => Ok(*v),
+ Data::Int(v) => Ok(F::from(*v).unwrap()),
+ Data::Bool(v) => Ok(F::from(*v as i32).unwrap()),
+ Data::String(_) => Err("Cannot convert string to float"),
+ }
+ }
+
+ pub fn to_string(&self) -> String {
+ match self {
+ Data::Scalar(v) => v.to_string(),
+ Data::Int(v) => v.to_string(),
+ Data::Bool(v) => v.to_string(),
+ Data::String(v) => v.clone(),
+ }
+ }
+}
-/// "This enum defines whether your DataSteam only contains observations (X) or both observations and one or more targets (XY)
pub enum DataStream {
X(HashMap>),
XY(HashMap>, HashMap>),
}
-impl DataStream {
+impl DataStream {
pub fn get_x(&self) -> &HashMap> {
match self {
DataStream::X(x) => x,
@@ -59,10 +80,33 @@ impl DataStream {
}
}
+ pub fn to_classifier_target(&self, target_key: &str) -> Result {
+ match self {
+ DataStream::X(_) => Err("No y data"),
+ // Use data to float
+ DataStream::XY(_, y) => {
+ let y = y.get(target_key).unwrap();
+ Ok(ClassifierTarget::from(y.to_string()))
+ }
+ }
+ }
+
pub fn get_y(&self) -> Result<&HashMap>, &str> {
match self {
DataStream::X(_) => Err("No y data"),
DataStream::XY(_, y) => Ok(y),
}
}
+ pub fn get_observation(&self) -> Observation {
+ match self {
+ DataStream::X(x) | DataStream::XY(x, _) => {
+ x.iter()
+ .filter_map(|(k, v)| match v.to_float() {
+ Ok(f_value) => Some((k.clone(), f_value)),
+ Err(_) => None, // Ignore non-convertible data types
+ })
+ .collect()
+ }
+ }
+ }
}