From bca4682b13e31f59a2ef49e4200f68caadc9e76f Mon Sep 17 00:00:00 2001 From: Ivan Shcheklein Date: Thu, 20 Apr 2023 15:31:39 -0700 Subject: [PATCH] feat(sklplots): add normalize option to confusion matrix --- src/dvclive/live.py | 2 +- src/dvclive/plots/sklearn.py | 118 +++++++++++++++++++---------------- src/dvclive/report.py | 17 ++--- tests/test_dvc.py | 17 +++++ tests/test_report.py | 4 +- 5 files changed, 91 insertions(+), 67 deletions(-) diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 9f3ee3ee..65afd825 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -291,7 +291,7 @@ def log_sklearn_plot(self, kind, labels, predictions, name=None, **kwargs): if name in self._plots: data = self._plots[name] elif kind in SKLEARN_PLOTS and SKLEARN_PLOTS[kind].could_log(val): - data = SKLEARN_PLOTS[kind](name, self.plots_dir) + data = SKLEARN_PLOTS[kind](name, self.plots_dir, **kwargs) self._plots[data.name] = data else: raise InvalidPlotTypeError(name) diff --git a/src/dvclive/plots/sklearn.py b/src/dvclive/plots/sklearn.py index 0547d51f..ad00bcce 100644 --- a/src/dvclive/plots/sklearn.py +++ b/src/dvclive/plots/sklearn.py @@ -1,3 +1,4 @@ +import copy from pathlib import Path from dvclive.serialize import dump_json @@ -9,7 +10,7 @@ class SKLearnPlot(Data): suffixes = [".json"] subfolder = "sklearn" - def __init__(self, name: str, output_folder: str) -> None: + def __init__(self, name: str, output_folder: str, **kwargs) -> None: # noqa: ARG002 super().__init__(name, output_folder) self.name = self.name.replace(".json", "") @@ -25,22 +26,22 @@ def could_log(val: object) -> bool: return True return False - @staticmethod - def get_properties(): + def get_properties(self): raise NotImplementedError class Roc(SKLearnPlot): - @staticmethod - def get_properties(): - return { - "template": "simple", - "x": "fpr", - "y": "tpr", - "title": "Receiver operating characteristic (ROC)", - "x_label": "False Positive Rate", - "y_label": "True Positive Rate", - } + DEFAULT_PROPERTIES = { + "template": "simple", + "x": "fpr", + "y": "tpr", + "title": "Receiver operating characteristic (ROC)", + "x_label": "False Positive Rate", + "y_label": "True Positive Rate", + } + + def get_properties(self): + return copy.deepcopy(self.DEFAULT_PROPERTIES) def dump(self, val, **kwargs) -> None: from sklearn import metrics @@ -58,16 +59,17 @@ def dump(self, val, **kwargs) -> None: class PrecisionRecall(SKLearnPlot): - @staticmethod - def get_properties(): - return { - "template": "simple", - "x": "recall", - "y": "precision", - "title": "Precision-Recall Curve", - "x_label": "Recall", - "y_label": "Precision", - } + DEFAULT_PROPERTIES = { + "template": "simple", + "x": "recall", + "y": "precision", + "title": "Precision-Recall Curve", + "x_label": "Recall", + "y_label": "Precision", + } + + def get_properties(self): + return copy.deepcopy(self.DEFAULT_PROPERTIES) def dump(self, val, **kwargs) -> None: from sklearn import metrics @@ -86,16 +88,17 @@ def dump(self, val, **kwargs) -> None: class Det(SKLearnPlot): - @staticmethod - def get_properties(): - return { - "template": "simple", - "x": "fpr", - "y": "fnr", - "title": "Detection error tradeoff (DET)", - "x_label": "False Positive Rate", - "y_label": "False Negative Rate", - } + DEFAULT_PROPERTIES = { + "template": "simple", + "x": "fpr", + "y": "fnr", + "title": "Detection error tradeoff (DET)", + "x_label": "False Positive Rate", + "y_label": "False Negative Rate", + } + + def get_properties(self): + return copy.deepcopy(self.DEFAULT_PROPERTIES) def dump(self, val, **kwargs) -> None: from sklearn import metrics @@ -114,16 +117,24 @@ def dump(self, val, **kwargs) -> None: class ConfusionMatrix(SKLearnPlot): - @staticmethod - def get_properties(): - return { - "template": "confusion", - "x": "actual", - "y": "predicted", - "title": "Confusion Matrix", - "x_label": "True Label", - "y_label": "Predicted Label", - } + DEFAULT_PROPERTIES = { + "template": "confusion", + "x": "actual", + "y": "predicted", + "title": "Confusion Matrix", + "x_label": "True Label", + "y_label": "Predicted Label", + } + + def __init__(self, name: str, output_folder: str, **kwargs) -> None: + super().__init__(name, output_folder) + self.normalized = kwargs.get("normalized") or False + + def get_properties(self): + properties = copy.deepcopy(self.DEFAULT_PROPERTIES) + if self.normalized: + properties["template"] = "confusion_normalized" + return properties def dump(self, val, **kwargs) -> None: # noqa: ARG002 cm = [ @@ -134,16 +145,17 @@ def dump(self, val, **kwargs) -> None: # noqa: ARG002 class Calibration(SKLearnPlot): - @staticmethod - def get_properties(): - return { - "template": "simple", - "x": "prob_pred", - "y": "prob_true", - "title": "Calibration Curve", - "x_label": "Mean Predicted Probability", - "y_label": "Fraction of Positives", - } + DEFAULT_PROPERTIES = { + "template": "simple", + "x": "prob_pred", + "y": "prob_true", + "title": "Calibration Curve", + "x_label": "Mean Predicted Probability", + "y_label": "Fraction of Positives", + } + + def get_properties(self): + return copy.deepcopy(self.DEFAULT_PROPERTIES) def dump(self, val, **kwargs) -> None: from sklearn import calibration diff --git a/src/dvclive/report.py b/src/dvclive/report.py index 3545fff7..b8333e07 100644 --- a/src/dvclive/report.py +++ b/src/dvclive/report.py @@ -93,17 +93,12 @@ def get_plot_renderers(plots_folder, live): name = file.relative_to(plots_folder).with_suffix("").as_posix() properties = {} - if name in SKLEARN_PLOTS: - properties = SKLEARN_PLOTS[name].get_properties() - data_field = name - else: - # Plot with custom name - logged_plot = live._plots[name] - for default_name, plot_class in SKLEARN_PLOTS.items(): - if isinstance(logged_plot, plot_class): - properties = plot_class.get_properties() - data_field = default_name - break + logged_plot = live._plots[name] + for default_name, plot_class in SKLEARN_PLOTS.items(): + if isinstance(logged_plot, plot_class): + properties = logged_plot.get_properties() + data_field = default_name + break data = json.loads(file.read_text()) diff --git a/tests/test_dvc.py b/tests/test_dvc.py index 9a1eebb1..c96b1886 100644 --- a/tests/test_dvc.py +++ b/tests/test_dvc.py @@ -65,6 +65,13 @@ def test_make_dvcyaml_all_plots(tmp_dir): live.log_metric("bar", 2) live.log_image("img.png", Image.new("RGB", (10, 10), (250, 250, 250))) live.log_sklearn_plot("confusion_matrix", [0, 0, 1, 1], [0, 1, 1, 0]) + live.log_sklearn_plot( + "confusion_matrix", + [0, 0, 1, 1], + [0, 1, 1, 0], + name="confusion_matrix_normalized", + normalized=True, + ) live.log_sklearn_plot("roc", [0, 0, 1, 1], [0.0, 0.5, 0.5, 0.0], "custom_name_roc") make_dvcyaml(live) @@ -84,6 +91,16 @@ def test_make_dvcyaml_all_plots(tmp_dir): "y_label": "Predicted Label", }, }, + { + "plots/sklearn/confusion_matrix_normalized.json": { + "template": "confusion_normalized", + "title": "Confusion Matrix", + "x": "actual", + "x_label": "True Label", + "y": "predicted", + "y_label": "Predicted Label", + } + }, { "plots/sklearn/custom_name_roc.json": { "template": "simple", diff --git a/tests/test_report.py b/tests/test_report.py index 1de98c6f..4131ef4e 100644 --- a/tests/test_report.py +++ b/tests/test_report.py @@ -162,7 +162,7 @@ def test_get_plot_renderers(tmp_dir, mocker): {"fpr": 1.0, "rev": "workspace", "threshold": 0.1, "tpr": 0.5}, {"fpr": 1.0, "rev": "workspace", "threshold": 0.0, "tpr": 1.0}, ] - assert plot_renderer.properties == Roc.get_properties() + assert plot_renderer.properties == Roc.DEFAULT_PROPERTIES for name in ("confusion_matrix", "train/cm"): plot_renderer = plot_renderers_dict[name] @@ -172,7 +172,7 @@ def test_get_plot_renderers(tmp_dir, mocker): {"actual": "1", "rev": "workspace", "predicted": "0"}, {"actual": "1", "rev": "workspace", "predicted": "1"}, ] - assert plot_renderer.properties == ConfusionMatrix.get_properties() + assert plot_renderer.properties == ConfusionMatrix.DEFAULT_PROPERTIES def test_report_auto_doesnt_set_notebook(tmp_dir, mocker):