Skip to content

Commit

Permalink
fix plot_roc_curve, plot_ks_statistic, and plot_precision_recall_curv…
Browse files Browse the repository at this point in the history
…e when passed Python lists instead of Numpy arrays (#32)
  • Loading branch information
reiinakano authored May 17, 2017
1 parent f25825c commit 0ef56e9
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
8 changes: 8 additions & 0 deletions scikitplot/plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ def plot_roc_curve(y_true, y_probas, title='ROC Curves', curves=('micro', 'macro
:align: center
:alt: ROC Curves
"""
y_true = np.array(y_true)
y_probas = np.array(y_probas)

if 'micro' not in curves and 'macro' not in curves and 'each_class' not in curves:
raise ValueError('Invalid argument for curves as it only takes "micro", "macro", or "each_class"')
Expand Down Expand Up @@ -282,6 +284,9 @@ def plot_ks_statistic(y_true, y_probas, title='KS Statistic Plot', ax=None, figs
:align: center
:alt: KS Statistic
"""
y_true = np.array(y_true)
y_probas = np.array(y_probas)

classes = np.unique(y_true)
if len(classes) != 2:
raise ValueError('Cannot calculate KS statistic for data with '
Expand Down Expand Up @@ -359,6 +364,9 @@ def plot_precision_recall_curve(y_true, y_probas, title='Precision-Recall Curve'
:align: center
:alt: Precision Recall Curve
"""
y_true = np.array(y_true)
y_probas = np.array(y_probas)

classes = np.unique(y_true)
probas = y_probas

Expand Down
13 changes: 13 additions & 0 deletions scikitplot/tests/test_classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sklearn.exceptions import NotFittedError
import numpy as np
import matplotlib.pyplot as plt
import scikitplot.plotters as skplt


def convert_labels_into_string(y_true):
Expand Down Expand Up @@ -201,6 +202,9 @@ def test_ax(self):
out_ax = clf.plot_confusion_matrix(self.X, self.y, ax=ax)
assert ax is out_ax

def test_array_like(self):
ax = skplt.plot_confusion_matrix([0, 1], [1, 0])


class TestPlotROCCurve(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -272,6 +276,9 @@ def test_invalid_curve_arg(self):
self.assertRaises(ValueError, clf.plot_roc_curve, self.X, self.y,
curves='zzz')

def test_array_like(self):
ax = skplt.plot_roc_curve([0, 1], [[0.8, 0.2], [0.2, 0.8]])


class TestPlotKSStatistic(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -333,6 +340,9 @@ def test_ax(self):
out_ax = clf.plot_ks_statistic(self.X, self.y, ax=ax)
assert ax is out_ax

def test_array_like(self):
ax = skplt.plot_ks_statistic([0, 1], [[0.8, 0.2], [0.2, 0.8]])


class TestPlotPrecisionRecall(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -403,6 +413,9 @@ def test_invalid_curve_arg(self):
self.assertRaises(ValueError, clf.plot_precision_recall_curve, self.X, self.y,
curves='zzz')

def test_array_like(self):
ax = skplt.plot_precision_recall_curve([0, 1], [[0.8, 0.2], [0.2, 0.8]])


class TestFeatureImportances(unittest.TestCase):
def setUp(self):
Expand Down

0 comments on commit 0ef56e9

Please sign in to comment.