diff --git a/Cargo.toml b/Cargo.toml index 447bc54..cc99797 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ reqwest = { version = "0.11.4", features = ["blocking"] } zip = "0.6.4" rand = "0.8.5" time = "0.3.29" +half = "2.3.1" [dev-dependencies] criterion = { version = "0.5", features = ["html_reports"] } diff --git a/examples/anomaly_detection/credit_card.rs b/examples/anomaly_detection/credit_card.rs index f14436c..b52c541 100644 --- a/examples/anomaly_detection/credit_card.rs +++ b/examples/anomaly_detection/credit_card.rs @@ -1,5 +1,10 @@ use light_river::anomaly::half_space_tree::HalfSpaceTree; +use light_river::common::ClassifierOutput; +use light_river::common::ClassifierTarget; use light_river::datasets::credit_card::CreditCard; +use light_river::metrics::rocauc::ROCAUC; +use light_river::metrics::traits::ClassificationMetric; +use light_river::stream::data_stream::DataStream; use light_river::stream::iter_csv::IterCsv; use std::fs::File; use std::time::Instant; @@ -11,17 +16,26 @@ fn main() { let window_size: u32 = 1000; let n_trees: u32 = 50; let height: u32 = 6; - + let pos_val_metric = ClassifierTarget::from("1".to_string()); + let pos_val_tree = pos_val_metric.clone(); + let mut roc_auc: ROCAUC = ROCAUC::new(Some(10), pos_val_metric.clone()); // INITIALIZATION - let mut hst: HalfSpaceTree = HalfSpaceTree::new(window_size, n_trees, height, None); + let mut hst: HalfSpaceTree = + HalfSpaceTree::new(window_size, n_trees, height, None, Some(pos_val_tree)); // LOOP let transactions: IterCsv = CreditCard::load_credit_card_transactions().unwrap(); for transaction in transactions { - let observation = transaction.unwrap().get_observation(); - let _ = hst.update(&observation, true, true); + let data = transaction.unwrap(); + let observation = data.get_observation(); + let label = data.to_classifier_target("Class").unwrap(); + let score = hst.update(&observation, true, true).unwrap(); + // println!("Label: {:?}", label); + // println!("Score: {:?}", score); + roc_auc.update(&score, &label, Some(1.)); } let elapsed_time = now.elapsed(); println!("Took {}ms", elapsed_time.as_millis()); + println!("ROCAUC: {:.2}%", roc_auc.get() * (100.0 as f32)); } diff --git a/src/anomaly/half_space_tree.rs b/src/anomaly/half_space_tree.rs index 0bed296..62f08df 100644 --- a/src/anomaly/half_space_tree.rs +++ b/src/anomaly/half_space_tree.rs @@ -95,6 +95,7 @@ pub struct HalfSpaceTree>, first_learn: bool, + pos_val: Option, } impl HalfSpaceTree { pub fn new( @@ -102,6 +103,7 @@ impl H n_trees: u32, height: u32, features: Option>, + pos_val: Option, // rng: ThreadRng, ) -> Self { // let mut rng = rand::thread_rng(); @@ -126,6 +128,7 @@ impl H n_nodes: n_nodes, trees: trees, first_learn: false, + pos_val: pos_val, } } @@ -213,8 +216,11 @@ impl H } if do_score { score = F::one() - (score / self.max_score()); + return Some(ClassifierOutput::Probabilities(HashMap::from([( - ClassifierTarget::from(true), + ClassifierTarget::from( + self.pos_val.clone().unwrap_or(ClassifierTarget::from(true)), + ), score, )]))); // return Some(score); @@ -237,8 +243,35 @@ impl H #[cfg(test)] mod tests { use super::*; + + use crate::datasets::credit_card::CreditCard; + use crate::stream::iter_csv::IterCsv; + use std::fs::File; #[test] - fn test_hst() {} + fn test_hst() { + // PARAMETERS + let window_size: u32 = 1000; + let n_trees: u32 = 50; + let height: u32 = 6; + + // INITIALIZATION + let mut hst: HalfSpaceTree = HalfSpaceTree::new( + window_size, + n_trees, + height, + None, + Some(ClassifierTarget::from("1".to_string())), + ); + + // LOOP + let transactions: IterCsv = CreditCard::load_credit_card_transactions().unwrap(); + for transaction in transactions { + let data = transaction.unwrap(); + let observation = data.get_observation(); + let label = data.get_y().unwrap().get("Class").unwrap(); + let _ = hst.update(&observation, true, true); + } + } } mod tests { diff --git a/src/common.rs b/src/common.rs index 5c90812..6e0a8f4 100644 --- a/src/common.rs +++ b/src/common.rs @@ -202,6 +202,7 @@ pub type ClassifierTargetProbabilities = HashMap; /// }); /// let mut prediction = probs.get_predicition(); /// assert_eq!(prediction, ClassifierTarget::String("Cat".to_string())); +#[derive(Debug)] pub enum ClassifierOutput { Probabilities(ClassifierTargetProbabilities), diff --git a/src/metrics/accuracy.rs b/src/metrics/accuracy.rs index fe0d460..4dbbac5 100644 --- a/src/metrics/accuracy.rs +++ b/src/metrics/accuracy.rs @@ -23,8 +23,8 @@ impl { fn update( &mut self, - y_true: &ClassifierOutput, - y_pred: &ClassifierTarget, + y_true: &ClassifierTarget, + y_pred: &ClassifierOutput, sample_weight: Option, ) { let sample_weight = sample_weight.unwrap_or_else(|| F::one()); @@ -46,8 +46,8 @@ impl } fn revert( &mut self, - y_true: &ClassifierOutput, - y_pred: &ClassifierTarget, + y_true: &ClassifierTarget, + y_pred: &ClassifierOutput, sample_weight: Option, ) { let sample_weight = sample_weight.unwrap_or_else(|| F::one()); diff --git a/src/metrics/mod.rs b/src/metrics/mod.rs index ee5646b..7c9e4f9 100644 --- a/src/metrics/mod.rs +++ b/src/metrics/mod.rs @@ -1,4 +1,4 @@ -pub mod accuracy; +// pub mod accuracy; pub mod confusion; pub mod rocauc; pub mod traits; diff --git a/src/metrics/rocauc.rs b/src/metrics/rocauc.rs index f991855..a34dda7 100644 --- a/src/metrics/rocauc.rs +++ b/src/metrics/rocauc.rs @@ -97,8 +97,8 @@ impl { fn update( &mut self, - y_pred: &ClassifierOutput, y_true: &ClassifierTarget, + y_pred: &ClassifierOutput, sample_weight: Option, ) { // Get the probability of the positive class @@ -118,8 +118,8 @@ impl fn revert( &mut self, - y_pred: &ClassifierOutput, y_true: &ClassifierTarget, + y_pred: &ClassifierOutput, sample_weight: Option, ) { let p_pred = y_pred.get_probabilities(); @@ -143,7 +143,6 @@ impl let true_negatives: F = cm.true_negatives(&self.pos_val); let false_positives: F = cm.false_positives(&self.pos_val); let false_negatives: F = cm.false_negatives(&self.pos_val); - // Handle the case of zero division let mut tpr: Option = None; if true_positives + false_negatives != F::zero() { @@ -186,12 +185,18 @@ mod tests { ClassifierOutput::Prediction(ClassifierTarget::from("bird")), ClassifierOutput::Prediction(ClassifierTarget::from("cat")), ]; - let y_true: Vec<&str> = vec!["cat", "cat", "dog", "cat"]; - - let mut metric = ROCAUC::new(Some(10), ClassifierTarget::from("cat")); + let y_true: Vec = vec![ + ClassifierTarget::from("cat".to_string()), + ClassifierTarget::from("cat".to_string()), + ClassifierTarget::from("dog".to_string()), + ClassifierTarget::from("cat".to_string()), + ]; + let pos_val = ClassifierTarget::from("cat"); + let mut metric: ROCAUC = ROCAUC::new(Some(10), pos_val); for (yt, yp) in y_true.iter().zip(y_pred.iter()) { - metric.update(yp, &ClassifierTarget::from(*yt), Some(1.0)); + metric.update(&yt, yp, Some(1.0)); + println!("{}", metric.get()); } println!("{}", metric.get()); } @@ -219,7 +224,7 @@ mod tests { let mut metric: ROCAUC = ROCAUC::new(Some(10), ClassifierTarget::from(true)); for (yt, yp) in y_true.iter().zip(y_pred.iter()) { - metric.update(yp, &ClassifierTarget::from(*yt), Some(1.0)); + metric.update(&ClassifierTarget::from(*yt), yp, Some(1.0)); } println!("ROCAUC: {:.2}%", metric.get() * (100.0 as f64)); diff --git a/src/metrics/traits.rs b/src/metrics/traits.rs index a771d61..c039e9c 100644 --- a/src/metrics/traits.rs +++ b/src/metrics/traits.rs @@ -9,14 +9,14 @@ pub trait ClassificationMetric< { fn update( &mut self, - y_true: &ClassifierOutput, - y_pred: &ClassifierTarget, + y_true: &ClassifierTarget, + y_pred: &ClassifierOutput, sample_weight: Option, ); fn revert( &mut self, - y_true: &ClassifierOutput, - y_pred: &ClassifierTarget, + y_true: &ClassifierTarget, + y_pred: &ClassifierOutput, sample_weight: Option, ); fn get(&self) -> F; diff --git a/src/stream/data_stream.rs b/src/stream/data_stream.rs index f11eb64..e579295 100644 --- a/src/stream/data_stream.rs +++ b/src/stream/data_stream.rs @@ -1,6 +1,6 @@ use std::collections::{HashMap, HashSet}; -use crate::common::Observation; +use crate::common::{ClassifierTarget, Observation}; use num::Float; /// This enum allows you to choose whether to define a single target (Name) or multiple targets (MultipleNames). @@ -43,16 +43,36 @@ impl Target { #[derive(Debug, Clone, Eq, PartialEq)] pub enum Data { Scalar(F), + Int(i32), + Bool(bool), String(String), } +impl Data { + pub fn to_float(&self) -> Result { + match self { + Data::Scalar(v) => Ok(*v), + Data::Int(v) => Ok(F::from(*v).unwrap()), + Data::Bool(v) => Ok(F::from(*v as i32).unwrap()), + Data::String(_) => Err("Cannot convert string to float"), + } + } + + pub fn to_string(&self) -> String { + match self { + Data::Scalar(v) => v.to_string(), + Data::Int(v) => v.to_string(), + Data::Bool(v) => v.to_string(), + Data::String(v) => v.clone(), + } + } +} -/// "This enum defines whether your DataSteam only contains observations (X) or both observations and one or more targets (XY) pub enum DataStream { X(HashMap>), XY(HashMap>, HashMap>), } -impl DataStream { +impl DataStream { pub fn get_x(&self) -> &HashMap> { match self { DataStream::X(x) => x, @@ -60,6 +80,17 @@ impl DataStream { } } + pub fn to_classifier_target(&self, target_key: &str) -> Result { + match self { + DataStream::X(_) => Err("No y data"), + // Use data to float + DataStream::XY(_, y) => { + let y = y.get(target_key).unwrap(); + Ok(ClassifierTarget::from(y.to_string())) + } + } + } + pub fn get_y(&self) -> Result<&HashMap>, &str> { match self { DataStream::X(_) => Err("No y data"), @@ -67,15 +98,15 @@ impl DataStream { } } pub fn get_observation(&self) -> Observation { - let observation = self.get_x(); - // Get only the value that are scalar - let observation: HashMap = observation - .iter() - .filter_map(|(k, v)| match v { - Data::Scalar(v) => Some((k.clone(), *v)), - _ => None, - }) - .collect(); - observation + match self { + DataStream::X(x) | DataStream::XY(x, _) => { + x.iter() + .filter_map(|(k, v)| match v.to_float() { + Ok(f_value) => Some((k.clone(), f_value)), + Err(_) => None, // Ignore non-convertible data types + }) + .collect() + } + } } }