Skip to content

Commit bca4682

Browse files
shchekleindberenbaum
authored andcommitted
feat(sklplots): add normalize option to confusion matrix
1 parent a4df3a9 commit bca4682

File tree

5 files changed

+91
-67
lines changed

5 files changed

+91
-67
lines changed

src/dvclive/live.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def log_sklearn_plot(self, kind, labels, predictions, name=None, **kwargs):
291291
if name in self._plots:
292292
data = self._plots[name]
293293
elif kind in SKLEARN_PLOTS and SKLEARN_PLOTS[kind].could_log(val):
294-
data = SKLEARN_PLOTS[kind](name, self.plots_dir)
294+
data = SKLEARN_PLOTS[kind](name, self.plots_dir, **kwargs)
295295
self._plots[data.name] = data
296296
else:
297297
raise InvalidPlotTypeError(name)

src/dvclive/plots/sklearn.py

Lines changed: 65 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
from pathlib import Path
23

34
from dvclive.serialize import dump_json
@@ -9,7 +10,7 @@ class SKLearnPlot(Data):
910
suffixes = [".json"]
1011
subfolder = "sklearn"
1112

12-
def __init__(self, name: str, output_folder: str) -> None:
13+
def __init__(self, name: str, output_folder: str, **kwargs) -> None: # noqa: ARG002
1314
super().__init__(name, output_folder)
1415
self.name = self.name.replace(".json", "")
1516

@@ -25,22 +26,22 @@ def could_log(val: object) -> bool:
2526
return True
2627
return False
2728

28-
@staticmethod
29-
def get_properties():
29+
def get_properties(self):
3030
raise NotImplementedError
3131

3232

3333
class Roc(SKLearnPlot):
34-
@staticmethod
35-
def get_properties():
36-
return {
37-
"template": "simple",
38-
"x": "fpr",
39-
"y": "tpr",
40-
"title": "Receiver operating characteristic (ROC)",
41-
"x_label": "False Positive Rate",
42-
"y_label": "True Positive Rate",
43-
}
34+
DEFAULT_PROPERTIES = {
35+
"template": "simple",
36+
"x": "fpr",
37+
"y": "tpr",
38+
"title": "Receiver operating characteristic (ROC)",
39+
"x_label": "False Positive Rate",
40+
"y_label": "True Positive Rate",
41+
}
42+
43+
def get_properties(self):
44+
return copy.deepcopy(self.DEFAULT_PROPERTIES)
4445

4546
def dump(self, val, **kwargs) -> None:
4647
from sklearn import metrics
@@ -58,16 +59,17 @@ def dump(self, val, **kwargs) -> None:
5859

5960

6061
class PrecisionRecall(SKLearnPlot):
61-
@staticmethod
62-
def get_properties():
63-
return {
64-
"template": "simple",
65-
"x": "recall",
66-
"y": "precision",
67-
"title": "Precision-Recall Curve",
68-
"x_label": "Recall",
69-
"y_label": "Precision",
70-
}
62+
DEFAULT_PROPERTIES = {
63+
"template": "simple",
64+
"x": "recall",
65+
"y": "precision",
66+
"title": "Precision-Recall Curve",
67+
"x_label": "Recall",
68+
"y_label": "Precision",
69+
}
70+
71+
def get_properties(self):
72+
return copy.deepcopy(self.DEFAULT_PROPERTIES)
7173

7274
def dump(self, val, **kwargs) -> None:
7375
from sklearn import metrics
@@ -86,16 +88,17 @@ def dump(self, val, **kwargs) -> None:
8688

8789

8890
class Det(SKLearnPlot):
89-
@staticmethod
90-
def get_properties():
91-
return {
92-
"template": "simple",
93-
"x": "fpr",
94-
"y": "fnr",
95-
"title": "Detection error tradeoff (DET)",
96-
"x_label": "False Positive Rate",
97-
"y_label": "False Negative Rate",
98-
}
91+
DEFAULT_PROPERTIES = {
92+
"template": "simple",
93+
"x": "fpr",
94+
"y": "fnr",
95+
"title": "Detection error tradeoff (DET)",
96+
"x_label": "False Positive Rate",
97+
"y_label": "False Negative Rate",
98+
}
99+
100+
def get_properties(self):
101+
return copy.deepcopy(self.DEFAULT_PROPERTIES)
99102

100103
def dump(self, val, **kwargs) -> None:
101104
from sklearn import metrics
@@ -114,16 +117,24 @@ def dump(self, val, **kwargs) -> None:
114117

115118

116119
class ConfusionMatrix(SKLearnPlot):
117-
@staticmethod
118-
def get_properties():
119-
return {
120-
"template": "confusion",
121-
"x": "actual",
122-
"y": "predicted",
123-
"title": "Confusion Matrix",
124-
"x_label": "True Label",
125-
"y_label": "Predicted Label",
126-
}
120+
DEFAULT_PROPERTIES = {
121+
"template": "confusion",
122+
"x": "actual",
123+
"y": "predicted",
124+
"title": "Confusion Matrix",
125+
"x_label": "True Label",
126+
"y_label": "Predicted Label",
127+
}
128+
129+
def __init__(self, name: str, output_folder: str, **kwargs) -> None:
130+
super().__init__(name, output_folder)
131+
self.normalized = kwargs.get("normalized") or False
132+
133+
def get_properties(self):
134+
properties = copy.deepcopy(self.DEFAULT_PROPERTIES)
135+
if self.normalized:
136+
properties["template"] = "confusion_normalized"
137+
return properties
127138

