Skip to content

Commit

Permalink
Merge branch 'release/0.3.0'
Browse files Browse the repository at this point in the history
  • Loading branch information
tayden committed Aug 24, 2022
2 parents 0daeb2f + ab8dcfc commit 70cf481
Show file tree
Hide file tree
Showing 11 changed files with 387 additions and 68 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,7 @@ venv.bak/

# mypy
.mypy_cache/

# IDE files
.idea
.vscode
9 changes: 0 additions & 9 deletions .vscode/settings.json

This file was deleted.

3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ print(auroc(scores, labels))

### AUPR

Calculate and return the area under the Precision Recall curve using unthresholded predictions on the data and a binary true label.
Calculate and return the area under the Precision Recall curve using unthresholded predictions on the data and a binary true
label.

```python
from ood_metrics import aupr
Expand Down
1 change: 0 additions & 1 deletion VERSION

This file was deleted.

3 changes: 3 additions & 0 deletions ood_metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
import pkg_resources
__version__ = pkg_resources.get_distribution('ood_metrics').version

from .metrics import *
from .plots import *
28 changes: 14 additions & 14 deletions ood_metrics/metrics.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from sklearn.metrics import roc_curve, auc, precision_recall_curve
import numpy as np
from sklearn.metrics import auc, precision_recall_curve, roc_curve


def auroc(preds, labels, pos_label = 1):
def auroc(preds, labels, pos_label=1):
"""Calculate and return the area under the ROC curve using unthresholded predictions on the data and a binary true label.
preds: array, shape = [n_samples]
Expand All @@ -18,7 +18,7 @@ def auroc(preds, labels, pos_label = 1):
return auc(fpr, tpr)


def aupr(preds, labels, pos_label = 1):
def aupr(preds, labels, pos_label=1):
"""Calculate and return the area under the Precision Recall curve using unthresholded predictions on the data and a binary true label.
preds: array, shape = [n_samples]
Expand All @@ -34,7 +34,7 @@ def aupr(preds, labels, pos_label = 1):
return auc(recall, precision)


def fpr_at_95_tpr(preds, labels, pos_label = 1):
def fpr_at_95_tpr(preds, labels, pos_label=1):
"""Return the FPR when TPR is at minimum 95%.
preds: array, shape = [n_samples]
Expand All @@ -47,20 +47,20 @@ def fpr_at_95_tpr(preds, labels, pos_label = 1):
pos_label: label of the positive class (1 by default)
"""
fpr, tpr, _ = roc_curve(labels, preds, pos_label=pos_label)

if all(tpr < 0.95):
# No threshold allows TPR >= 0.95
return 0
elif all(tpr >= 0.95):
elif all(tpr >= 0.95):
# All thresholds allow TPR >= 0.95, so find lowest possible FPR
idxs = [i for i, x in enumerate(tpr) if x>=0.95]
idxs = [i for i, x in enumerate(tpr) if x >= 0.95]
return min(map(lambda idx: fpr[idx], idxs))
else:
# Linear interp between values to get FPR at TPR == 0.95
return np.interp(0.95, tpr, fpr)


def detection_error(preds, labels, pos_label = 1):
def detection_error(preds, labels, pos_label=1):
"""Return the misclassification probability when TPR is 95%.
preds: array, shape = [n_samples]
Expand All @@ -79,17 +79,17 @@ def detection_error(preds, labels, pos_label = 1):
neg_ratio = 1 - pos_ratio

# Get indexes of all TPR >= 95%
idxs = [i for i, x in enumerate(tpr) if x>=0.95]
idxs = [i for i, x in enumerate(tpr) if x >= 0.95]

# Calc error for a given threshold (i.e. idx)
# Calc is the (# of negatives * FNR) + (# of positives * FPR)
_detection_error = lambda idx: neg_ratio * (1 - tpr[idx]) + pos_ratio * fpr[idx]

# Return the minimum detection error such that TPR >= 0.95
return min(map(_detection_error, idxs))


def calc_metrics(predictions, labels, pos_label = 1):

def calc_metrics(predictions, labels, pos_label=1):
"""Using predictions and labels, return a dictionary containing all novelty
detection performance statistics.
Expand All @@ -105,7 +105,7 @@ def calc_metrics(predictions, labels, pos_label = 1):
pos_label: label of the positive class (1 by default)
"""

return {
'fpr_at_95_tpr': fpr_at_95_tpr(predictions, labels, pos_label=pos_label),
'detection_error': detection_error(predictions, labels, pos_label=pos_label),
Expand Down
24 changes: 12 additions & 12 deletions ood_metrics/plots.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from sklearn.metrics import roc_curve, auc, precision_recall_curve
import matplotlib.pyplot as plt
import numpy as np
from .metrics import fpr_at_95_tpr, auroc, aupr
from sklearn.metrics import auc, precision_recall_curve, roc_curve

from .metrics import aupr, auroc, fpr_at_95_tpr


def plot_roc(preds, labels, title="Receiver operating characteristic"):
Expand All @@ -16,13 +17,13 @@ def plot_roc(preds, labels, title="Receiver operating characteristic"):
title: string, optional (default="Receiver operating characteristic")
The title for the chart
"""

# Compute values for curve
fpr, tpr, _ = roc_curve(labels, preds)

# Compute FPR (95% TPR)
tpr95 = fpr_at_95_tpr(preds, labels)

# Compute AUROC
roc_auc = auroc(preds, labels)

Expand All @@ -41,8 +42,8 @@ def plot_roc(preds, labels, title="Receiver operating characteristic"):
plt.title(title)
plt.legend(loc="lower right")
plt.show()


def plot_pr(preds, labels, title="Precision recall curve"):
"""Plot an Precision-Recall curve based on unthresholded predictions and true binary labels.
Expand All @@ -55,7 +56,7 @@ def plot_pr(preds, labels, title="Precision recall curve"):
title: string, optional (default="Receiver operating characteristic")
The title for the chart
"""

# Compute values for curve
precision, recall, _ = precision_recall_curve(labels, preds)
prc_auc = auc(recall, precision)
Expand All @@ -64,7 +65,7 @@ def plot_pr(preds, labels, title="Precision recall curve"):
lw = 2
plt.plot(recall, precision, color='darkorange',
lw=lw, label='PRC curve (area = %0.2f)' % prc_auc)
# plt.plot([0, 1], [1, 0], color='navy', lw=lw, linestyle='--')
# plt.plot([0, 1], [1, 0], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Recall')
Expand All @@ -78,9 +79,9 @@ def plot_barcode(preds, labels):
"""Plot a visualization showing inliers and outliers sorted by their prediction of novelty."""
# the bar
x = sorted([a for a in zip(preds, labels)], key=lambda x: x[0])
x = np.array([[49,163,84] if a[1] == 1 else [173,221,142] for a in x])
x = np.array([[49, 163, 84] if a[1] == 1 else [173, 221, 142] for a in x])
# x = np.array([a[1] for a in x]) # for bw image

axprops = dict(xticks=[], yticks=[])
barprops = dict(aspect='auto', cmap=plt.cm.binary_r, interpolation='nearest')

Expand All @@ -91,4 +92,3 @@ def plot_barcode(preds, labels):
ax.imshow(x.reshape((1, -1, 3)), **barprops)

plt.show()

18 changes: 11 additions & 7 deletions ood_metrics/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,28 @@
from .metrics import auroc, aupr, detection_error, fpr_at_95_tpr
from .metrics import aupr, auroc, detection_error, fpr_at_95_tpr


def test_auroc():
assert auroc([0.1, 0.2, 0.3, 0.4], [0, 0, 1, 1]) == 1.0
assert auroc([0.4, 0.3, 0.2, 0.1], [1, 1, 0, 0]) == 1.0
assert auroc([0.4, 0.3, 0.2, 0.1], [0, 1, 1, 0]) == 0.5
assert auroc([0.4, 0.3, 0.2, 0.1], [-1, 1, 1, -1]) == 0.5
assert auroc([0.1, 0.2, 0.3, 0.4], [1, 1, 0, 0]) == 0.0
assert auroc([0.1, 0.2, 0.3, 0.4], [1, 0, 1, 1]) == 2./3
assert auroc([0.1, 0.2, 0.3, 0.4], [1, 0, 1, 1]) == 2. / 3


def test_aupr():
assert aupr([0.1, 0.2, 0.3, 0.4], [0, 0, 1, 1]) == 1.0
assert round(aupr(list(range(10000)), [i%2 for i in range(10000)]), 2) == 0.5
assert round(aupr(list(range(10000)), [i % 2 for i in range(10000)]), 2) == 0.5


def test_fpr_at_95_tpr():
assert fpr_at_95_tpr([0.1, 0.2, 0.3, 0.4], [0, 0, 1, 1]) == 0.0
assert fpr_at_95_tpr([0.1, 0.2, 0.3, 0.4], [1, 1, 0, 0]) == 1.0
assert round(fpr_at_95_tpr(list(range(10000)), [i%2 for i in range(10000)]), 2) == 0.95
assert round(fpr_at_95_tpr(list(range(10000)), [i % 2 for i in range(10000)]), 2) == 0.95


def test_detection_error():
assert detection_error([0.1, 0.2, 0.3, 0.4], [0, 0, 1, 1]) == 0.0
assert round(detection_error(list(range(100)), [1]*3 + [0]*97), 2) == 0.03
assert round(detection_error(list(range(100)), [1]*4 + [0]*96), 2) == 0.04
assert round(detection_error(list(range(10000)), [i%2 for i in range(10000)]), 2) == 0.5
assert round(detection_error(list(range(100)), [1] * 3 + [0] * 97), 2) == 0.03
assert round(detection_error(list(range(100)), [1] * 4 + [0] * 96), 2) == 0.04
assert round(detection_error(list(range(10000)), [i % 2 for i in range(10000)]), 2) == 0.5
Loading

0 comments on commit 70cf481

Please sign in to comment.