Skip to content

Commit

Permalink
Add roc_curve + test
Browse files Browse the repository at this point in the history
  • Loading branch information
chuvalniy committed Jan 23, 2024
1 parent e88a5cb commit db1c272
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 0 deletions.
Binary file modified src/metrics/__pycache__/classification.cpython-310.pyc
Binary file not shown.
Binary file modified src/metrics/__pycache__/regression.cpython-310.pyc
Binary file not shown.
36 changes: 36 additions & 0 deletions src/metrics/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,39 @@ def f1_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:

f1 = 2 * precision * recall / (precision + recall)
return f1


def roc_curve(y_true: np.ndarray, y_pred: np.ndarray, threshold_step: float = 0.1) -> (np.ndarray, np.ndarray):
"""
Calculate ROC curve values.
:param y_true: Target labels (n_samples, ).
:param y_pred: Target predictions probabilities (n_samples, ).
:param threshold_step: Threshold step for calculating rates (float).
:return: True positive rates (threshold_steps, ), False positive rates (threshold_steps, ).
"""
thresholds = np.arange(0.1, 1, threshold_step)

tpr = np.zeros_like(thresholds)
fpr = np.zeros_like(thresholds)

for i, threshold in enumerate(thresholds):
predictions = np.where(y_pred >= threshold, 1, 0)

# Calculate true positive rates.
true_positives = np.sum(y_true * predictions)
false_negatives = np.sum(y_true * (1 - predictions))
if true_positives == 0 and false_negatives == 0:
tpr[i] = 0.0
else:
tpr[i] = true_positives / (true_positives + false_negatives)

# Calculate false positive rates.
true_negatives = np.sum((1 - y_true) * (1 - predictions))
false_positives = np.sum((1 - y_true) * predictions)
if false_positives == 0 and true_negatives == 0:
fpr[i] = 0.0
else:
fpr[i] = true_negatives / (true_negatives + false_positives)

return tpr, fpr
Binary file not shown.
94 changes: 94 additions & 0 deletions tests/metrics/classifciation/test_roc_curve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import numpy as np

from src.metrics import roc_curve


def test_roc_curve_identical_labels():
y_true = np.array([1, 0, 1, 0, 1])
y_pred = np.array([1, 0, 1, 0, 1])

expected_tpr = np.ones(shape=(9,))
expected_fpr = np.ones(shape=(9,))

tpr, fpr = roc_curve(y_true, y_pred)

assert np.allclose(expected_tpr, tpr, atol=1e-5, rtol=1e-5)
assert np.allclose(expected_fpr, fpr, atol=1e-5, rtol=1e-5)


def test_roc_curve_different_labels():
y_true = np.array([1, 1, 1, 1, 1])
y_pred = np.array([0, 0, 0, 0, 0])

expected_tpr = np.zeros(shape=(9,))
expected_fpr = np.zeros(shape=(9,))

tpr, fpr = roc_curve(y_true, y_pred)

assert np.allclose(expected_tpr, tpr, atol=1e-5, rtol=1e-5)
assert np.allclose(expected_fpr, fpr, atol=1e-5, rtol=1e-5)


def test_roc_curve_reversed_different_labels():
y_true = np.array([0, 0, 0, 0, 0])
y_pred = np.array([1, 1, 1, 1, 1])

expected_tpr = np.zeros(shape=(9,))
expected_fpr = np.zeros(shape=(9,))

tpr, fpr = roc_curve(y_true, y_pred)

assert np.allclose(expected_tpr, tpr, atol=1e-5, rtol=1e-5)
assert np.allclose(expected_fpr, fpr, atol=1e-5, rtol=1e-5)


def test_roc_curve_reversed_labels():
y_true = np.array([1, 0, 1, 0, 1])
y_pred = np.array([0, 1, 0, 1, 0])

expected_tpr = np.zeros(shape=(9,))
expected_fpr = np.zeros(shape=(9,))

tpr, fpr = roc_curve(y_true, y_pred)

assert np.allclose(expected_tpr, tpr, atol=1e-5, rtol=1e-5)
assert np.allclose(expected_fpr, fpr, atol=1e-5, rtol=1e-5)


def test_roc_cruve_all_true():
y_true = np.array([1, 0, 0, 0, 1])
y_pred = np.array([1, 1, 1, 1, 1])

expected_tpr = np.ones(shape=(9,))
expected_fpr = np.zeros(shape=(9,))

tpr, fpr = roc_curve(y_true, y_pred)

assert np.allclose(expected_tpr, tpr, atol=1e-5, rtol=1e-5)
assert np.allclose(expected_fpr, fpr, atol=1e-5, rtol=1e-5)


def test_roc_curve_equal_prob():
y_true = np.array([1, 0, 0, 0, 1])
y_pred = np.array([0.5, 0.5, 0.5, 0.5, 0.5])

expected_tpr = np.array([1, 1, 1, 1, 1, 0, 0, 0, 0])
expected_fpr = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1])

tpr, fpr = roc_curve(y_true, y_pred)

assert np.allclose(expected_tpr, tpr, atol=1e-5, rtol=1e-5)
assert np.allclose(expected_fpr, fpr, atol=1e-5, rtol=1e-5)


def test_roc_curve_close_to_target():
y_true = np.array([1, 0, 0, 0, 1])
y_pred = np.array([0.9, 0.1, 0.1, 0.1, 0.9])

expected_tpr = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1])
expected_fpr = np.array([0, 1, 1, 1, 1, 1, 1, 1, 1])

tpr, fpr = roc_curve(y_true, y_pred)

assert np.allclose(expected_tpr, tpr, atol=1e-5, rtol=1e-5)
assert np.allclose(expected_fpr, fpr, atol=1e-5, rtol=1e-5)

0 comments on commit db1c272

Please sign in to comment.