Skip to content

Commit

Permalink
add roc auc
Browse files Browse the repository at this point in the history
  • Loading branch information
AdilZouitine committed Oct 5, 2023
1 parent 742e459 commit 729fbbc
Show file tree
Hide file tree
Showing 5 changed files with 211 additions and 49 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ num = "0.4.0"
tempfile = "3.4.0"
maplit = "1.0.2"
reqwest = { version = "0.11.4", features = ["blocking"] }
zip = "0.6.4"
zip = "0.6.4"
rand = "0.8.5"
53 changes: 9 additions & 44 deletions src/metrics/accuracy.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
use std::ops::{AddAssign, DivAssign, MulAssign, SubAssign};

use crate::common::{ClassifierOutput, ClassifierTarget};
use crate::metrics::confusion::ConfusionMatrix;
use crate::metrics::traits::ClassificationMetric;
use num::{Float, FromPrimitive};

struct Accuracy<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign> {
n_samples: F,
n_correct: F,
cm: ConfusionMatrix<F>,
}
impl<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign> Accuracy<F> {
pub fn new() -> Self {

Check warning on line 12 in src/metrics/accuracy.rs

View workflow job for this annotation

GitHub Actions / clippy

associated function `new` is never used

warning: associated function `new` is never used --> src/metrics/accuracy.rs:12:12 | 11 | impl<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign> Accuracy<F> { | ------------------------------------------------------------------------------------------ associated function in this implementation 12 | pub fn new() -> Self { | ^^^ | = note: `#[warn(dead_code)]` on by default
Self {
n_samples: F::zero(),
n_correct: F::zero(),
cm: ConfusionMatrix::new(),
}
}
}
Expand All @@ -27,56 +26,22 @@ impl<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign>
y_pred: &ClassifierTarget,
sample_weight: Option<F>,
) {
let sample_weight = sample_weight.unwrap_or_else(|| F::one());
let y_true = match y_true {
ClassifierOutput::Prediction(y_true) => y_true,
ClassifierOutput::Probabilities(y_true) => {
// Find the key with the highest probabilities
y_true
.iter()
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
.unwrap()
.0
}
};
if y_true == y_pred {
self.n_correct += sample_weight;
}
self.n_samples += sample_weight;
self.cm.update(y_true, y_pred, sample_weight);
}
fn revert(
&mut self,
y_true: &ClassifierOutput<F>,
y_pred: &ClassifierTarget,
sample_weight: Option<F>,
) {
let sample_weight = sample_weight.unwrap_or_else(|| F::one());
let y_true = match y_true {
ClassifierOutput::Prediction(y_true) => y_true,
ClassifierOutput::Probabilities(y_true) => {
// Find the key with the highest probabilities
y_true
.iter()
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
.unwrap()
.0
}
};
if y_true == y_pred {
self.n_correct -= sample_weight;
}
self.n_samples -= sample_weight;

if self.n_samples < F::zero() {
self.n_samples = F::zero();
}
self.cm.revert(y_true, y_pred, sample_weight);
}
fn get(&self) -> F {
if self.n_samples == F::zero() {
return F::zero();
}
self.n_correct / self.n_samples
self.cm
.total_true_positives()
.div(F::from(self.cm.total_weight).unwrap())
}

fn is_multiclass(&self) -> bool {
true
}
Expand Down
8 changes: 4 additions & 4 deletions src/metrics/confusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pub struct ConfusionMatrix<F: Float + FromPrimitive + AddAssign + SubAssign + Mu
data: HashMap<ClassifierTarget, HashMap<ClassifierTarget, F>>,
sum_row: HashMap<ClassifierTarget, F>,
sum_col: HashMap<ClassifierTarget, F>,
total_weight: F,
pub total_weight: F,
}

