Skip to content

Commit

Permalink
feat(sklplots): add normalize option to confusion matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
shcheklein authored and dberenbaum committed Apr 21, 2023
1 parent a4df3a9 commit bca4682
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 67 deletions.
2 changes: 1 addition & 1 deletion src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def log_sklearn_plot(self, kind, labels, predictions, name=None, **kwargs):
if name in self._plots:
data = self._plots[name]
elif kind in SKLEARN_PLOTS and SKLEARN_PLOTS[kind].could_log(val):
data = SKLEARN_PLOTS[kind](name, self.plots_dir)
data = SKLEARN_PLOTS[kind](name, self.plots_dir, **kwargs)
self._plots[data.name] = data
else:
raise InvalidPlotTypeError(name)
Expand Down
118 changes: 65 additions & 53 deletions src/dvclive/plots/sklearn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from pathlib import Path

from dvclive.serialize import dump_json
Expand All @@ -9,7 +10,7 @@ class SKLearnPlot(Data):
suffixes = [".json"]
subfolder = "sklearn"

def __init__(self, name: str, output_folder: str) -> None:
def __init__(self, name: str, output_folder: str, **kwargs) -> None: # noqa: ARG002
super().__init__(name, output_folder)
self.name = self.name.replace(".json", "")

Expand All @@ -25,22 +26,22 @@ def could_log(val: object) -> bool:
return True
return False

@staticmethod
def get_properties():
def get_properties(self):
raise NotImplementedError


class Roc(SKLearnPlot):
@staticmethod
def get_properties():
return {
"template": "simple",
"x": "fpr",
"y": "tpr",
"title": "Receiver operating characteristic (ROC)",
"x_label": "False Positive Rate",
"y_label": "True Positive Rate",
}
DEFAULT_PROPERTIES = {
"template": "simple",
"x": "fpr",
"y": "tpr",
"title": "Receiver operating characteristic (ROC)",
"x_label": "False Positive Rate",
"y_label": "True Positive Rate",
}

def get_properties(self):
return copy.deepcopy(self.DEFAULT_PROPERTIES)

def dump(self, val, **kwargs) -> None:
from sklearn import metrics
Expand All @@ -58,16 +59,17 @@ def dump(self, val, **kwargs) -> None:


class PrecisionRecall(SKLearnPlot):
@staticmethod
def get_properties():
return {
"template": "simple",
"x": "recall",
"y": "precision",
"title": "Precision-Recall Curve",
"x_label": "Recall",
"y_label": "Precision",
}
DEFAULT_PROPERTIES = {
"template": "simple",
"x": "recall",
"y": "precision",
"title": "Precision-Recall Curve",
"x_label": "Recall",
"y_label": "Precision",
}

def get_properties(self):
return copy.deepcopy(self.DEFAULT_PROPERTIES)

def dump(self, val, **kwargs) -> None:
from sklearn import metrics
Expand All @@ -86,16 +88,17 @@ def dump(self, val, **kwargs) -> None:


class Det(SKLearnPlot):
@staticmethod
def get_properties():
return {
"template": "simple",
"x": "fpr",
"y": "fnr",
"title": "Detection error tradeoff (DET)",
"x_label": "False Positive Rate",
"y_label": "False Negative Rate",
}
DEFAULT_PROPERTIES = {
"template": "simple",
"x": "fpr",
"y": "fnr",
"title": "Detection error tradeoff (DET)",
"x_label": "False Positive Rate",
"y_label": "False Negative Rate",
}

def get_properties(self):
return copy.deepcopy(self.DEFAULT_PROPERTIES)

def dump(self, val, **kwargs) -> None:
from sklearn import metrics
Expand All @@ -114,16 +117,24 @@ def dump(self, val, **kwargs) -> None:


class ConfusionMatrix(SKLearnPlot):
@staticmethod
def get_properties():
return {
"template": "confusion",
"x": "actual",
"y": "predicted",
"title": "Confusion Matrix",
"x_label": "True Label",
"y_label": "Predicted Label",
}
DEFAULT_PROPERTIES = {
"template": "confusion",
"x": "actual",
"y": "predicted",
"title": "Confusion Matrix",
"x_label": "True Label",
"y_label": "Predicted Label",
}

def __init__(self, name: str, output_folder: str, **kwargs) -> None:
super().__init__(name, output_folder)
self.normalized = kwargs.get("normalized") or False