128139
def dump(self, val, **kwargs) -> None: # noqa: ARG002
129140
cm = [
@@ -134,16 +145,17 @@ def dump(self, val, **kwargs) -> None: # noqa: ARG002
134145

135146

136147
class Calibration(SKLearnPlot):
137-
@staticmethod
138-
def get_properties():
139-
return {
140-
"template": "simple",
141-
"x": "prob_pred",
142-
"y": "prob_true",
143-
"title": "Calibration Curve",
144-
"x_label": "Mean Predicted Probability",
145-
"y_label": "Fraction of Positives",
146-
}
148+
DEFAULT_PROPERTIES = {
149+
"template": "simple",
150+
"x": "prob_pred",
151+
"y": "prob_true",
152+
"title": "Calibration Curve",
153+
"x_label": "Mean Predicted Probability",
154+
"y_label": "Fraction of Positives",
155+
}
156+
157+
def get_properties(self):
158+
return copy.deepcopy(self.DEFAULT_PROPERTIES)
147159

148160
def dump(self, val, **kwargs) -> None:
149161
from sklearn import calibration

src/dvclive/report.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -93,17 +93,12 @@ def get_plot_renderers(plots_folder, live):
9393
name = file.relative_to(plots_folder).with_suffix("").as_posix()
9494
properties = {}
9595

96-
if name in SKLEARN_PLOTS:
97-
properties = SKLEARN_PLOTS[name].get_properties()
98-
data_field = name
99-
else:
100-
# Plot with custom name
101-
logged_plot = live._plots[name]
102-
for default_name, plot_class in SKLEARN_PLOTS.items():
103-
if isinstance(logged_plot, plot_class):
104-
properties = plot_class.get_properties()
105-
data_field = default_name
106-
break
96+
logged_plot = live._plots[name]
97+
for default_name, plot_class in SKLEARN_PLOTS.items():
98+
if isinstance(logged_plot, plot_class):
99+
properties = logged_plot.get_properties()
100+
data_field = default_name
101+
break
107102

108103
data = json.loads(file.read_text())
109104

tests/test_dvc.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,13 @@ def test_make_dvcyaml_all_plots(tmp_dir):
6565
live.log_metric("bar", 2)
6666
live.log_image("img.png", Image.new("RGB", (10, 10), (250, 250, 250)))
6767
live.log_sklearn_plot("confusion_matrix", [0, 0, 1, 1], [0, 1, 1, 0])
68+
live.log_sklearn_plot(
69+
"confusion_matrix",
70+
[0, 0, 1, 1],
71+
[0, 1, 1, 0],
72+
name="confusion_matrix_normalized",
73+
normalized=True,
74+
)
6875
live.log_sklearn_plot("roc", [0, 0, 1, 1], [0.0, 0.5, 0.5, 0.0], "custom_name_roc")
6976
make_dvcyaml(live)
7077

@@ -84,6 +91,16 @@ def test_make_dvcyaml_all_plots(tmp_dir):
8491
"y_label": "Predicted Label",
8592
},
8693
},
94+
{
95+
"plots/sklearn/confusion_matrix_normalized.json": {
96+
"template": "confusion_normalized",
97+
"title": "Confusion Matrix",
98+
"x": "actual",
99+
"x_label": "True Label",
100+
"y": "predicted",
101+
"y_label": "Predicted Label",
102+
}
103+
},
87104
{
88105
"plots/sklearn/custom_name_roc.json": {
89106
"template": "simple",

tests/test_report.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def test_get_plot_renderers(tmp_dir, mocker):
162162
{"fpr": 1.0, "rev": "workspace", "threshold": 0.1, "tpr": 0.5},
163163
{"fpr": 1.0, "rev": "workspace", "threshold": 0.0, "tpr": 1.0},
164164
]
165-
assert plot_renderer.properties == Roc.get_properties()
165+
assert plot_renderer.properties == Roc.DEFAULT_PROPERTIES
166166

167167
for name in ("confusion_matrix", "train/cm"):
168168
plot_renderer = plot_renderers_dict[name]
@@ -172,7 +172,7 @@ def test_get_plot_renderers(tmp_dir, mocker):
172172
{"actual": "1", "rev": "workspace", "predicted": "0"},
173173
{"actual": "1", "rev": "workspace", "predicted": "1"},
174174
]
175-
assert plot_renderer.properties == ConfusionMatrix.get_properties()
175+
assert plot_renderer.properties == ConfusionMatrix.DEFAULT_PROPERTIES
176176

177177

178178
def test_report_auto_doesnt_set_notebook(tmp_dir, mocker):

0 commit comments

Comments
 (0)