Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
AdilZouitine committed Sep 18, 2023
1 parent fbfa458 commit 398392b
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 45 deletions.
20 changes: 18 additions & 2 deletions src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,23 @@ impl From<bool> for ClassifierTarget {
ClassifierTarget::Bool(b)
}
}
impl From<&bool> for ClassifierTarget {
fn from(b: &bool) -> Self {
ClassifierTarget::Bool(*b)
}
}

impl From<&i32> for ClassifierTarget {
fn from(i: &i32) -> Self {
ClassifierTarget::Int(*i)
}
}

impl From<&String> for ClassifierTarget {
fn from(s: &String) -> Self {
ClassifierTarget::String(s.clone())
}
}

pub trait IntoClassifierTargetIter {
fn into_classifier_target_iter(self) -> Box<dyn Iterator<Item = ClassifierTarget>>;
Expand All @@ -160,8 +177,7 @@ where
/// use std::collections::HashMap;
/// use light_river::common::{ClassifierTarget, ClassifierTargetProbabilities};
/// use num::Float;
///
/// type ClassifierTargetProbabilities<F: Float> = HashMap<ClassifierTarget, F>;
//
///
/// let mut probs: ClassifierTargetProbabilities<f32> = HashMap::new();
/// probs.insert(ClassifierTarget::Bool(true), 0.7);
Expand Down
81 changes: 38 additions & 43 deletions src/metrics/confusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::{
};

use crate::common::{ClassifierOutput, ClassifierTarget};

use num::{Float, FromPrimitive};

/// Confusion Matrix for binary and multi-class classification.
Expand All @@ -18,27 +19,28 @@ use num::{Float, FromPrimitive};
///
/// ```
/// use light_river::metrics::confusion::ConfusionMatrix;
/// use light_river::common::ClassifierTarget;
/// use light_river::common::{ClassifierTarget, ClassifierOutput};
///
/// let y_true = vec!["cat", "ant", "cat", "cat", "ant", "bird"];
/// let y_pred = vec!["ant", "ant", "cat", "cat", "ant", "cat"];
/// let y_pred = vec![
/// ClassifierOutput::Prediction(ClassifierTarget::from("ant")),
/// ClassifierOutput::Prediction(ClassifierTarget::from("ant")),
/// ClassifierOutput::Prediction(ClassifierTarget::from("cat")),
/// ClassifierOutput::Prediction(ClassifierTarget::from("cat")),
/// ClassifierOutput::Prediction(ClassifierTarget::from("ant")),
/// ClassifierOutput::Prediction(ClassifierTarget::from("cat")),
/// ];
/// let y_pred_stream = y_pred.iter();
/// let y_true: Vec<String> = vec!["cat".to_string(), "ant".to_string(), "cat".to_string(), "cat".to_string(), "ant".to_string(), "bird".to_string()];
/// let y_true_stream = ClassifierTarget::from_iter(y_true.into_iter());
///
/// let mut cm: ConfusionMatrix<f64> = ConfusionMatrix::new();
///
/// for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
/// cm.update(yt, yp); // Assuming an update method
/// for (yt, yp) in y_true_stream.zip(y_pred_stream) {
/// cm.update( &yp, &yt, Some(1.0)); // Assuming an update method
/// }
///
/// // Representation of the matrix. This will depend on your actual implementation
/// // Here's just a placeholder. Make sure to adjust based on your actual method and output.
/// assert_eq!(cm.to_string(), "
/// ant bird cat
/// ant 2 0 0
/// bird 0 0 1
/// cat 1 0 2
/// ");
///
/// assert_eq!(cm.get(ClassifierTarget::String("bird"),ClassifierTarget::String("cat")), 1.0); // Assuming a get method
/// assert_eq!(*cm.get(&ClassifierTarget::from("bird")).get(&ClassifierTarget::from("cat")).unwrap_or(&0.0), 1.0);
/// ```
///
/// # Notes
Expand Down Expand Up @@ -96,25 +98,25 @@ impl<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign> C
y_true: &ClassifierTarget,
sample_weight: F,
) {
let label = y_pred.get_predicition();
let label_pred = y_pred.get_predicition();
let y = y_true.clone();
let y_col = y.clone();
let label_row = label.clone();
let y_row = y.clone();
let label_col = label_pred.clone();

self.data
.entry(label)
.or_insert_with(HashMap::new)
.entry(y)
.or_insert_with(HashMap::new)
.entry(label_pred)
.and_modify(|x| *x += sample_weight)
.or_insert(sample_weight);

self.total_weight += sample_weight;
self.sum_row
.entry(y_col)
.entry(y_row)
.and_modify(|x| *x += sample_weight)
.or_insert(sample_weight);
self.sum_col
.entry(label_row)
.entry(label_col)
.and_modify(|x| *x += sample_weight)
.or_insert(sample_weight);
}
Expand All @@ -137,6 +139,7 @@ impl<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign> C
self._update(y_pred, y_true, -sample_weight);
}
pub fn get(&self, label: &ClassifierTarget) -> HashMap<ClassifierTarget, F> {
// return rows of the label in the confusion matrix
self.data.get(label).unwrap_or(&HashMap::new()).clone()
}
pub fn support(&self, label: &ClassifierTarget) -> F {
Expand Down Expand Up @@ -237,44 +240,36 @@ mod tests {
use super::*;
#[test]
fn test_confusion_matrix() {
let y_true = vec![
ClassifierOutput::Prediction(ClassifierTarget::from("cat")),
let y_pred = vec![
ClassifierOutput::Prediction(ClassifierTarget::from("ant")),
ClassifierOutput::Prediction(ClassifierTarget::from("ant")),
ClassifierOutput::Prediction(ClassifierTarget::from("cat")),
ClassifierOutput::Prediction(ClassifierTarget::from("cat")),
ClassifierOutput::Prediction(ClassifierTarget::from("ant")),
ClassifierOutput::Prediction(ClassifierTarget::from("bird")),
ClassifierOutput::Prediction(ClassifierTarget::from("cat")),
];
let y_pred = vec![
"ant".to_string(),
let y_pred_stream = y_pred.iter();
let y_true: Vec<String> = vec![
"cat".to_string(),
"ant".to_string(),
"cat".to_string(),
"cat".to_string(),
"ant".to_string(),
"cat".to_string(),
"bird".to_string(),
];
let y_pred = ClassifierTarget::from_iter(y_pred.into_iter());
let y_true_stream = ClassifierTarget::from_iter(y_true.into_iter());

let mut cm: ConfusionMatrix<f64> = ConfusionMatrix::new();

for (yt, yp) in y_true.iter().zip(y_pred) {
cm.update(yt, &yp, Some(1.0)); // Assuming an update method
for (yt, yp) in y_true_stream.zip(y_pred_stream) {
cm.update(&yp, &yt, Some(1.0)); // Assuming an update method
}

// assert_eq!(
// cm.to_string(),
// "
// ant bird cat
// ant 2 0 0
// bird 0 0 1
// cat 1 0 2
// "
// );
println!("{:?}", cm);
assert_eq!(
*cm.get(&ClassifierTarget::String("bird".to_string()))
.get(&ClassifierTarget::String("cat".to_string()))
*cm.get(&ClassifierTarget::from("bird"))
.get(&ClassifierTarget::from("cat"))
.unwrap_or(&0.0),
1.0
); // Assuming a get method
);
}
}

0 comments on commit 398392b

Please sign in to comment.