def get_properties(self):
properties = copy.deepcopy(self.DEFAULT_PROPERTIES)
if self.normalized:
properties["template"] = "confusion_normalized"
return properties

def dump(self, val, **kwargs) -> None: # noqa: ARG002
cm = [
Expand All @@ -134,16 +145,17 @@ def dump(self, val, **kwargs) -> None: # noqa: ARG002


class Calibration(SKLearnPlot):
@staticmethod
def get_properties():
return {
"template": "simple",
"x": "prob_pred",
"y": "prob_true",
"title": "Calibration Curve",
"x_label": "Mean Predicted Probability",
"y_label": "Fraction of Positives",
}
DEFAULT_PROPERTIES = {
"template": "simple",
"x": "prob_pred",
"y": "prob_true",
"title": "Calibration Curve",
"x_label": "Mean Predicted Probability",
"y_label": "Fraction of Positives",
}

def get_properties(self):
return copy.deepcopy(self.DEFAULT_PROPERTIES)

def dump(self, val, **kwargs) -> None:
from sklearn import calibration
Expand Down
17 changes: 6 additions & 11 deletions src/dvclive/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,12 @@ def get_plot_renderers(plots_folder, live):
name = file.relative_to(plots_folder).with_suffix("").as_posix()
properties = {}

if name in SKLEARN_PLOTS:
properties = SKLEARN_PLOTS[name].get_properties()
data_field = name
else:
# Plot with custom name
logged_plot = live._plots[name]
for default_name, plot_class in SKLEARN_PLOTS.items():
if isinstance(logged_plot, plot_class):
properties = plot_class.get_properties()
data_field = default_name
break
logged_plot = live._plots[name]
for default_name, plot_class in SKLEARN_PLOTS.items():
if isinstance(logged_plot, plot_class):
properties = logged_plot.get_properties()
data_field = default_name
break

data = json.loads(file.read_text())

Expand Down
17 changes: 17 additions & 0 deletions tests/test_dvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ def test_make_dvcyaml_all_plots(tmp_dir):
live.log_metric("bar", 2)
live.log_image("img.png", Image.new("RGB", (10, 10), (250, 250, 250)))
live.log_sklearn_plot("confusion_matrix", [0, 0, 1, 1], [0, 1, 1, 0])
live.log_sklearn_plot(
"confusion_matrix",
[0, 0, 1, 1],
[0, 1, 1, 0],
name="confusion_matrix_normalized",
normalized=True,
)
live.log_sklearn_plot("roc", [0, 0, 1, 1], [0.0, 0.5, 0.5, 0.0], "custom_name_roc")
make_dvcyaml(live)

Expand All @@ -84,6 +91,16 @@ def test_make_dvcyaml_all_plots(tmp_dir):
"y_label": "Predicted Label",
},
},
{
"plots/sklearn/confusion_matrix_normalized.json": {
"template": "confusion_normalized",
"title": "Confusion Matrix",
"x": "actual",
"x_label": "True Label",
"y": "predicted",
"y_label": "Predicted Label",
}
},
{
"plots/sklearn/custom_name_roc.json": {
"template": "simple",
Expand Down
4 changes: 2 additions & 2 deletions tests/test_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def test_get_plot_renderers(tmp_dir, mocker):
{"fpr": 1.0, "rev": "workspace", "threshold": 0.1, "tpr": 0.5},
{"fpr": 1.0, "rev": "workspace", "threshold": 0.0, "tpr": 1.0},
]
assert plot_renderer.properties == Roc.get_properties()
assert plot_renderer.properties == Roc.DEFAULT_PROPERTIES

for name in ("confusion_matrix", "train/cm"):
plot_renderer = plot_renderers_dict[name]
Expand All @@ -172,7 +172,7 @@ def test_get_plot_renderers(tmp_dir, mocker):
{"actual": "1", "rev": "workspace", "predicted": "0"},
{"actual": "1", "rev": "workspace", "predicted": "1"},
]
assert plot_renderer.properties == ConfusionMatrix.get_properties()
assert plot_renderer.properties == ConfusionMatrix.DEFAULT_PROPERTIES


def test_report_auto_doesnt_set_notebook(tmp_dir, mocker):
Expand Down

0 comments on commit bca4682

Please sign in to comment.