1
+ import copy
1
2
from pathlib import Path
2
3
3
4
from dvclive .serialize import dump_json
@@ -9,7 +10,7 @@ class SKLearnPlot(Data):
9
10
suffixes = [".json" ]
10
11
subfolder = "sklearn"
11
12
12
- def __init__ (self , name : str , output_folder : str ) -> None :
13
+ def __init__ (self , name : str , output_folder : str , ** kwargs ) -> None : # noqa: ARG002
13
14
super ().__init__ (name , output_folder )
14
15
self .name = self .name .replace (".json" , "" )
15
16
@@ -25,22 +26,22 @@ def could_log(val: object) -> bool:
25
26
return True
26
27
return False
27
28
28
- @staticmethod
29
- def get_properties ():
29
+ def get_properties (self ):
30
30
raise NotImplementedError
31
31
32
32
33
33
class Roc (SKLearnPlot ):
34
- @staticmethod
35
- def get_properties ():
36
- return {
37
- "template" : "simple" ,
38
- "x" : "fpr" ,
39
- "y" : "tpr" ,
40
- "title" : "Receiver operating characteristic (ROC)" ,
41
- "x_label" : "False Positive Rate" ,
42
- "y_label" : "True Positive Rate" ,
43
- }
34
+ DEFAULT_PROPERTIES = {
35
+ "template" : "simple" ,
36
+ "x" : "fpr" ,
37
+ "y" : "tpr" ,
38
+ "title" : "Receiver operating characteristic (ROC)" ,
39
+ "x_label" : "False Positive Rate" ,
40
+ "y_label" : "True Positive Rate" ,
41
+ }
42
+
43
+ def get_properties (self ):
44
+ return copy .deepcopy (self .DEFAULT_PROPERTIES )
44
45
45
46
def dump (self , val , ** kwargs ) -> None :
46
47
from sklearn import metrics
@@ -58,16 +59,17 @@ def dump(self, val, **kwargs) -> None:
58
59
59
60
60
61
class PrecisionRecall (SKLearnPlot ):
61
- @staticmethod
62
- def get_properties ():
63
- return {
64
- "template" : "simple" ,
65
- "x" : "recall" ,
66
- "y" : "precision" ,
67
- "title" : "Precision-Recall Curve" ,
68
- "x_label" : "Recall" ,
69
- "y_label" : "Precision" ,
70
- }
62
+ DEFAULT_PROPERTIES = {
63
+ "template" : "simple" ,
64
+ "x" : "recall" ,
65
+ "y" : "precision" ,
66
+ "title" : "Precision-Recall Curve" ,
67
+ "x_label" : "Recall" ,
68
+ "y_label" : "Precision" ,
69
+ }
70
+
71
+ def get_properties (self ):
72
+ return copy .deepcopy (self .DEFAULT_PROPERTIES )
71
73
72
74
def dump (self , val , ** kwargs ) -> None :
73
75
from sklearn import metrics
@@ -86,16 +88,17 @@ def dump(self, val, **kwargs) -> None:
86
88
87
89
88
90
class Det (SKLearnPlot ):
89
- @staticmethod
90
- def get_properties ():
91
- return {
92
- "template" : "simple" ,
93
- "x" : "fpr" ,
94
- "y" : "fnr" ,
95
- "title" : "Detection error tradeoff (DET)" ,
96
- "x_label" : "False Positive Rate" ,
97
- "y_label" : "False Negative Rate" ,
98
- }
91
+ DEFAULT_PROPERTIES = {
92
+ "template" : "simple" ,
93
+ "x" : "fpr" ,
94
+ "y" : "fnr" ,
95
+ "title" : "Detection error tradeoff (DET)" ,
96
+ "x_label" : "False Positive Rate" ,
97
+ "y_label" : "False Negative Rate" ,
98
+ }
99
+
100
+ def get_properties (self ):
101
+ return copy .deepcopy (self .DEFAULT_PROPERTIES )
99
102
100
103
def dump (self , val , ** kwargs ) -> None :
101
104
from sklearn import metrics
@@ -114,16 +117,24 @@ def dump(self, val, **kwargs) -> None:
114
117
115
118
116
119
class ConfusionMatrix (SKLearnPlot ):
117
- @staticmethod
118
- def get_properties ():
119
- return {
120
- "template" : "confusion" ,
121
- "x" : "actual" ,
122
- "y" : "predicted" ,
123
- "title" : "Confusion Matrix" ,
124
- "x_label" : "True Label" ,
125
- "y_label" : "Predicted Label" ,
126
- }
120
+ DEFAULT_PROPERTIES = {
121
+ "template" : "confusion" ,
122
+ "x" : "actual" ,
123
+ "y" : "predicted" ,
124
+ "title" : "Confusion Matrix" ,
125
+ "x_label" : "True Label" ,
126
+ "y_label" : "Predicted Label" ,
127
+ }
128
+
129
+ def __init__ (self , name : str , output_folder : str , ** kwargs ) -> None :
130
+ super ().__init__ (name , output_folder )
131
+ self .normalized = kwargs .get ("normalized" ) or False
132
+
133
+ def get_properties (self ):
134
+ properties = copy .deepcopy (self .DEFAULT_PROPERTIES )
135
+ if self .normalized :
136
+ properties ["template" ] = "confusion_normalized"
137
+ return properties
127
138
128
139
def dump (self , val , ** kwargs ) -> None : # noqa: ARG002
129
140
cm = [
@@ -134,16 +145,17 @@ def dump(self, val, **kwargs) -> None: # noqa: ARG002
134
145
135
146
136
147
class Calibration (SKLearnPlot ):
137
- @staticmethod
138
- def get_properties ():
139
- return {
140
- "template" : "simple" ,
141
- "x" : "prob_pred" ,
142
- "y" : "prob_true" ,
143
- "title" : "Calibration Curve" ,
144
- "x_label" : "Mean Predicted Probability" ,
145
- "y_label" : "Fraction of Positives" ,
146
- }
148
+ DEFAULT_PROPERTIES = {
149
+ "template" : "simple" ,
150
+ "x" : "prob_pred" ,
151
+ "y" : "prob_true" ,
152
+ "title" : "Calibration Curve" ,
153
+ "x_label" : "Mean Predicted Probability" ,
154
+ "y_label" : "Fraction of Positives" ,
155
+ }
156
+
157
+ def get_properties (self ):
158
+ return copy .deepcopy (self .DEFAULT_PROPERTIES )
147
159
148
160
def dump (self , val , ** kwargs ) -> None :
149
161
from sklearn import calibration
0 commit comments