diff --git a/src/evidently/future/metric_types.py b/src/evidently/future/metric_types.py index 95a86f079d..7937b0ad68 100644 --- a/src/evidently/future/metric_types.py +++ b/src/evidently/future/metric_types.py @@ -380,8 +380,6 @@ def __call__(self, context: "Context", metric: "MetricCalculationBase", value: T MetricId = str -ByLabelValueTests = Dict[Label, List[SingleValueTest]] - def metric_tests_widget(tests: List[MetricTestResult]) -> BaseWidgetInfo: return BaseWidgetInfo( @@ -738,8 +736,11 @@ def run_test( return self.test.run(context, calculation, metric_result) +SingleValueMetricTests = Optional[List[MetricTest]] + + class SingleValueMetric(Metric[TSingleValueMetricCalculation]): - tests: Optional[List[MetricTest]] = None + tests: SingleValueMetricTests = None def get_bound_tests(self, context: "Context") -> List[BoundTest]: if self.tests is None and context.configuration.include_tests: @@ -768,8 +769,11 @@ def run_test( return self.test.run(context, calculation, value) +ByLabelMetricTests = Optional[Dict[Label, List[MetricTest]]] + + class ByLabelMetric(Metric["ByLabelCalculation"]): - tests: Optional[Dict[Label, List[MetricTest]]] = None + tests: ByLabelMetricTests = None def get_bound_tests(self, context: "Context") -> List[BoundTest]: if self.tests is None and context.configuration.include_tests: @@ -806,8 +810,8 @@ def run_test( class CountMetric(Metric["CountCalculation"]): - tests: Optional[List[MetricTest]] = None - share_tests: Optional[List[MetricTest]] = None + tests: SingleValueMetricTests = None + share_tests: SingleValueMetricTests = None def get_bound_tests(self, context: "Context") -> Sequence[BoundTest]: if self.tests is None and self.share_tests is None and context.configuration.include_tests: @@ -846,8 +850,8 @@ def run_test( class MeanStdMetric(Metric["MeanStdCalculation"]): - mean_tests: Optional[List[MetricTest]] = None - std_tests: Optional[List[MetricTest]] = None + mean_tests: SingleValueMetricTests = None + std_tests: SingleValueMetricTests = None def get_bound_tests(self, context: "Context") -> Sequence[BoundTest]: if self.mean_tests is None and self.mean_tests is None and context.configuration.include_tests: diff --git a/src/evidently/future/presets/classification.py b/src/evidently/future/presets/classification.py index d6cb93cd08..07a88b8e50 100644 --- a/src/evidently/future/presets/classification.py +++ b/src/evidently/future/presets/classification.py @@ -4,9 +4,11 @@ from evidently.future.container import MetricContainer from evidently.future.datasets import BinaryClassification +from evidently.future.metric_types import ByLabelMetricTests from evidently.future.metric_types import Metric from evidently.future.metric_types import MetricId from evidently.future.metric_types import MetricResult +from evidently.future.metric_types import SingleValueMetricTests from evidently.future.metrics import FNR from evidently.future.metrics import FPR from evidently.future.metrics import TNR @@ -43,7 +45,27 @@ def __init__( conf_matrix: bool = False, pr_curve: bool = False, pr_table: bool = False, + accuracy_tests: SingleValueMetricTests = None, + precision_tests: SingleValueMetricTests = None, + recall_tests: SingleValueMetricTests = None, + f1score_tests: SingleValueMetricTests = None, + rocauc_tests: SingleValueMetricTests = None, + logloss_tests: SingleValueMetricTests = None, + tpr_tests: SingleValueMetricTests = None, + tnr_tests: SingleValueMetricTests = None, + fpr_tests: SingleValueMetricTests = None, + fnr_tests: SingleValueMetricTests = None, ): + self._accuracy_tests = accuracy_tests + self._precision_tests = precision_tests + self._recall_tests = recall_tests + self._f1score_tests = f1score_tests + self._rocauc_test = rocauc_tests + self._logloss_test = logloss_tests + self._tpr_test = tpr_tests + self._tnr_test = tnr_tests + self._fpr_test = fpr_tests + self._fnr_test = fnr_tests self._probas_threshold = probas_threshold self._conf_matrix = conf_matrix self._pr_curve = pr_curve @@ -56,42 +78,28 @@ def generate_metrics(self, context: "Context") -> List[Metric]: metrics: List[Metric] + metrics = [ + Accuracy(probas_threshold=self._probas_threshold, tests=self._accuracy_tests), + Precision(probas_threshold=self._probas_threshold, tests=self._precision_tests), + Recall(probas_threshold=self._probas_threshold, tests=self._recall_tests), + F1Score(probas_threshold=self._probas_threshold, tests=self._f1score_tests), + ] + if classification.prediction_probas is not None: + metrics.extend( + [ + RocAuc(probas_threshold=self._probas_threshold, tests=self._rocauc_test), + LogLoss(probas_threshold=self._probas_threshold, tests=self._logloss_test), + ] + ) if isinstance(classification, BinaryClassification): - metrics = [ - Accuracy(probas_threshold=self._probas_threshold), - Precision(probas_threshold=self._probas_threshold), - Recall(probas_threshold=self._probas_threshold), - F1Score(probas_threshold=self._probas_threshold), - ] - if classification.prediction_probas is not None: - metrics.extend( - [ - RocAuc(probas_threshold=self._probas_threshold), - LogLoss(probas_threshold=self._probas_threshold), - ] - ) metrics.extend( [ - TPR(probas_threshold=self._probas_threshold), - TNR(probas_threshold=self._probas_threshold), - FPR(probas_threshold=self._probas_threshold), - FNR(probas_threshold=self._probas_threshold), + TPR(probas_threshold=self._probas_threshold, tests=self._tpr_test), + TNR(probas_threshold=self._probas_threshold, tests=self._tnr_test), + FPR(probas_threshold=self._probas_threshold, tests=self._fpr_test), + FNR(probas_threshold=self._probas_threshold, tests=self._fnr_test), ] ) - else: - metrics = [ - Accuracy(probas_threshold=self._probas_threshold), - Precision(probas_threshold=self._probas_threshold), - Recall(probas_threshold=self._probas_threshold), - F1Score(probas_threshold=self._probas_threshold), - ] - if classification.prediction_probas is not None: - metrics.extend( - [ - RocAuc(probas_threshold=self._probas_threshold), - LogLoss(probas_threshold=self._probas_threshold), - ] - ) return metrics def render(self, context: "Context", results: Dict[MetricId, MetricResult]) -> List[BaseWidgetInfo]: @@ -123,23 +131,35 @@ def render(self, context: "Context", results: Dict[MetricId, MetricResult]) -> L class ClassificationQualityByLabel(MetricContainer): - def __init__(self, probas_threshold: Optional[float] = None, k: Optional[int] = None): + def __init__( + self, + probas_threshold: Optional[float] = None, + k: Optional[int] = None, + f1score_tests: ByLabelMetricTests = None, + precision_tests: ByLabelMetricTests = None, + recall_tests: ByLabelMetricTests = None, + rocauc_tests: ByLabelMetricTests = None, + ): self._probas_threshold = probas_threshold self._k = k + self._f1score_tests = f1score_tests + self._precision_tests = precision_tests + self._recall_tests = recall_tests + self._rocauc_tests = rocauc_tests def generate_metrics(self, context: "Context") -> List[Metric]: classification = context.data_definition.get_classification("default") if classification is None: raise ValueError("Cannot use ClassificationPreset without a classification configration") return [ - F1ByLabel(probas_threshold=self._probas_threshold, k=self._k), - PrecisionByLabel(probas_threshold=self._probas_threshold, k=self._k), - RecallByLabel(probas_threshold=self._probas_threshold, k=self._k), + F1ByLabel(probas_threshold=self._probas_threshold, k=self._k, tests=self._f1score_tests), + PrecisionByLabel(probas_threshold=self._probas_threshold, k=self._k, tests=self._precision_tests), + RecallByLabel(probas_threshold=self._probas_threshold, k=self._k, tests=self._recall_tests), ] + ( [] if classification.prediction_probas is None else [ - RocAucByLabel(probas_threshold=self._probas_threshold, k=self._k), + RocAucByLabel(probas_threshold=self._probas_threshold, k=self._k, tests=self._rocauc_tests), ] ) @@ -156,7 +176,11 @@ def render(self, context: "Context", results: Dict[MetricId, MetricResult]): class ClassificationDummyQuality(MetricContainer): - def __init__(self, probas_threshold: Optional[float] = None, k: Optional[int] = None): + def __init__( + self, + probas_threshold: Optional[float] = None, + k: Optional[int] = None, + ): self._probas_threshold = probas_threshold self._k = k @@ -178,23 +202,59 @@ def render(self, context: "Context", results: Dict[MetricId, MetricResult]) -> L class ClassificationPreset(MetricContainer): - def __init__(self, probas_threshold: Optional[float] = None): + def __init__( + self, + probas_threshold: Optional[float] = None, + accuracy_tests: SingleValueMetricTests = None, + precision_tests: SingleValueMetricTests = None, + recall_tests: SingleValueMetricTests = None, + f1score_tests: SingleValueMetricTests = None, + rocauc_tests: SingleValueMetricTests = None, + logloss_tests: SingleValueMetricTests = None, + tpr_tests: SingleValueMetricTests = None, + tnr_tests: SingleValueMetricTests = None, + fpr_tests: SingleValueMetricTests = None, + fnr_tests: SingleValueMetricTests = None, + f1score_by_label_tests: ByLabelMetricTests = None, + precision_by_label_tests: ByLabelMetricTests = None, + recall_by_label_tests: ByLabelMetricTests = None, + rocauc_by_label_tests: ByLabelMetricTests = None, + ): self._probas_threshold = probas_threshold self._quality = ClassificationQuality( probas_threshold=probas_threshold, conf_matrix=True, pr_curve=True, pr_table=True, + accuracy_tests=accuracy_tests, + precision_tests=precision_tests, + recall_tests=recall_tests, + f1score_tests=f1score_tests, + rocauc_tests=rocauc_tests, + logloss_tests=logloss_tests, + tpr_tests=tpr_tests, + tnr_tests=tnr_tests, + fpr_tests=fpr_tests, + fnr_tests=fnr_tests, + ) + self._quality_by_label = ClassificationQualityByLabel( + probas_threshold=probas_threshold, + f1score_tests=f1score_by_label_tests, + precision_tests=precision_by_label_tests, + recall_tests=recall_by_label_tests, + rocauc_tests=rocauc_by_label_tests, + ) + self._roc_auc: Optional[RocAuc] = RocAuc( + probas_threshold=probas_threshold, + tests=rocauc_tests, ) - self._quality_by_label = ClassificationQualityByLabel(probas_threshold=probas_threshold) - self._roc_auc: Optional[RocAuc] = None def generate_metrics(self, context: "Context") -> List[Metric]: classification = context.data_definition.get_classification("default") if classification is None: raise ValueError("Cannot use ClassificationPreset without a classification configration") - if classification.prediction_probas is not None: - self._roc_auc = RocAuc() + if classification.prediction_probas is None: + self._roc_auc = None return ( self._quality.metrics(context) + self._quality_by_label.metrics(context)