Skip to content

Commit

Permalink
Merge branch 'main' into ahnentafel
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxHalford authored Dec 14, 2023
2 parents 7f32031 + 729fbbc commit 566d803
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 92 deletions.
53 changes: 9 additions & 44 deletions src/metrics/accuracy.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
use std::ops::{AddAssign, DivAssign, MulAssign, SubAssign};

use crate::common::{ClassifierOutput, ClassifierTarget};
use crate::metrics::confusion::ConfusionMatrix;
use crate::metrics::traits::ClassificationMetric;
use num::{Float, FromPrimitive};

struct Accuracy<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign> {
n_samples: F,
n_correct: F,
cm: ConfusionMatrix<F>,
}
impl<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign> Accuracy<F> {
pub fn new() -> Self {
Self {
n_samples: F::zero(),
n_correct: F::zero(),
cm: ConfusionMatrix::new(),
}
}
}
Expand All @@ -27,56 +26,22 @@ impl<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign>
y_pred: &ClassifierOutput<F>,
sample_weight: Option<F>,
) {
let sample_weight = sample_weight.unwrap_or_else(|| F::one());
let y_true = match y_true {
ClassifierOutput::Prediction(y_true) => y_true,
ClassifierOutput::Probabilities(y_true) => {
// Find the key with the highest probabilities
y_true
.iter()
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
.unwrap()
.0
}
};
if y_true == y_pred {
self.n_correct += sample_weight;
}
self.n_samples += sample_weight;
self.cm.update(y_true, y_pred, sample_weight);
}
fn revert(
&mut self,
y_true: &ClassifierTarget,
y_pred: &ClassifierOutput<F>,
sample_weight: Option<F>,
) {
let sample_weight = sample_weight.unwrap_or_else(|| F::one());
let y_true = match y_true {
ClassifierOutput::Prediction(y_true) => y_true,
ClassifierOutput::Probabilities(y_true) => {
// Find the key with the highest probabilities
y_true
.iter()
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
.unwrap()
.0
}
};
if y_true == y_pred {
self.n_correct -= sample_weight;
}
self.n_samples -= sample_weight;

if self.n_samples < F::zero() {
self.n_samples = F::zero();
}
self.cm.revert(y_true, y_pred, sample_weight);
}
fn get(&self) -> F {
if self.n_samples == F::zero() {
return F::zero();
}
self.n_correct / self.n_samples
self.cm
.total_true_positives()
.div(F::from(self.cm.total_weight).unwrap())
}

