Skip to content

Commit

Permalink
Merge branch 'ahnentafel' of https://github.com/online-ml/light-river
Browse files Browse the repository at this point in the history
…into ahnentafel
  • Loading branch information
MaxHalford committed Dec 12, 2023
2 parents 0580dcf + 30c3795 commit 7f32031
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 36 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ reqwest = { version = "0.11.4", features = ["blocking"] }
zip = "0.6.4"
rand = "0.8.5"
time = "0.3.29"
half = "2.3.1"

[dev-dependencies]
criterion = { version = "0.5", features = ["html_reports"] }
Expand Down
22 changes: 18 additions & 4 deletions examples/anomaly_detection/credit_card.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
use light_river::anomaly::half_space_tree::HalfSpaceTree;
use light_river::common::ClassifierOutput;
use light_river::common::ClassifierTarget;
use light_river::datasets::credit_card::CreditCard;
use light_river::metrics::rocauc::ROCAUC;
use light_river::metrics::traits::ClassificationMetric;
use light_river::stream::data_stream::DataStream;
use light_river::stream::iter_csv::IterCsv;
use std::fs::File;
use std::time::Instant;
Expand All @@ -11,17 +16,26 @@ fn main() {
let window_size: u32 = 1000;
let n_trees: u32 = 50;
let height: u32 = 6;

let pos_val_metric = ClassifierTarget::from("1".to_string());
let pos_val_tree = pos_val_metric.clone();
let mut roc_auc: ROCAUC<f32> = ROCAUC::new(Some(10), pos_val_metric.clone());
// INITIALIZATION
let mut hst: HalfSpaceTree<f32> = HalfSpaceTree::new(window_size, n_trees, height, None);
let mut hst: HalfSpaceTree<f32> =
HalfSpaceTree::new(window_size, n_trees, height, None, Some(pos_val_tree));

// LOOP
let transactions: IterCsv<f32, File> = CreditCard::load_credit_card_transactions().unwrap();
for transaction in transactions {
let observation = transaction.unwrap().get_observation();
let _ = hst.update(&observation, true, true);
let data = transaction.unwrap();
let observation = data.get_observation();
let label = data.to_classifier_target("Class").unwrap();
let score = hst.update(&observation, true, true).unwrap();
// println!("Label: {:?}", label);
// println!("Score: {:?}", score);
roc_auc.update(&score, &label, Some(1.));
}

let elapsed_time = now.elapsed();
println!("Took {}ms", elapsed_time.as_millis());
println!("ROCAUC: {:.2}%", roc_auc.get() * (100.0 as f32));
}
37 changes: 35 additions & 2 deletions src/anomaly/half_space_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,15 @@ pub struct HalfSpaceTree<F: Float + FromPrimitive + AddAssign + SubAssign + MulA
n_nodes: u32,
trees: Option<Trees<F>>,
first_learn: bool,
pos_val: Option<ClassifierTarget>,
}
impl<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign> HalfSpaceTree<F> {
pub fn new(
window_size: u32,
n_trees: u32,
height: u32,
features: Option<Vec<String>>,
pos_val: Option<ClassifierTarget>,
// rng: ThreadRng,
) -> Self {
// let mut rng = rand::thread_rng();
Expand All @@ -126,6 +128,7 @@ impl<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign> H
n_nodes: n_nodes,

Check warning on line 128 in src/anomaly/half_space_tree.rs

View workflow job for this annotation

GitHub Actions / clippy

redundant field names in struct initialization

warning: redundant field names in struct initialization --> src/anomaly/half_space_tree.rs:128:13 | 128 | n_nodes: n_nodes, | ^^^^^^^^^^^^^^^^ help: replace it with: `n_nodes` | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#redundant_field_names
trees: trees,

Check warning on line 129 in src/anomaly/half_space_tree.rs

View workflow job for this annotation

GitHub Actions / clippy

redundant field names in struct initialization

warning: redundant field names in struct initialization --> src/anomaly/half_space_tree.rs:129:13 | 129 | trees: trees, | ^^^^^^^^^^^^ help: replace it with: `trees` | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#redundant_field_names
first_learn: false,
pos_val: pos_val,

Check warning on line 131 in src/anomaly/half_space_tree.rs

View workflow job for this annotation

GitHub Actions / clippy

redundant field names in struct initialization

warning: redundant field names in struct initialization --> src/anomaly/half_space_tree.rs:131:13 | 131 | pos_val: pos_val, | ^^^^^^^^^^^^^^^^ help: replace it with: `pos_val` | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#redundant_field_names
}
}