impl<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign> ConfusionMatrix<F> {
Expand Down Expand Up @@ -133,10 +133,10 @@ impl<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign> C
&mut self,
y_pred: &ClassifierOutput<F>,
y_true: &ClassifierTarget,
sample_weight: F,
sample_weight: Option<F>,
) {
self.n_samples -= sample_weight;
self._update(y_pred, y_true, -sample_weight);
self.n_samples -= sample_weight.unwrap_or(F::one());
self._update(y_pred, y_true, -sample_weight.unwrap_or(F::one()));
}
pub fn get(&self, label: &ClassifierTarget) -> HashMap<ClassifierTarget, F> {
// return rows of the label in the confusion matrix
Expand Down
1 change: 1 addition & 0 deletions src/metrics/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod accuracy;
pub mod confusion;
pub mod rocauc;
pub mod traits;
195 changes: 195 additions & 0 deletions src/metrics/rocauc.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
use std::ops::{AddAssign, DivAssign, MulAssign, SubAssign};

use crate::common::{ClassifierOutput, ClassifierTarget};
use crate::metrics::confusion::ConfusionMatrix;
use crate::metrics::traits::ClassificationMetric;
use num::{Float, FromPrimitive};

/// Receiver Operating Characteristic Area Under the Curve (ROC AUC).
///
/// This metric provides an approximation of the true ROC AUC. Computing the true ROC AUC would
/// require storing all the predictions and ground truths, which may not be efficient. The approximation
/// error is typically insignificant as long as the predicted probabilities are well calibrated. Regardless,
/// this metric can be used to reliably compare models with each other.
///
/// # Parameters
///
/// - `n_threshold`: The number of thresholds used for discretizing the ROC curve. A higher value will lead to
/// more accurate results, but will also require more computation time and memory.
/// - `pos_val`: Value to treat as "positive".
///
/// # Examples
///
/// ```rust
/// use light_river::metrics::ROCAUC;
/// use light_river::common::{ClassifierTarget, ClassifierOutput};
/// use std::collections::HashMap;
///
/// let y_pred = vec![
/// ClassifierOutput::Probabilities(HashMap::from([
/// (ClassifierTarget::from(true), 0.1),
/// (ClassifierTarget::from(false), 0.9),
/// ])),
/// ClassifierOutput::Probabilities(HashMap::from([(ClassifierTarget::from(true), 0.4)])),
/// ClassifierOutput::Probabilities(HashMap::from([
/// (ClassifierTarget::from(true), 0.35),
/// (ClassifierTarget::from(false), 0.65),
/// ])),
/// ClassifierOutput::Probabilities(HashMap::from([
/// (ClassifierTarget::from(true), 0.8),
/// (ClassifierTarget::from(false), 0.2),
/// ])),
/// ];
/// let y_true: Vec<bool> = vec![false, false, true, true];
///
/// let mut metric = ROCAUC::new(Some(10), ClassifierTarget::from(true));
///
/// for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
/// metric.update(yp, &ClassifierTarget::from(*yt), Some(1.0));
/// }
///
/// println!("ROCAUC: {:.2}%", metric.get() * 100.0);
/// ```
///
/// # Notes
///
/// The true ROC AUC might differ from the approximation. The accuracy can be improved by increasing the number
/// of thresholds, but this comes at the cost of more computation time and memory usage.
///
struct ROCAUC<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign> {

Check warning on line 59 in src/metrics/rocauc.rs

View workflow job for this annotation

GitHub Actions / clippy

name `ROCAUC` contains a capitalized acronym

warning: name `ROCAUC` contains a capitalized acronym --> src/metrics/rocauc.rs:59:8 | 59 | struct ROCAUC<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign> { | ^^^^^^ help: consider making the acronym lowercase, except the initial letter: `Rocauc` | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#upper_case_acronyms = note: `#[warn(clippy::upper_case_acronyms)]` on by default
n_threshold: Option<usize>,
pos_val: ClassifierTarget,
thresholds: Vec<F>,
cms: Vec<ConfusionMatrix<F>>,
}
impl<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign> ROCAUC<F> {
pub fn new(n_threshold: Option<usize>, pos_val: ClassifierTarget) -> Self {

Check warning on line 66 in src/metrics/rocauc.rs

View workflow job for this annotation

GitHub Actions / clippy

associated function `new` is never used

warning: associated function `new` is never used --> src/metrics/rocauc.rs:66:12 | 65 | impl<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign> ROCAUC<F> { | ---------------------------------------------------------------------------------------- associated function in this implementation 66 | pub fn new(n_threshold: Option<usize>, pos_val: ClassifierTarget) -> Self { | ^^^
let n_threshold = n_threshold.unwrap_or(10);

let mut thresholds = Vec::with_capacity(n_threshold);

for i in 0..n_threshold {
thresholds.push(
F::from(i).unwrap() / (F::from(n_threshold).unwrap() - F::from(1.0).unwrap()),
);
}
thresholds[0] -= F::from(1e-7).unwrap();
thresholds[n_threshold - 1] += F::from(1e-7).unwrap();

let mut cms = Vec::with_capacity(n_threshold);
for _ in 0..n_threshold {
cms.push(ConfusionMatrix::new());
}

Self {
n_threshold: Some(n_threshold),
pos_val: pos_val,

Check warning on line 86 in src/metrics/rocauc.rs

View workflow job for this annotation

GitHub Actions / clippy

redundant field names in struct initialization

warning: redundant field names in struct initialization --> src/metrics/rocauc.rs:86:13 | 86 | pos_val: pos_val, | ^^^^^^^^^^^^^^^^ help: replace it with: `pos_val` | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#redundant_field_names = note: `#[warn(clippy::redundant_field_names)]` on by default
thresholds: thresholds,

Check warning on line 87 in src/metrics/rocauc.rs

View workflow job for this annotation

GitHub Actions / clippy

redundant field names in struct initialization

warning: redundant field names in struct initialization --> src/metrics/rocauc.rs:87:13 | 87 | thresholds: thresholds, | ^^^^^^^^^^^^^^^^^^^^^^ help: replace it with: `thresholds` | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#redundant_field_names
cms: cms,

Check warning on line 88 in src/metrics/rocauc.rs

View workflow job for this annotation

GitHub Actions / clippy

redundant field names in struct initialization

warning: redundant field names in struct initialization --> src/metrics/rocauc.rs:88:13 | 88 | cms: cms, | ^^^^^^^^ help: replace it with: `cms` | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#redundant_field_names
}
}
}

impl<F: Float + FromPrimitive + AddAssign + SubAssign + MulAssign + DivAssign>
ClassificationMetric<F> for ROCAUC<F>
{
fn update(
&mut self,
y_pred: &ClassifierOutput<F>,
y_true: &ClassifierTarget,
sample_weight: Option<F>,
) {
// Get the probability of the positive class
let p_pred = y_pred.get_probabilities();
let default_proba = F::zero();
let p_pred_pos = p_pred.get(&self.pos_val).unwrap_or(&default_proba);

// Convert the target to a binary target
let y_true = ClassifierTarget::from(y_true.eq(&self.pos_val));

for (threshold, cm) in self.thresholds.iter().zip(self.cms.iter_mut()) {
let y_pred =
ClassifierOutput::Prediction(ClassifierTarget::from(p_pred_pos.ge(threshold)));
cm.update(&y_pred, &y_true, sample_weight);
}
}

fn revert(
&mut self,
y_pred: &ClassifierOutput<F>,
y_true: &ClassifierTarget,
sample_weight: Option<F>,
) {
let p_pred = y_pred.get_probabilities();

let default_proba = F::zero();
let p_pred_pos = p_pred.get(&self.pos_val).unwrap_or(&default_proba);
let y_true = ClassifierTarget::from(y_true.eq(&self.pos_val));

for (threshold, cm) in self.thresholds.iter().zip(self.cms.iter_mut()) {
let y_pred =
ClassifierOutput::Prediction(ClassifierTarget::from(p_pred_pos.ge(threshold)));
cm.revert(&y_pred, &y_true, sample_weight);
}
}
fn get(&self) -> F {
let mut tprs: Vec<F> = (0..self.n_threshold.unwrap()).map(|_| F::zero()).collect();
let mut fprs: Vec<F> = (0..self.n_threshold.unwrap()).map(|_| F::zero()).collect();

for (i, cm) in self.cms.iter().enumerate() {
let true_positives: F = cm.true_positives(&self.pos_val);
let true_negatives: F = cm.true_negatives(&self.pos_val);
let false_positives: F = cm.false_positives(&self.pos_val);
let false_negatives: F = cm.false_negatives(&self.pos_val);

// Handle the case of zero division
let mut tpr: Option<F> = None;
if true_positives + false_negatives != F::zero() {
tpr = Some(true_positives.div(true_positives + false_negatives));
}

tprs[i] = tpr.unwrap_or(F::zero());

// Handle the case of zero division
let mut fpr: Option<F> = None;
if false_positives + true_negatives != F::zero() {
fpr = Some(false_positives.div(false_positives + true_negatives));
}

fprs[i] = fpr.unwrap_or(F::zero());
}
// Trapezoidal integration
let mut auc = F::zero();
for i in 0..self.n_threshold.unwrap() - 1 {
auc += (fprs[i + 1] - fprs[i]) * (tprs[i + 1] + tprs[i]) / F::from(2.0).unwrap();
} // TODO: Turn it functional

-auc
}

fn is_multiclass(&self) -> bool {
false
}
}

#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rocauc() {
// same example as in the doctest
let y_pred = vec![
ClassifierOutput::Prediction(ClassifierTarget::from("cat")),
ClassifierOutput::Prediction(ClassifierTarget::from("dog")),
ClassifierOutput::Prediction(ClassifierTarget::from("bird")),
ClassifierOutput::Prediction(ClassifierTarget::from("cat")),
];
let y_true: Vec<&str> = vec!["cat", "cat", "dog", "cat"];

let mut metric = ROCAUC::new(Some(10), ClassifierTarget::from("cat"));

for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
metric.update(yp, &ClassifierTarget::from(*yt), Some(1.0));
}
}
}

0 comments on commit 729fbbc

Please sign in to comment.