fn is_multiclass(&self) -> bool {
true
}
Expand Down
59 changes: 11 additions & 48 deletions src/metrics/rocauc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ use num::{Float, FromPrimitive};
///
/// # Examples
///
/// ```
/// use light_river::metrics::rocauc::ROCAUC;
/// use light_river::metrics::traits::ClassificationMetric;
/// ```rust
/// use light_river::metrics::ROCAUC;
/// use light_river::common::{ClassifierTarget, ClassifierOutput};
/// use std::collections::HashMap;
///
Expand All @@ -43,22 +42,21 @@ use num::{Float, FromPrimitive};
/// ];
/// let y_true: Vec<bool> = vec![false, false, true, true];
///
/// let mut metric = ROCAUC::new(Some(10), ClassifierTarget::from(true));
/// let mut metric = ROCAUC::new(Some(10), ClassifierTarget::from(true));
///
/// for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
/// metric.update(yp, &ClassifierTarget::from(*yt), Some(1.0));
/// }
///
/// println!("ROCAUC: {:.2}%", metric.get() * (100.0 as f64));
/// // ROCAUC: 87.50%
/// println!("ROCAUC: {:.2}%", metric.get() * 100.0);
/// ```
///
/// # Notes
///
/// The true ROC AUC might differ from the approximation. The accuracy can be improved by increasing the number
/// of thresholds, but this comes at the cost of more computation time and memory usage.
///
pub struct ROCAUC<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign> {
struct ROCAUC<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign> {
n_threshold: Option<usize>,
pos_val: ClassifierTarget,
thresholds: Vec<F>,
Expand Down Expand Up @@ -97,8 +95,8 @@ impl<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign>
{
fn update(
&mut self,
y_true: &ClassifierTarget,
y_pred: &ClassifierOutput<F>,

Check failure on line 98 in src/metrics/rocauc.rs

View workflow job for this annotation

GitHub Actions / clippy

method `update` has an incompatible type for trait

error[E0053]: method `update` has an incompatible type for trait --> src/metrics/rocauc.rs:98:17 | 98 | y_pred: &ClassifierOutput<F>, | ^^^^^^^^^^^^^^^^^^^^ | | | expected `common::ClassifierTarget`, found `common::ClassifierOutput<F>` | help: change the parameter type to match the trait: `&common::ClassifierTarget` | note: type in trait --> src/metrics/traits.rs:12:17 | 12 | y_true: &ClassifierTarget, | ^^^^^^^^^^^^^^^^^ = note: expected signature `fn(&mut metrics::rocauc::ROCAUC<F>, &common::ClassifierTarget, &common::ClassifierOutput<F>, std::option::Option<_>)` found signature `fn(&mut metrics::rocauc::ROCAUC<F>, &common::ClassifierOutput<F>, &common::ClassifierTarget, std::option::Option<_>)`
y_true: &ClassifierTarget,
sample_weight: Option<F>,
) {
// Get the probability of the positive class
Expand All @@ -118,8 +116,8 @@ impl<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign>

fn revert(
&mut self,
y_true: &ClassifierTarget,
y_pred: &ClassifierOutput<F>,

Check failure on line 119 in src/metrics/rocauc.rs

View workflow job for this annotation

GitHub Actions / clippy

method `revert` has an incompatible type for trait

error[E0053]: method `revert` has an incompatible type for trait --> src/metrics/rocauc.rs:119:17 | 119 | y_pred: &ClassifierOutput<F>, | ^^^^^^^^^^^^^^^^^^^^ | | | expected `common::ClassifierTarget`, found `common::ClassifierOutput<F>` | help: change the parameter type to match the trait: `&common::ClassifierTarget` | note: type in trait --> src/metrics/traits.rs:18:17 | 18 | y_true: &ClassifierTarget, | ^^^^^^^^^^^^^^^^^ = note: expected signature `fn(&mut metrics::rocauc::ROCAUC<F>, &common::ClassifierTarget, &common::ClassifierOutput<F>, std::option::Option<_>)` found signature `fn(&mut metrics::rocauc::ROCAUC<F>, &common::ClassifierOutput<F>, &common::ClassifierTarget, std::option::Option<_>)`
y_true: &ClassifierTarget,
sample_weight: Option<F>,
) {
let p_pred = y_pred.get_probabilities();
Expand All @@ -143,6 +141,7 @@ impl<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign>
let true_negatives: F = cm.true_negatives(&self.pos_val);
let false_positives: F = cm.false_positives(&self.pos_val);
let false_negatives: F = cm.false_negatives(&self.pos_val);

// Handle the case of zero division
let mut tpr: Option<F> = None;
if true_positives + false_negatives != F::zero() {
Expand Down Expand Up @@ -185,48 +184,12 @@ mod tests {
ClassifierOutput::Prediction(ClassifierTarget::from("bird")),
ClassifierOutput::Prediction(ClassifierTarget::from("cat")),
];
let y_true: Vec<ClassifierTarget> = vec![
ClassifierTarget::from("cat".to_string()),
ClassifierTarget::from("cat".to_string()),
ClassifierTarget::from("dog".to_string()),
ClassifierTarget::from("cat".to_string()),
];
let pos_val = ClassifierTarget::from("cat");
let mut metric: ROCAUC<f32> = ROCAUC::new(Some(10), pos_val);

for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
metric.update(&yt, yp, Some(1.0));
println!("{}", metric.get());
}
println!("{}", metric.get());
}
#[test]
fn another() {
use std::collections::HashMap;

let y_pred = vec![
ClassifierOutput::Probabilities(HashMap::from([
(ClassifierTarget::from(true), 0.1),
(ClassifierTarget::from(false), 0.9),
])),
ClassifierOutput::Probabilities(HashMap::from([(ClassifierTarget::from(true), 0.4)])),
ClassifierOutput::Probabilities(HashMap::from([
(ClassifierTarget::from(true), 0.35),
(ClassifierTarget::from(false), 0.65),
])),
ClassifierOutput::Probabilities(HashMap::from([
(ClassifierTarget::from(true), 0.8),
(ClassifierTarget::from(false), 0.2),
])),
];
let y_true: Vec<bool> = vec![false, false, true, true];
let y_true: Vec<&str> = vec!["cat", "cat", "dog", "cat"];

let mut metric: ROCAUC<f64> = ROCAUC::new(Some(10), ClassifierTarget::from(true));
let mut metric = ROCAUC::new(Some(10), ClassifierTarget::from("cat"));

for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
metric.update(&ClassifierTarget::from(*yt), yp, Some(1.0));
metric.update(yp, &ClassifierTarget::from(*yt), Some(1.0));
}

println!("ROCAUC: {:.2}%", metric.get() * (100.0 as f64));
}
}

0 comments on commit 566d803

Please sign in to comment.