Expand Down Expand Up @@ -213,8 +216,11 @@ impl<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign> H
}
if do_score {
score = F::one() - (score / self.max_score());

return Some(ClassifierOutput::Probabilities(HashMap::from([(
ClassifierTarget::from(true),
ClassifierTarget::from(
self.pos_val.clone().unwrap_or(ClassifierTarget::from(true)),
),
score,
)])));
// return Some(score);
Expand All @@ -237,8 +243,35 @@ impl<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign> H
#[cfg(test)]
mod tests {
use super::*;

use crate::datasets::credit_card::CreditCard;
use crate::stream::iter_csv::IterCsv;
use std::fs::File;
#[test]
fn test_hst() {}
fn test_hst() {
// PARAMETERS
let window_size: u32 = 1000;
let n_trees: u32 = 50;
let height: u32 = 6;

// INITIALIZATION
let mut hst: HalfSpaceTree<f32> = HalfSpaceTree::new(
window_size,
n_trees,
height,
None,
Some(ClassifierTarget::from("1".to_string())),
);

// LOOP
let transactions: IterCsv<f32, File> = CreditCard::load_credit_card_transactions().unwrap();
for transaction in transactions {
let data = transaction.unwrap();
let observation = data.get_observation();
let label = data.get_y().unwrap().get("Class").unwrap();
let _ = hst.update(&observation, true, true);
}
}
}

mod tests {
Expand Down
1 change: 1 addition & 0 deletions src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ pub type ClassifierTargetProbabilities<F> = HashMap<ClassifierTarget, F>;
/// });
/// let mut prediction = probs.get_predicition();
/// assert_eq!(prediction, ClassifierTarget::String("Cat".to_string()));
#[derive(Debug)]
pub enum ClassifierOutput<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign>
{
Probabilities(ClassifierTargetProbabilities<F>),
Expand Down
8 changes: 4 additions & 4 deletions src/metrics/accuracy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ impl<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign>
{
fn update(
&mut self,
y_true: &ClassifierOutput<F>,
y_pred: &ClassifierTarget,
y_true: &ClassifierTarget,
y_pred: &ClassifierOutput<F>,
sample_weight: Option<F>,
) {
let sample_weight = sample_weight.unwrap_or_else(|| F::one());
Expand All @@ -46,8 +46,8 @@ impl<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign>
}
fn revert(
&mut self,
y_true: &ClassifierOutput<F>,
y_pred: &ClassifierTarget,
y_true: &ClassifierTarget,
y_pred: &ClassifierOutput<F>,
sample_weight: Option<F>,
) {
let sample_weight = sample_weight.unwrap_or_else(|| F::one());
Expand Down
2 changes: 1 addition & 1 deletion src/metrics/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub mod accuracy;
// pub mod accuracy;
pub mod confusion;
pub mod rocauc;
pub mod traits;
21 changes: 13 additions & 8 deletions src/metrics/rocauc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ impl<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign>
{
fn update(
&mut self,
y_pred: &ClassifierOutput<F>,
y_true: &ClassifierTarget,
y_pred: &ClassifierOutput<F>,
sample_weight: Option<F>,
) {
// Get the probability of the positive class
Expand All @@ -118,8 +118,8 @@ impl<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign>

fn revert(
&mut self,
y_pred: &ClassifierOutput<F>,
y_true: &ClassifierTarget,
y_pred: &ClassifierOutput<F>,
sample_weight: Option<F>,
) {
let p_pred = y_pred.get_probabilities();
Expand All @@ -143,7 +143,6 @@ impl<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign>
let true_negatives: F = cm.true_negatives(&self.pos_val);
let false_positives: F = cm.false_positives(&self.pos_val);
let false_negatives: F = cm.false_negatives(&self.pos_val);

// Handle the case of zero division
let mut tpr: Option<F> = None;
if true_positives + false_negatives != F::zero() {
Expand Down Expand Up @@ -186,12 +185,18 @@ mod tests {
ClassifierOutput::Prediction(ClassifierTarget::from("bird")),
ClassifierOutput::Prediction(ClassifierTarget::from("cat")),
];
let y_true: Vec<&str> = vec!["cat", "cat", "dog", "cat"];

let mut metric = ROCAUC::new(Some(10), ClassifierTarget::from("cat"));
let y_true: Vec<ClassifierTarget> = vec![
ClassifierTarget::from("cat".to_string()),
ClassifierTarget::from("cat".to_string()),
ClassifierTarget::from("dog".to_string()),
ClassifierTarget::from("cat".to_string()),
];
let pos_val = ClassifierTarget::from("cat");
let mut metric: ROCAUC<f32> = ROCAUC::new(Some(10), pos_val);

for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
metric.update(yp, &ClassifierTarget::from(*yt), Some(1.0));
metric.update(&yt, yp, Some(1.0));
println!("{}", metric.get());
}
println!("{}", metric.get());
}
Expand Down Expand Up @@ -219,7 +224,7 @@ mod tests {
let mut metric: ROCAUC<f64> = ROCAUC::new(Some(10), ClassifierTarget::from(true));

for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
metric.update(yp, &ClassifierTarget::from(*yt), Some(1.0));
metric.update(&ClassifierTarget::from(*yt), yp, Some(1.0));
}

println!("ROCAUC: {:.2}%", metric.get() * (100.0 as f64));
Expand Down
8 changes: 4 additions & 4 deletions src/metrics/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ pub trait ClassificationMetric<
{
fn update(
&mut self,
y_true: &ClassifierOutput<F>,
y_pred: &ClassifierTarget,
y_true: &ClassifierTarget,
y_pred: &ClassifierOutput<F>,
sample_weight: Option<F>,
);
fn revert(
&mut self,
y_true: &ClassifierOutput<F>,
y_pred: &ClassifierTarget,
y_true: &ClassifierTarget,
y_pred: &ClassifierOutput<F>,
sample_weight: Option<F>,
);
fn get(&self) -> F;
Expand Down
57 changes: 44 additions & 13 deletions src/stream/data_stream.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::collections::{HashMap, HashSet};

use crate::common::Observation;
use crate::common::{ClassifierTarget, Observation};
use num::Float;

/// This enum allows you to choose whether to define a single target (Name) or multiple targets (MultipleNames).
Expand Down Expand Up @@ -43,39 +43,70 @@ impl Target {
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum Data<F: Float + std::str::FromStr> {
Scalar(F),
Int(i32),
Bool(bool),
String(String),
}
impl<F: Float + std::fmt::Display + std::str::FromStr> Data<F> {
pub fn to_float(&self) -> Result<F, &str> {
match self {
Data::Scalar(v) => Ok(*v),
Data::Int(v) => Ok(F::from(*v).unwrap()),
Data::Bool(v) => Ok(F::from(*v as i32).unwrap()),
Data::String(_) => Err("Cannot convert string to float"),
}
}

pub fn to_string(&self) -> String {
match self {
Data::Scalar(v) => v.to_string(),
Data::Int(v) => v.to_string(),
Data::Bool(v) => v.to_string(),
Data::String(v) => v.clone(),
}
}

Check warning on line 67 in src/stream/data_stream.rs

View workflow job for this annotation

GitHub Actions / clippy

implementation of inherent method `to_string(&self) -> String` for type `stream::data_stream::Data<F>`

warning: implementation of inherent method `to_string(&self) -> String` for type `stream::data_stream::Data<F>` --> src/stream/data_stream.rs:60:5 | 60 | / pub fn to_string(&self) -> String { 61 | | match self { 62 | | Data::Scalar(v) => v.to_string(), 63 | | Data::Int(v) => v.to_string(), ... | 66 | | } 67 | | } | |_____^ | = help: implement trait `Display` for type `stream::data_stream::Data<F>` instead = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#inherent_to_string = note: `#[warn(clippy::inherent_to_string)]` on by default
}

/// "This enum defines whether your DataSteam only contains observations (X) or both observations and one or more targets (XY)
pub enum DataStream<F: Float + std::str::FromStr> {
X(HashMap<String, Data<F>>),
XY(HashMap<String, Data<F>>, HashMap<String, Data<F>>),
}

impl<F: Float + std::str::FromStr> DataStream<F> {
impl<F: Float + std::str::FromStr + std::fmt::Display> DataStream<F> {
pub fn get_x(&self) -> &HashMap<String, Data<F>> {
match self {
DataStream::X(x) => x,
DataStream::XY(x, _) => x,
}
}

pub fn to_classifier_target(&self, target_key: &str) -> Result<ClassifierTarget, &str> {
match self {
DataStream::X(_) => Err("No y data"),
// Use data to float
DataStream::XY(_, y) => {
let y = y.get(target_key).unwrap();
Ok(ClassifierTarget::from(y.to_string()))
}
}
}

pub fn get_y(&self) -> Result<&HashMap<String, Data<F>>, &str> {
match self {
DataStream::X(_) => Err("No y data"),
DataStream::XY(_, y) => Ok(y),
}
}
pub fn get_observation(&self) -> Observation<F> {
let observation = self.get_x();
// Get only the value that are scalar
let observation: HashMap<String, F> = observation
.iter()
.filter_map(|(k, v)| match v {
Data::Scalar(v) => Some((k.clone(), *v)),
_ => None,
})
.collect();
observation
match self {
DataStream::X(x) | DataStream::XY(x, _) => {
x.iter()
.filter_map(|(k, v)| match v.to_float() {
Ok(f_value) => Some((k.clone(), f_value)),
Err(_) => None, // Ignore non-convertible data types
})
.collect()
}
}
}
}

0 comments on commit 7f32031

Please sign in to comment.