Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
AdilZouitine committed Dec 10, 2023
1 parent df60a92 commit 30c3795
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 32 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ reqwest = { version = "0.11.4", features = ["blocking"] }
zip = "0.6.4"
rand = "0.8.5"
time = "0.3.29"
half = "2.3.1"

[dev-dependencies]
criterion = { version = "0.4", features = ["html_reports"] }
Expand Down
22 changes: 18 additions & 4 deletions examples/anomaly_detection/credit_card.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
use light_river::anomaly::half_space_tree::HalfSpaceTree;
use light_river::common::ClassifierOutput;
use light_river::common::ClassifierTarget;
use light_river::datasets::credit_card::CreditCard;
use light_river::metrics::rocauc::ROCAUC;
use light_river::metrics::traits::ClassificationMetric;
use light_river::stream::data_stream::DataStream;
use light_river::stream::iter_csv::IterCsv;
use std::fs::File;
use std::time::Instant;
Expand All @@ -11,17 +16,26 @@ fn main() {
let window_size: u32 = 1000;
let n_trees: u32 = 50;
let height: u32 = 6;

let pos_val_metric = ClassifierTarget::from("1".to_string());
let pos_val_tree = pos_val_metric.clone();
let mut roc_auc: ROCAUC<f32> = ROCAUC::new(Some(10), pos_val_metric.clone());
// INITIALIZATION
let mut hst: HalfSpaceTree<f32> = HalfSpaceTree::new(window_size, n_trees, height, None);
let mut hst: HalfSpaceTree<f32> =
HalfSpaceTree::new(window_size, n_trees, height, None, Some(pos_val_tree));

// LOOP
let transactions: IterCsv<f32, File> = CreditCard::load_credit_card_transactions().unwrap();
for transaction in transactions {
let observation = transaction.unwrap().get_observation();
let _ = hst.update(&observation, true, true);
let data = transaction.unwrap();
let observation = data.get_observation();
let label = data.to_classifier_target("Class").unwrap();
let score = hst.update(&observation, true, true).unwrap();
// println!("Label: {:?}", label);
// println!("Score: {:?}", score);
roc_auc.update(&score, &label, Some(1.));
}

let elapsed_time = now.elapsed();
println!("Took {}ms", elapsed_time.as_millis());
println!("ROCAUC: {:.2}%", roc_auc.get() * (100.0 as f32));
}
37 changes: 35 additions & 2 deletions src/anomaly/half_space_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,15 @@ pub struct HalfSpaceTree<F: Float + FromPrimitive + AddAssign + SubAssign + MulA
n_nodes: u32,
trees: Option<Trees<F>>,
first_learn: bool,
pos_val: Option<ClassifierTarget>,
}
impl<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign> HalfSpaceTree<F> {
pub fn new(
window_size: u32,
n_trees: u32,
height: u32,
features: Option<Vec<String>>,
pos_val: Option<ClassifierTarget>,
// rng: ThreadRng,
) -> Self {
// let mut rng = rand::thread_rng();
Expand All @@ -126,6 +128,7 @@ impl<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign> H
n_nodes: n_nodes,
trees: trees,
first_learn: false,
pos_val: pos_val,
}
}

Expand Down Expand Up @@ -213,8 +216,11 @@ impl<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign> H
}
if do_score {
score = F::one() - (score / self.max_score());

return Some(ClassifierOutput::Probabilities(HashMap::from([(
ClassifierTarget::from(true),
ClassifierTarget::from(
self.pos_val.clone().unwrap_or(ClassifierTarget::from(true)),
),
score,
)])));
// return Some(score);
Expand All @@ -237,6 +243,33 @@ impl<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign> H
#[cfg(test)]
mod tests {
use super::*;

use crate::datasets::credit_card::CreditCard;
use crate::stream::iter_csv::IterCsv;
use std::fs::File;
#[test]
fn test_hst() {}
fn test_hst() {
// PARAMETERS
let window_size: u32 = 1000;
let n_trees: u32 = 50;
let height: u32 = 6;

// INITIALIZATION
let mut hst: HalfSpaceTree<f32> = HalfSpaceTree::new(
window_size,
n_trees,
height,
None,
Some(ClassifierTarget::from("1".to_string())),
);

// LOOP
let transactions: IterCsv<f32, File> = CreditCard::load_credit_card_transactions().unwrap();
for transaction in transactions {
let data = transaction.unwrap();
let observation = data.get_observation();
let label = data.get_y().unwrap().get("Class").unwrap();
let _ = hst.update(&observation, true, true);
}
}
}
1 change: 1 addition & 0 deletions src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ pub type ClassifierTargetProbabilities<F> = HashMap<ClassifierTarget, F>;
/// });
/// let mut prediction = probs.get_predicition();
/// assert_eq!(prediction, ClassifierTarget::String("Cat".to_string()));
#[derive(Debug)]
pub enum ClassifierOutput<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign>
{
Probabilities(ClassifierTargetProbabilities<F>),
Expand Down
8 changes: 4 additions & 4 deletions src/metrics/accuracy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ impl<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign>
{
fn update(
&mut self,
y_true: &ClassifierOutput<F>,
y_pred: &ClassifierTarget,
y_true: &ClassifierTarget,
y_pred: &ClassifierOutput<F>,
sample_weight: Option<F>,
) {
let sample_weight = sample_weight.unwrap_or_else(|| F::one());
Expand All @@ -46,8 +46,8 @@ impl<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign>
}
fn revert(
&mut self,
y_true: &ClassifierOutput<F>,
y_pred: &ClassifierTarget,
y_true: &ClassifierTarget,
y_pred: &ClassifierOutput<F>,
sample_weight: Option<F>,
) {
let sample_weight = sample_weight.unwrap_or_else(|| F::one());
Expand Down
2 changes: 1 addition & 1 deletion src/metrics/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub mod accuracy;
// pub mod accuracy;
pub mod confusion;
pub mod rocauc;
pub mod traits;
21 changes: 13 additions & 8 deletions src/metrics/rocauc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ impl<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign>
{
fn update(
&mut self,
y_pred: &ClassifierOutput<F>,
y_true: &ClassifierTarget,
y_pred: &ClassifierOutput<F>,
sample_weight: Option<F>,
) {
// Get the probability of the positive class
Expand All @@ -118,8 +118,8 @@ impl<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign>

fn revert(
&mut self,
y_pred: &ClassifierOutput<F>,
y_true: &ClassifierTarget,
y_pred: &ClassifierOutput<F>,
sample_weight: Option<F>,
) {
let p_pred = y_pred.get_probabilities();
Expand All @@ -143,7 +143,6 @@ impl<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign>
let true_negatives: F = cm.true_negatives(&self.pos_val);
let false_positives: F = cm.false_positives(&self.pos_val);
let false_negatives: F = cm.false_negatives(&self.pos_val);

// Handle the case of zero division
let mut tpr: Option<F> = None;
if true_positives + false_negatives != F::zero() {
Expand Down Expand Up @@ -186,12 +185,18 @@ mod tests {
ClassifierOutput::Prediction(ClassifierTarget::from("bird")),
ClassifierOutput::Prediction(ClassifierTarget::from("cat")),
];
let y_true: Vec<&str> = vec!["cat", "cat", "dog", "cat"];

let mut metric = ROCAUC::new(Some(10), ClassifierTarget::from("cat"));
let y_true: Vec<ClassifierTarget> = vec![
ClassifierTarget::from("cat".to_string()),
ClassifierTarget::from("cat".to_string()),
ClassifierTarget::from("dog".to_string()),
ClassifierTarget::from("cat".to_string()),
];
let pos_val = ClassifierTarget::from("cat");
let mut metric: ROCAUC<f32> = ROCAUC::new(Some(10), pos_val);

for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
metric.update(yp, &ClassifierTarget::from(*yt), Some(1.0));
metric.update(&yt, yp, Some(1.0));
println!("{}", metric.get());
}
println!("{}", metric.get());
}
Expand Down Expand Up @@ -219,7 +224,7 @@ mod tests {
let mut metric: ROCAUC<f64> = ROCAUC::new(Some(10), ClassifierTarget::from(true));

for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
metric.update(yp, &ClassifierTarget::from(*yt), Some(1.0));
metric.update(&ClassifierTarget::from(*yt), yp, Some(1.0));
}

println!("ROCAUC: {:.2}%", metric.get() * (100.0 as f64));
Expand Down
8 changes: 4 additions & 4 deletions src/metrics/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ pub trait ClassificationMetric<
{
fn update(
&mut self,
y_true: &ClassifierOutput<F>,
y_pred: &ClassifierTarget,
y_true: &ClassifierTarget,
y_pred: &ClassifierOutput<F>,
sample_weight: Option<F>,
);
fn revert(
&mut self,
y_true: &ClassifierOutput<F>,
y_pred: &ClassifierTarget,
y_true: &ClassifierTarget,
y_pred: &ClassifierOutput<F>,
sample_weight: Option<F>,
);
fn get(&self) -> F;
Expand Down
17 changes: 8 additions & 9 deletions src/stream/data_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,15 @@ pub enum DataStream<F: Float + std::str::FromStr> {
X(HashMap<String, Data<F>>),
XY(HashMap<String, Data<F>>, HashMap<String, Data<F>>),
}

impl<F: Float + std::str::FromStr + std::fmt::Display> DataStream<F> {
pub fn get_x(&self) -> &HashMap<String, Data<F>> {
match self {
DataStream::X(x) => x,
DataStream::XY(x, _) => x,
}
}

pub fn to_classifier_target(&self, target_key: &str) -> Result<ClassifierTarget, &str> {
match self {
DataStream::X(_) => Err("No y data"),
Expand All @@ -82,15 +90,6 @@ impl<F: Float + std::str::FromStr + std::fmt::Display> DataStream<F> {
}
}
}
}

impl<F: Float + std::str::FromStr + std::fmt::Display> DataStream<F> {
pub fn get_x(&self) -> &HashMap<String, Data<F>> {
match self {
DataStream::X(x) => x,
DataStream::XY(x, _) => x,
}
}

pub fn get_y(&self) -> Result<&HashMap<String, Data<F>>, &str> {
match self {
Expand Down

0 comments on commit 30c3795

Please sign in to comment.