From 1babc8b7cddfb08d434b6f01d7474810fc6dc084 Mon Sep 17 00:00:00 2001 From: Jirka Date: Sat, 24 Jun 2023 00:31:55 +0200 Subject: [PATCH 01/11] reverting optional task for v0.11 --- src/torchmetrics/classification/accuracy.py | 189 +++++++++++++-- src/torchmetrics/classification/auroc.py | 134 +++++++++-- .../classification/average_precision.py | 114 +++++++-- .../classification/calibration_error.py | 90 ++++++-- .../classification/cohen_kappa.py | 81 +++++-- .../classification/confusion_matrix.py | 103 +++++++-- src/torchmetrics/classification/dice.py | 94 +------- src/torchmetrics/classification/f_beta.py | 200 +++++++++++++--- src/torchmetrics/classification/hamming.py | 95 ++++++-- src/torchmetrics/classification/hinge.py | 102 ++++++-- src/torchmetrics/classification/jaccard.py | 106 +++++++-- .../classification/matthews_corrcoef.py | 87 +++++-- .../classification/precision_recall.py | 218 +++++++++++++++--- .../classification/precision_recall_curve.py | 102 ++++++-- src/torchmetrics/classification/roc.py | 123 +++++++--- .../classification/specificity.py | 112 +++++++-- .../classification/stat_scores.py | 170 ++++++++++++-- .../functional/classification/accuracy.py | 184 ++++++++++++++- .../functional/classification/auroc.py | 153 +++++++++++- .../classification/average_precision.py | 127 ++++++++++ .../classification/calibration_error.py | 32 +++ .../functional/classification/cohen_kappa.py | 43 ++++ .../classification/confusion_matrix.py | 79 ++++++- .../functional/classification/dice.py | 9 + .../functional/classification/f_beta.py | 67 ++++++ .../functional/classification/hamming.py | 33 ++- .../functional/classification/hinge.py | 115 ++++++++- .../functional/classification/jaccard.py | 53 +++++ .../classification/matthews_corrcoef.py | 30 +++ .../classification/precision_recall.py | 47 ++++ .../classification/precision_recall_curve.py | 169 ++++++++++++++ .../functional/classification/roc.py | 143 +++++++++++- .../functional/classification/specificity.py | 41 ++++ .../functional/classification/stat_scores.py | 12 + 34 files changed, 3040 insertions(+), 417 deletions(-) diff --git a/src/torchmetrics/classification/accuracy.py b/src/torchmetrics/classification/accuracy.py index a73045d9d39..65c0e380f1f 100644 --- a/src/torchmetrics/classification/accuracy.py +++ b/src/torchmetrics/classification/accuracy.py @@ -14,17 +14,27 @@ from typing import Any, Optional import torch -from torch import Tensor +from torch import Tensor, tensor from typing_extensions import Literal -from torchmetrics.functional.classification.accuracy import _accuracy_reduce +from torchmetrics.functional.classification.accuracy import ( + _accuracy_compute, + _accuracy_reduce, + _accuracy_update, + _check_subset_validity, + _mode, + _subset_accuracy_compute, + _subset_accuracy_update, +) from torchmetrics.metric import Metric from torchmetrics.classification.stat_scores import ( # isort:skip BinaryStatScores, MulticlassStatScores, - MultilabelStatScores, + MultilabelStatScores, StatScores, ) +from torchmetrics.utilities import rank_zero_warn +from torchmetrics.utilities.enums import AverageMethod class BinaryAccuracy(BinaryStatScores): @@ -312,8 +322,17 @@ def compute(self) -> Tensor: ) -class Accuracy: - r"""Computes `Accuracy`_ +class Accuracy(StatScores): + r"""Accuracy. + + .. note:: + From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes Accuracy_: .. math:: \text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i) @@ -329,40 +348,168 @@ class Accuracy: >>> import torch >>> target = torch.tensor([0, 1, 2, 3]) >>> preds = torch.tensor([0, 2, 1, 3]) - >>> accuracy = Accuracy(task="multiclass", num_classes=4) + >>> accuracy = Accuracy() >>> accuracy(preds, target) tensor(0.5000) >>> target = torch.tensor([0, 1, 2]) >>> preds = torch.tensor([[0.1, 0.9, 0], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3]]) - >>> accuracy = Accuracy(task="multiclass", num_classes=3, top_k=2) + >>> accuracy = Accuracy(top_k=2) >>> accuracy(preds, target) tensor(0.6667) """ + is_differentiable = False + higher_is_better = True + full_state_update: bool = False + correct: Tensor + total: Tensor def __new__( cls, - task: Literal["binary", "multiclass", "multilabel"], + task: Literal["binary", "multiclass", "multilabel"] = None, threshold: float = 0.5, num_classes: Optional[int] = None, num_labels: Optional[int] = None, average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", - multidim_average: Literal["global", "samplewise"] = "global", + multidim_average: Optional[Literal["global", "samplewise"]] = "global", top_k: Optional[int] = 1, ignore_index: Optional[int] = None, validate_args: bool = True, + mdmc_average: Optional[str] = None, + multiclass: Optional[bool] = None, + subset_accuracy: bool = False, **kwargs: Any, ) -> Metric: - kwargs.update(dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args)) - if task == "binary": - return BinaryAccuracy(threshold, **kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - assert isinstance(top_k, int) - return MulticlassAccuracy(num_classes, top_k, average, **kwargs) - if task == "multilabel": - assert isinstance(num_labels, int) - return MultilabelAccuracy(num_labels, threshold, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + if task is not None: + assert multidim_average is not None + kwargs.update( + dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + ) + if task == "binary": + return BinaryAccuracy(threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return MulticlassAccuracy(num_classes, top_k, average, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelAccuracy(num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + return super().__new__(cls) + + def __init__( + self, + task: Literal["binary", "multiclass", "multilabel"] = None, + threshold: float = 0.5, + num_classes: Optional[int] = None, + num_labels: Optional[int] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + top_k: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + mdmc_average: Optional[str] = None, + multiclass: Optional[bool] = None, + subset_accuracy: bool = False, + **kwargs: Any, + ) -> None: + allowed_average = ["micro", "macro", "weighted", "samples", "none", None] + if average not in allowed_average: + raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + + _reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None) + if "reduce" not in kwargs: + kwargs["reduce"] = AverageMethod.MACRO if average in _reduce_options else average + if "mdmc_reduce" not in kwargs: + kwargs["mdmc_reduce"] = mdmc_average + + super().__init__( + threshold=threshold, + top_k=top_k, + num_classes=num_classes, + multiclass=multiclass, + ignore_index=ignore_index, + **kwargs, ) + + if top_k is not None and (not isinstance(top_k, int) or top_k <= 0): + raise ValueError(f"The `top_k` should be an integer larger than 0, got {top_k}") + + self.average = average + self.threshold = threshold + self.top_k = top_k + self.subset_accuracy = subset_accuracy + self.mode: DataType = None # type: ignore + self.multiclass = multiclass + self.ignore_index = ignore_index + + if self.subset_accuracy: + self.add_state("correct", default=tensor(0), dist_reduce_fx="sum") + self.add_state("total", default=tensor(0), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update state with predictions and targets. + """ + mode = _mode(preds, target, self.threshold, self.top_k, self.num_classes, self.multiclass, self.ignore_index) + + if not self.mode: + self.mode = mode + elif self.mode != mode: + raise ValueError(f"You can not use {mode} inputs with {self.mode} inputs.") + + if self.subset_accuracy and not _check_subset_validity(self.mode): + self.subset_accuracy = False + + if self.subset_accuracy: + correct, total = _subset_accuracy_update( + preds, target, threshold=self.threshold, top_k=self.top_k, ignore_index=self.ignore_index + ) + self.correct += correct + self.total += total + else: + if not self.mode: + raise RuntimeError("You have to have determined mode.") + tp, fp, tn, fn = _accuracy_update( + preds, + target, + reduce=self.reduce, + mdmc_reduce=self.mdmc_reduce, + threshold=self.threshold, + num_classes=self.num_classes, + top_k=self.top_k, + multiclass=self.multiclass, + ignore_index=self.ignore_index, + mode=self.mode, + ) + + # Update states + if self.reduce != "samples" and self.mdmc_reduce != "samplewise": + self.tp += tp + self.fp += fp + self.tn += tn + self.fn += fn + else: + self.tp.append(tp) + self.fp.append(fp) + self.tn.append(tn) + self.fn.append(fn) + + def compute(self) -> Tensor: + """Computes accuracy based on inputs passed in to ``update`` previously.""" + if not self.mode: + raise RuntimeError("You have to have determined mode.") + if self.subset_accuracy: + return _subset_accuracy_compute(self.correct, self.total) + tp, fp, tn, fn = self._get_final_stats() + return _accuracy_compute(tp, fp, tn, fn, self.average, self.mdmc_reduce, self.mode) diff --git a/src/torchmetrics/classification/auroc.py b/src/torchmetrics/classification/auroc.py index 14a6d9f3e88..910802e78e4 100644 --- a/src/torchmetrics/classification/auroc.py +++ b/src/torchmetrics/classification/auroc.py @@ -23,6 +23,8 @@ MultilabelPrecisionRecallCurve, ) from torchmetrics.functional.classification.auroc import ( + _auroc_compute, + _auroc_update, _binary_auroc_arg_validation, _binary_auroc_compute, _multiclass_auroc_arg_validation, @@ -31,7 +33,9 @@ _multilabel_auroc_compute, ) from torchmetrics.metric import Metric +from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.enums import DataType class BinaryAUROC(BinaryPrecisionRecallCurve): @@ -314,8 +318,18 @@ def compute(self) -> Tensor: return _multilabel_auroc_compute(state, self.num_labels, self.average, self.thresholds, self.ignore_index) -class AUROC: - r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_). The AUROC score summarizes the +class AUROC(Metric): + r"""Area Under the Receiver Operating Characteristic Curve. + + .. note:: + From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + + Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_). The AUROC score summarizes the ROC curve into an single number that describes the performance of a model for multiple thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 corresponds to random guessing. @@ -327,24 +341,30 @@ class AUROC: Legacy Example: >>> preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]) >>> target = torch.tensor([0, 0, 1, 1, 1]) - >>> auroc = AUROC(task="binary") + >>> auroc = AUROC(pos_label=1) >>> auroc(preds, target) tensor(0.5000) + Example (multiclass case): >>> preds = torch.tensor([[0.90, 0.05, 0.05], ... [0.05, 0.90, 0.05], ... [0.05, 0.05, 0.90], ... [0.85, 0.05, 0.10], ... [0.10, 0.10, 0.80]]) >>> target = torch.tensor([0, 1, 1, 2, 2]) - >>> auroc = AUROC(task="multiclass", num_classes=3) + >>> auroc = AUROC(num_classes=3) >>> auroc(preds, target) tensor(0.7778) """ + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + preds: List[Tensor] + target: List[Tensor] def __new__( cls, - task: Literal["binary", "multiclass", "multilabel"], + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, thresholds: Optional[Union[int, List[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, @@ -352,17 +372,99 @@ def __new__( max_fpr: Optional[float] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + pos_label: Optional[int] = None, **kwargs: Any, ) -> Metric: - kwargs.update(dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)) - if task == "binary": - return BinaryAUROC(max_fpr, **kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - return MulticlassAUROC(num_classes, average, **kwargs) - if task == "multilabel": - assert isinstance(num_labels, int) - return MultilabelAUROC(num_labels, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + if task is not None: + kwargs.update(dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryAUROC(max_fpr, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + return MulticlassAUROC(num_classes, average, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelAUROC(num_labels, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + return super().__new__(cls) + + def __init__( + self, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + num_classes: Optional[int] = None, + num_labels: Optional[int] = None, + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + max_fpr: Optional[float] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + pos_label: Optional[int] = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + + self.num_classes = num_classes + self.pos_label = pos_label + self.average = average + self.max_fpr = max_fpr + + allowed_average = (None, "macro", "weighted", "micro") + if self.average not in allowed_average: + raise ValueError( + f"Argument `average` expected to be one of the following: {allowed_average} but got {average}" + ) + + if self.max_fpr is not None: + if not isinstance(max_fpr, float) or not 0 < max_fpr <= 1: + raise ValueError(f"`max_fpr` should be a float in range (0, 1], got: {max_fpr}") + + self.mode: DataType = None # type: ignore + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + + rank_zero_warn( + "Metric `AUROC` will save all targets and predictions in buffer." + " For large datasets this may lead to large memory footprint." + ) + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update state with predictions and targets. + """ + preds, target, mode = _auroc_update(preds, target) + + self.preds.append(preds) + self.target.append(target) + + if self.mode and self.mode != mode: + raise ValueError( + "The mode of data (binary, multi-label, multi-class) should be constant, but changed" + f" between batches from {self.mode} to {mode}" + ) + self.mode = mode + + def compute(self) -> Tensor: + """Computes AUROC based on inputs passed in to ``update`` previously.""" + if not self.mode: + raise RuntimeError("You have to have determined mode.") + preds = dim_zero_cat(self.preds) + target = dim_zero_cat(self.target) + return _auroc_compute( + preds, + target, + self.mode, + self.num_classes, + self.pos_label, + self.average, + self.max_fpr, ) diff --git a/src/torchmetrics/classification/average_precision.py b/src/torchmetrics/classification/average_precision.py index be736fa35fe..80eaa84c7ea 100644 --- a/src/torchmetrics/classification/average_precision.py +++ b/src/torchmetrics/classification/average_precision.py @@ -23,6 +23,8 @@ MultilabelPrecisionRecallCurve, ) from torchmetrics.functional.classification.average_precision import ( + _average_precision_compute, + _average_precision_update, _binary_average_precision_compute, _multiclass_average_precision_arg_validation, _multiclass_average_precision_compute, @@ -30,6 +32,7 @@ _multilabel_average_precision_compute, ) from torchmetrics.metric import Metric +from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat @@ -315,8 +318,17 @@ def compute(self) -> Tensor: ) -class AveragePrecision: - r"""Computes the average precision (AP) score. The AP score summarizes a precision-recall curve as an weighted +class AveragePrecision(Metric): + r"""Average Precision. + + .. note:: + From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes the average precision (AP) score. The AP score summarizes a precision-recall curve as an weighted mean of precisions at each threshold, with the difference in recall from the previous threshold as weight: .. math:: @@ -333,40 +345,108 @@ class AveragePrecision: Legacy Example: >>> pred = torch.tensor([0, 0.1, 0.8, 0.4]) >>> target = torch.tensor([0, 1, 1, 1]) - >>> average_precision = AveragePrecision(task="binary") + >>> average_precision = AveragePrecision(pos_label=1) >>> average_precision(pred, target) tensor(1.) + Example (multiclass case): >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = torch.tensor([0, 1, 3, 2]) - >>> average_precision = AveragePrecision(task="multiclass", num_classes=5, average=None) + >>> average_precision = AveragePrecision(num_classes=5, average=None) >>> average_precision(pred, target) - tensor([1.0000, 1.0000, 0.2500, 0.2500, nan]) + [tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)] """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + preds: List[Tensor] + target: List[Tensor] + def __new__( cls, - task: Literal["binary", "multiclass", "multilabel"], + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, thresholds: Optional[Union[int, List[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, average: Optional[Literal["macro", "weighted", "none"]] = "macro", ignore_index: Optional[int] = None, validate_args: bool = True, + pos_label: Optional[int] = None, **kwargs: Any, ) -> Metric: - kwargs.update(dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)) - if task == "binary": - return BinaryAveragePrecision(**kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - return MulticlassAveragePrecision(num_classes, average, **kwargs) - if task == "multilabel": - assert isinstance(num_labels, int) - return MultilabelAveragePrecision(num_labels, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + if task is not None: + kwargs.update(dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryAveragePrecision(**kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + return MulticlassAveragePrecision(num_classes, average, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelAveragePrecision(num_labels, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + return super().__new__(cls) + + def __init__( + self, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + num_classes: Optional[int] = None, + num_labels: Optional[int] = None, + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + ignore_index: Optional[int] = None, + validate_args: bool = True, + pos_label: Optional[int] = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + + self.num_classes = num_classes + self.pos_label = pos_label + allowed_average = ("micro", "macro", "weighted", "none", None) + if average not in allowed_average: + raise ValueError(f"Expected argument `average` to be one of {allowed_average}" f" but got {average}") + self.average = average + + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + + rank_zero_warn( + "Metric `AveragePrecision` will save all targets and predictions in buffer." + " For large datasets this may lead to large memory footprint." + ) + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update state with predictions and targets. + """ + preds, target, num_classes, pos_label = _average_precision_update( + preds, target, self.num_classes, self.pos_label, self.average ) + self.preds.append(preds) + self.target.append(target) + self.num_classes = num_classes + self.pos_label = pos_label + + def compute(self) -> Union[Tensor, List[Tensor]]: + """Compute the average precision score. + """ + preds = dim_zero_cat(self.preds) + target = dim_zero_cat(self.target) + if not self.num_classes: + raise ValueError(f"`num_classes` bas to be positive number, but got {self.num_classes}") + return _average_precision_compute(preds, target, self.num_classes, self.pos_label, self.average) diff --git a/src/torchmetrics/classification/calibration_error.py b/src/torchmetrics/classification/calibration_error.py index f863addbe8b..2b0021c5e30 100644 --- a/src/torchmetrics/classification/calibration_error.py +++ b/src/torchmetrics/classification/calibration_error.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any, List, Optional import torch from torch import Tensor @@ -23,6 +23,7 @@ _binary_calibration_error_update, _binary_confusion_matrix_format, _ce_compute, + _ce_update, _multiclass_calibration_error_arg_validation, _multiclass_calibration_error_tensor_validation, _multiclass_calibration_error_update, @@ -30,6 +31,7 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.prints import rank_zero_warn class BinaryCalibrationError(Metric): @@ -220,8 +222,17 @@ def compute(self) -> Tensor: return _ce_compute(confidences, accuracies, self.n_bins, norm=self.norm) -class CalibrationError: - r"""`Computes the Top-label Calibration Error`_. The expected calibration error can be used to quantify how well +class CalibrationError(Metric): + r"""Calibration Error. + + .. note:: + From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + `Computes the Top-label Calibration Error`_. The expected calibration error can be used to quantify how well a given model is calibrated e.g. how well the predicted output probabilities of the model matches the actual probabilities of the ground truth distribution. @@ -245,10 +256,16 @@ class CalibrationError: :mod:`BinaryCalibrationError` and :mod:`MulticlassCalibrationError` for the specific details of each argument influence and examples. """ + is_differentiable: bool = False + higher_is_better: bool = False + full_state_update: bool = False + DISTANCES = {"l1", "l2", "max"} + confidences: List[Tensor] + accuracies: List[Tensor] def __new__( cls, - task: Literal["binary", "multiclass"] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, n_bins: int = 15, norm: Literal["l1", "l2", "max"] = "l1", num_classes: Optional[int] = None, @@ -256,12 +273,59 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: - kwargs.update(dict(n_bins=n_bins, norm=norm, ignore_index=ignore_index, validate_args=validate_args)) - if task == "binary": - return BinaryCalibrationError(**kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - return MulticlassCalibrationError(num_classes, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) + if task is not None: + kwargs.update(dict(n_bins=n_bins, norm=norm, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryCalibrationError(**kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + return MulticlassCalibrationError(num_classes, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + return super().__new__(cls) + + def __init__( + self, + n_bins: int = 15, + norm: str = "l1", + **kwargs: Any, + ): + super().__init__(**kwargs) + + if norm not in self.DISTANCES: + raise ValueError(f"Norm {norm} is not supported. Please select from l1, l2, or max. ") + + if not isinstance(n_bins, int) or n_bins <= 0: + raise ValueError(f"Expected argument `n_bins` to be a int larger than 0 but got {n_bins}") + self.n_bins = n_bins + self.bin_boundaries = torch.linspace(0, 1, n_bins + 1) + self.norm = norm + + self.add_state("confidences", [], dist_reduce_fx="cat") + self.add_state("accuracies", [], dist_reduce_fx="cat") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Computes top-level confidences and accuracies for the input probabilities and appends them to internal + state. + """ + confidences, accuracies = _ce_update(preds, target) + + self.confidences.append(confidences) + self.accuracies.append(accuracies) + + def compute(self) -> Tensor: + """Computes calibration error across all confidences and accuracies. + """ + confidences = dim_zero_cat(self.confidences) + accuracies = dim_zero_cat(self.accuracies) + return _ce_compute(confidences, accuracies, self.bin_boundaries.to(self.device), norm=self.norm) diff --git a/src/torchmetrics/classification/cohen_kappa.py b/src/torchmetrics/classification/cohen_kappa.py index 70679526ce1..511d7176055 100644 --- a/src/torchmetrics/classification/cohen_kappa.py +++ b/src/torchmetrics/classification/cohen_kappa.py @@ -20,10 +20,13 @@ from torchmetrics.classification import BinaryConfusionMatrix, MulticlassConfusionMatrix from torchmetrics.functional.classification.cohen_kappa import ( _binary_cohen_kappa_arg_validation, + _cohen_kappa_compute, _cohen_kappa_reduce, + _cohen_kappa_update, _multiclass_cohen_kappa_arg_validation, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.prints import rank_zero_warn class BinaryCohenKappa(BinaryConfusionMatrix): @@ -177,8 +180,17 @@ def compute(self) -> Tensor: return _cohen_kappa_reduce(self.confmat, self.weights) -class CohenKappa: - r"""Calculates `Cohen's kappa score`_ that measures inter-annotator agreement. It is defined as. +class CohenKappa(Metric): + r"""Cohen Kappa. + + .. note:: + From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Calculates `Cohen's kappa score`_ that measures inter-annotator agreement. It is defined as .. math:: \kappa = (p_o - p_e) / (1 - p_e) @@ -196,14 +208,18 @@ class labels. Legacy Example: >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0, 1, 0, 0]) - >>> cohenkappa = CohenKappa(task="multiclass", num_classes=2) + >>> cohenkappa = CohenKappa(num_classes=2) >>> cohenkappa(preds, target) tensor(0.5000) """ + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + confmat: Tensor def __new__( cls, - task: Literal["binary", "multiclass"], + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, threshold: float = 0.5, num_classes: Optional[int] = None, weights: Optional[Literal["linear", "quadratic", "none"]] = None, @@ -211,12 +227,51 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: - kwargs.update(dict(weights=weights, ignore_index=ignore_index, validate_args=validate_args)) - if task == "binary": - return BinaryCohenKappa(threshold, **kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - return MulticlassCohenKappa(num_classes, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) + if task is not None: + kwargs.update(dict(weights=weights, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryCohenKappa(threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + return MulticlassCohenKappa(num_classes, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + return super().__new__(cls) + + def __init__( + self, + num_classes: int, + weights: Optional[str] = None, + threshold: float = 0.5, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.num_classes = num_classes + self.weights = weights + self.threshold = threshold + + allowed_weights = ("linear", "quadratic", "none", None) + if self.weights not in allowed_weights: + raise ValueError(f"Argument weights needs to one of the following: {allowed_weights}") + + self.add_state("confmat", default=torch.zeros(num_classes, num_classes, dtype=torch.long), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update state with predictions and targets. + """ + confmat = _cohen_kappa_update(preds, target, self.num_classes, self.threshold) + self.confmat += confmat + + def compute(self) -> Tensor: + """Computes cohen kappa score.""" + return _cohen_kappa_compute(self.confmat, self.weights) diff --git a/src/torchmetrics/classification/confusion_matrix.py b/src/torchmetrics/classification/confusion_matrix.py index e2fd33b9105..598db010bc6 100644 --- a/src/torchmetrics/classification/confusion_matrix.py +++ b/src/torchmetrics/classification/confusion_matrix.py @@ -23,6 +23,8 @@ _binary_confusion_matrix_format, _binary_confusion_matrix_tensor_validation, _binary_confusion_matrix_update, + _confusion_matrix_compute, + _confusion_matrix_update, _multiclass_confusion_matrix_arg_validation, _multiclass_confusion_matrix_compute, _multiclass_confusion_matrix_format, @@ -35,6 +37,7 @@ _multilabel_confusion_matrix_update, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.prints import rank_zero_warn class BinaryConfusionMatrix(Metric): @@ -312,8 +315,15 @@ def compute(self) -> Tensor: return _multilabel_confusion_matrix_compute(self.confmat, self.normalize) -class ConfusionMatrix: - r"""Computes the `confusion matrix`_. +class ConfusionMatrix(Metric): + r"""Confusion Matrix. + + .. note:: + From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of @@ -323,48 +333,105 @@ class ConfusionMatrix: Legacy Example: >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0, 1, 0, 0]) - >>> confmat = ConfusionMatrix(task="binary", num_classes=2) + >>> confmat = ConfusionMatrix(num_classes=2) >>> confmat(preds, target) tensor([[2, 0], [1, 1]]) + Example (multiclass data): >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([2, 1, 0, 1]) - >>> confmat = ConfusionMatrix(task="multiclass", num_classes=3) + >>> confmat = ConfusionMatrix(num_classes=3) >>> confmat(preds, target) tensor([[1, 1, 0], [0, 1, 0], [0, 0, 1]]) + Example (multilabel data): >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) - >>> confmat = ConfusionMatrix(task="multilabel", num_labels=3) + >>> confmat = ConfusionMatrix(num_classes=3, multilabel=True) >>> confmat(preds, target) tensor([[[1, 0], [0, 1]], [[1, 0], [1, 0]], [[0, 1], [0, 1]]]) """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + confmat: Tensor def __new__( cls, - task: Literal["binary", "multiclass", "multilabel"], + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, threshold: float = 0.5, num_classes: Optional[int] = None, num_labels: Optional[int] = None, normalize: Optional[Literal["true", "pred", "all", "none"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + multilabel: bool = False, **kwargs: Any, ) -> Metric: - kwargs.update(dict(normalize=normalize, ignore_index=ignore_index, validate_args=validate_args)) - if task == "binary": - return BinaryConfusionMatrix(threshold, **kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - return MulticlassConfusionMatrix(num_classes, **kwargs) - if task == "multilabel": - assert isinstance(num_labels, int) - return MultilabelConfusionMatrix(num_labels, threshold, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) + if task is not None: + kwargs.update(dict(normalize=normalize, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryConfusionMatrix(threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + return MulticlassConfusionMatrix(num_classes, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelConfusionMatrix(num_labels, threshold, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + return super().__new__(cls) + + def __init__( + self, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + threshold: float = 0.5, + num_classes: Optional[int] = None, + num_labels: Optional[int] = None, + normalize: Optional[str] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + multilabel: bool = False, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.num_classes = num_classes + self.normalize = normalize + self.threshold = threshold + self.multilabel = multilabel + + allowed_normalize = ("true", "pred", "all", "none", None) + if self.normalize not in allowed_normalize: + raise ValueError(f"Argument average needs to one of the following: {allowed_normalize}") + + if multilabel: + default = torch.zeros(num_classes, 2, 2, dtype=torch.long) + else: + default = torch.zeros(num_classes, num_classes, dtype=torch.long) + self.add_state("confmat", default=default, dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update state with predictions and targets. + """ + confmat = _confusion_matrix_update(preds, target, self.num_classes, self.threshold, self.multilabel) + self.confmat += confmat + + def compute(self) -> Tensor: + """Computes confusion matrix. + """ + return _confusion_matrix_compute(self.confmat, self.normalize) diff --git a/src/torchmetrics/classification/dice.py b/src/torchmetrics/classification/dice.py index 926770f6b54..6067d7309a1 100644 --- a/src/torchmetrics/classification/dice.py +++ b/src/torchmetrics/classification/dice.py @@ -11,19 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Tuple, no_type_check +from typing import Any, Optional -import torch from torch import Tensor from typing_extensions import Literal +from torchmetrics.classification.stat_scores import StatScores from torchmetrics.functional.classification.dice import _dice_compute -from torchmetrics.functional.classification.stat_scores import _stat_scores_update -from torchmetrics.metric import Metric -from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod +from torchmetrics.utilities.enums import AverageMethod -class Dice(Metric): +class Dice(StatScores): r"""Computes `Dice`_: .. math:: \text{Dice} = \frac{\text{2 * TP}}{\text{2 * TP} + \text{FP} + \text{FN}} @@ -143,86 +141,18 @@ def __init__( if "mdmc_reduce" not in kwargs: kwargs["mdmc_reduce"] = mdmc_average - self.reduce = average - self.mdmc_reduce = mdmc_average - self.num_classes = num_classes - self.threshold = threshold - self.multiclass = multiclass - self.ignore_index = ignore_index - self.top_k = top_k - - if average not in ["micro", "macro", "samples"]: - raise ValueError(f"The `reduce` {average} is not valid.") - - if mdmc_average not in [None, "samplewise", "global"]: - raise ValueError(f"The `mdmc_reduce` {mdmc_average} is not valid.") - - if average == "macro" and (not num_classes or num_classes < 1): - raise ValueError("When you set `average` as 'macro', you have to provide the number of classes.") - - if num_classes and ignore_index is not None and (not ignore_index < num_classes or num_classes == 1): - raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") - - default: Callable = lambda: [] - reduce_fn: Optional[str] = "cat" - if mdmc_average != "samplewise" and average != "samples": - if average == "micro": - zeros_shape = [] - elif average == "macro": - zeros_shape = [num_classes] - else: - raise ValueError(f'Wrong reduce="{average}"') - default = lambda: torch.zeros(zeros_shape, dtype=torch.long) - reduce_fn = "sum" - - for s in ("tp", "fp", "tn", "fn"): - self.add_state(s, default=default(), dist_reduce_fx=reduce_fn) + super().__init__( + threshold=threshold, + top_k=top_k, + num_classes=num_classes, + multiclass=multiclass, + ignore_index=ignore_index, + **kwargs, + ) self.average = average self.zero_division = zero_division - @no_type_check - def update(self, preds: Tensor, target: Tensor) -> None: - """Update state with predictions and targets. - - Args: - preds: Predictions from model (probabilities, logits or labels) - target: Ground truth values - """ - tp, fp, tn, fn = _stat_scores_update( - preds, - target, - reduce=self.reduce, - mdmc_reduce=self.mdmc_reduce, - threshold=self.threshold, - num_classes=self.num_classes, - top_k=self.top_k, - multiclass=self.multiclass, - ignore_index=self.ignore_index, - ) - - # Update states - if self.reduce != AverageMethod.SAMPLES and self.mdmc_reduce != MDMCAverageMethod.SAMPLEWISE: - self.tp += tp - self.fp += fp - self.tn += tn - self.fn += fn - else: - self.tp.append(tp) - self.fp.append(fp) - self.tn.append(tn) - self.fn.append(fn) - - @no_type_check - def _get_final_stats(self) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - """Performs concatenation on the stat scores if neccesary, before passing them to a compute function.""" - tp = torch.cat(self.tp) if isinstance(self.tp, list) else self.tp - fp = torch.cat(self.fp) if isinstance(self.fp, list) else self.fp - tn = torch.cat(self.tn) if isinstance(self.tn, list) else self.tn - fn = torch.cat(self.fn) if isinstance(self.fn, list) else self.fn - return tp, fp, tn, fn - - @no_type_check def compute(self) -> Tensor: """Computes the dice score based on inputs passed in to ``update`` previously. diff --git a/src/torchmetrics/classification/f_beta.py b/src/torchmetrics/classification/f_beta.py index 49fd39005c5..b70d7347e52 100644 --- a/src/torchmetrics/classification/f_beta.py +++ b/src/torchmetrics/classification/f_beta.py @@ -17,14 +17,22 @@ from torch import Tensor from typing_extensions import Literal -from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores +from torchmetrics.classification.stat_scores import ( + BinaryStatScores, + MulticlassStatScores, + MultilabelStatScores, + StatScores, +) from torchmetrics.functional.classification.f_beta import ( _binary_fbeta_score_arg_validation, + _fbeta_compute, _fbeta_reduce, _multiclass_fbeta_score_arg_validation, _multilabel_fbeta_score_arg_validation, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import AverageMethod +from torchmetrics.utilities.prints import rank_zero_warn class BinaryFBetaScore(BinaryStatScores): @@ -700,8 +708,17 @@ def __init__( ) -class FBetaScore: - r"""Computes `F-score`_ metric: +class FBetaScore(StatScores): + r"""F-Beta Score. + + .. note:: + From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes `F-score`_ metric: .. math:: F_{\beta} = (1 + \beta^2) * \frac{\text{precision} * \text{recall}} @@ -716,14 +733,15 @@ class FBetaScore: >>> import torch >>> target = torch.tensor([0, 1, 2, 0, 1, 2]) >>> preds = torch.tensor([0, 2, 1, 0, 0, 1]) - >>> f_beta = FBetaScore(task="multiclass", num_classes=3, beta=0.5) + >>> f_beta = FBetaScore(num_classes=3, beta=0.5) >>> f_beta(preds, target) tensor(0.3333) """ + full_state_update: bool = False def __new__( cls, - task: Literal["binary", "multiclass", "multilabel"], + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, beta: float = 1.0, threshold: float = 0.5, num_classes: Optional[int] = None, @@ -733,26 +751,92 @@ def __new__( top_k: Optional[int] = 1, ignore_index: Optional[int] = None, validate_args: bool = True, + mdmc_average: Optional[str] = None, + multiclass: Optional[bool] = None, **kwargs: Any, ) -> Metric: - assert multidim_average is not None - kwargs.update(dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args)) - if task == "binary": - return BinaryFBetaScore(beta, threshold, **kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - assert isinstance(top_k, int) - return MulticlassFBetaScore(beta, num_classes, top_k, average, **kwargs) - if task == "multilabel": - assert isinstance(num_labels, int) - return MultilabelFBetaScore(beta, num_labels, threshold, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + if task is not None: + assert multidim_average is not None + kwargs.update( + dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + ) + if task == "binary": + return BinaryFBetaScore(beta, threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return MulticlassFBetaScore(beta, num_classes, top_k, average, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelFBetaScore(beta, num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + return super().__new__(cls) + + def __init__( + self, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + beta: float = 1.0, + threshold: float = 0.5, + num_classes: Optional[int] = None, + num_labels: Optional[int] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + top_k: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + mdmc_average: Optional[str] = None, + multiclass: Optional[bool] = None, + **kwargs: Any, + ) -> None: + self.beta = beta + allowed_average = list(AverageMethod) + if average not in allowed_average: + raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + + _reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None) + if "reduce" not in kwargs: + kwargs["reduce"] = AverageMethod.MACRO if average in _reduce_options else average + if "mdmc_reduce" not in kwargs: + kwargs["mdmc_reduce"] = mdmc_average + + super().__init__( + threshold=threshold, + top_k=top_k, + num_classes=num_classes, + multiclass=multiclass, + ignore_index=ignore_index, + **kwargs, ) + self.average = average + + def compute(self) -> Tensor: + """Computes f-beta over state.""" + tp, fp, tn, fn = self._get_final_stats() + return _fbeta_compute(tp, fp, tn, fn, self.beta, self.ignore_index, self.average, self.mdmc_reduce) + +class F1Score(FBetaScore): + r"""F1 Score. -class F1Score: - r"""Computes F-1 score: + .. note:: + From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes F-1 score: .. math:: F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}} @@ -766,14 +850,18 @@ class F1Score: >>> import torch >>> target = torch.tensor([0, 1, 2, 0, 1, 2]) >>> preds = torch.tensor([0, 2, 1, 0, 0, 1]) - >>> f1 = F1Score(task="multiclass", num_classes=3) + >>> f1 = F1Score(num_classes=3) >>> f1(preds, target) tensor(0.3333) """ + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + def __new__( cls, - task: Literal["binary", "multiclass", "multilabel"], + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, threshold: float = 0.5, num_classes: Optional[int] = None, num_labels: Optional[int] = None, @@ -782,19 +870,61 @@ def __new__( top_k: Optional[int] = 1, ignore_index: Optional[int] = None, validate_args: bool = True, + mdmc_average: Optional[str] = None, + multiclass: Optional[bool] = None, **kwargs: Any, ) -> Metric: - assert multidim_average is not None - kwargs.update(dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args)) - if task == "binary": - return BinaryF1Score(threshold, **kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - assert isinstance(top_k, int) - return MulticlassF1Score(num_classes, top_k, average, **kwargs) - if task == "multilabel": - assert isinstance(num_labels, int) - return MultilabelF1Score(num_labels, threshold, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + if task is not None: + assert multidim_average is not None + kwargs.update( + dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + ) + if task == "binary": + return BinaryF1Score(threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return MulticlassF1Score(num_classes, top_k, average, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelF1Score(num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + return super().__new__(cls) + + def __init__( + self, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + threshold: float = 0.5, + num_classes: Optional[int] = None, + num_labels: Optional[int] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + top_k: Optional[int] = 1, + ignore_index: Optional[int] = None, + validate_args: bool = True, + mdmc_average: Optional[str] = None, + multiclass: Optional[bool] = None, + **kwargs: Any, + ) -> None: + super().__init__( + num_classes=num_classes, + beta=1.0, + threshold=threshold, + average=average, + mdmc_average=mdmc_average, + ignore_index=ignore_index, + top_k=top_k, + multiclass=multiclass, + **kwargs, ) diff --git a/src/torchmetrics/classification/hamming.py b/src/torchmetrics/classification/hamming.py index cbb7a5ce987..163ef938f8c 100644 --- a/src/torchmetrics/classification/hamming.py +++ b/src/torchmetrics/classification/hamming.py @@ -14,12 +14,17 @@ from typing import Any, Optional import torch -from torch import Tensor +from torch import Tensor, tensor from typing_extensions import Literal from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores -from torchmetrics.functional.classification.hamming import _hamming_distance_reduce +from torchmetrics.functional.classification.hamming import ( + _hamming_distance_compute, + _hamming_distance_reduce, + _hamming_distance_update, +) from torchmetrics.metric import Metric +from torchmetrics.utilities.prints import rank_zero_warn class BinaryHammingDistance(BinaryStatScores): @@ -312,8 +317,17 @@ def compute(self) -> Tensor: ) -class HammingDistance: - r"""Computes the average `Hamming distance`_ (also known as Hamming loss): +class HammingDistance(Metric): + r"""Hamming distance. + + .. note:: + From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes the average `Hamming distance`_ (also known as Hamming loss): .. math:: \text{Hamming distance} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il}) @@ -330,14 +344,19 @@ class HammingDistance: Legacy Example: >>> target = torch.tensor([[0, 1], [1, 1]]) >>> preds = torch.tensor([[0, 1], [0, 1]]) - >>> hamming_distance = HammingDistance(task="multilabel", num_labels=2) + >>> hamming_distance = HammingDistance() >>> hamming_distance(preds, target) tensor(0.2500) """ + is_differentiable: bool = False + higher_is_better: bool = False + full_state_update: bool = False + correct: Tensor + total: Tensor def __new__( cls, - task: Literal["binary", "multiclass", "multilabel"], + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, threshold: float = 0.5, num_classes: Optional[int] = None, num_labels: Optional[int] = None, @@ -348,18 +367,54 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + if task is not None: + assert multidim_average is not None + kwargs.update( + dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + ) + if task == "binary": + return BinaryHammingDistance(threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return MulticlassHammingDistance(num_classes, top_k, average, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelHammingDistance(num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + return super().__new__(cls) + + def __init__( + self, + threshold: float = 0.5, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) - assert multidim_average is not None - kwargs.update(dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args)) - if task == "binary": - return BinaryHammingDistance(threshold, **kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - assert isinstance(top_k, int) - return MulticlassHammingDistance(num_classes, top_k, average, **kwargs) - if task == "multilabel": - assert isinstance(num_labels, int) - return MultilabelHammingDistance(num_labels, threshold, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) + self.add_state("correct", default=tensor(0), dist_reduce_fx="sum") + self.add_state("total", default=tensor(0), dist_reduce_fx="sum") + + self.threshold = threshold + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update state with predictions and targets. + """ + correct, total = _hamming_distance_update(preds, target, self.threshold) + + self.correct += correct + self.total += total + + def compute(self) -> Tensor: + """Computes hamming distance based on inputs passed in to ``update`` previously.""" + return _hamming_distance_compute(self.correct, self.total) diff --git a/src/torchmetrics/classification/hinge.py b/src/torchmetrics/classification/hinge.py index 1c8b05986ab..81437042270 100644 --- a/src/torchmetrics/classification/hinge.py +++ b/src/torchmetrics/classification/hinge.py @@ -11,24 +11,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any, Optional, Union import torch from torch import Tensor from typing_extensions import Literal from torchmetrics.functional.classification.hinge import ( + MulticlassMode, _binary_confusion_matrix_format, _binary_hinge_loss_arg_validation, _binary_hinge_loss_tensor_validation, _binary_hinge_loss_update, + _hinge_compute, _hinge_loss_compute, + _hinge_update, _multiclass_confusion_matrix_format, _multiclass_hinge_loss_arg_validation, _multiclass_hinge_loss_tensor_validation, _multiclass_hinge_loss_update, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.prints import rank_zero_warn class BinaryHingeLoss(Metric): @@ -201,8 +205,17 @@ def compute(self) -> Tensor: return _hinge_loss_compute(self.measures, self.total) -class HingeLoss: - r"""Computes the mean `Hinge loss`_ typically used for Support Vector Machines (SVMs). +class HingeLoss(Metric): + r"""Hinge Loss. + + .. note:: + From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes the mean `Hinge loss`_ typically used for Support Vector Machines (SVMs). This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the ``task`` argument to either ``'binary'`` or ``'multiclass'``. See the documentation of @@ -212,27 +225,34 @@ class HingeLoss: Legacy Example: >>> import torch >>> target = torch.tensor([0, 1, 1]) - >>> preds = torch.tensor([0.5, 0.7, 0.1]) - >>> hinge = HingeLoss(task="binary") + >>> preds = torch.tensor([-2.2, 2.4, 0.1]) + >>> hinge = HingeLoss() >>> hinge(preds, target) - tensor(0.9000) + tensor(0.3000) + Example (default / multiclass case): >>> target = torch.tensor([0, 1, 2]) >>> preds = torch.tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]]) - >>> hinge = HingeLoss(task="multiclass", num_classes=3) + >>> hinge = HingeLoss() >>> hinge(preds, target) - tensor(1.5551) + tensor(2.9000) + Example (multiclass example, one vs all mode): >>> target = torch.tensor([0, 1, 2]) >>> preds = torch.tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]]) - >>> hinge = HingeLoss(task="multiclass", num_classes=3, multiclass_mode="one-vs-all") + >>> hinge = HingeLoss(multiclass_mode="one-vs-all") >>> hinge(preds, target) - tensor([1.3743, 1.1945, 1.2359]) + tensor([2.2333, 1.5000, 1.2333]) """ + is_differentiable: bool = True + higher_is_better: bool = False + full_state_update: bool = False + measure: Tensor + total: Tensor def __new__( cls, - task: Literal["binary", "multiclass"], + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, num_classes: Optional[int] = None, squared: bool = False, multiclass_mode: Optional[Literal["crammer-singer", "one-vs-all"]] = "crammer-singer", @@ -240,12 +260,54 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: - kwargs.update(dict(ignore_index=ignore_index, validate_args=validate_args)) - if task == "binary": - return BinaryHingeLoss(squared, **kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - return MulticlassHingeLoss(num_classes, squared, multiclass_mode, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) + if task is not None: + kwargs.update(dict(ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryHingeLoss(squared, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + assert multiclass_mode is not None + return MulticlassHingeLoss(num_classes, squared, multiclass_mode, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + return super().__new__(cls) + + def __init__( + self, + squared: bool = False, + multiclass_mode: Optional[Union[str, MulticlassMode]] = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + + self.add_state("measure", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + if multiclass_mode not in (None, MulticlassMode.CRAMMER_SINGER, MulticlassMode.ONE_VS_ALL): + raise ValueError( + "The `multiclass_mode` should be either None / 'crammer-singer' / MulticlassMode.CRAMMER_SINGER" + "(default) or 'one-vs-all' / MulticlassMode.ONE_VS_ALL," + f" got {multiclass_mode}." + ) + + self.squared = squared + self.multiclass_mode = multiclass_mode + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + measure, total = _hinge_update(preds, target, squared=self.squared, multiclass_mode=self.multiclass_mode) + + self.measure = measure + self.measure + self.total = total + self.total + + def compute(self) -> Tensor: + return _hinge_compute(self.measure, self.total) diff --git a/src/torchmetrics/classification/jaccard.py b/src/torchmetrics/classification/jaccard.py index c637a8a94b6..1b37c92397e 100644 --- a/src/torchmetrics/classification/jaccard.py +++ b/src/torchmetrics/classification/jaccard.py @@ -18,12 +18,15 @@ from typing_extensions import Literal from torchmetrics.classification import BinaryConfusionMatrix, MulticlassConfusionMatrix, MultilabelConfusionMatrix +from torchmetrics.classification.confusion_matrix import ConfusionMatrix from torchmetrics.functional.classification.jaccard import ( + _jaccard_from_confmat, _jaccard_index_reduce, _multiclass_jaccard_index_arg_validation, _multilabel_jaccard_index_arg_validation, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.prints import rank_zero_warn class BinaryJaccardIndex(BinaryConfusionMatrix): @@ -252,8 +255,17 @@ def compute(self) -> Tensor: return _jaccard_index_reduce(self.confmat, average=self.average) -class JaccardIndex: - r"""Calculates the Jaccard index for multilabel tasks. The `Jaccard index`_ (also known as the intersetion over +class JaccardIndex(ConfusionMatrix): + r"""Jaccard Index. + + .. note:: + From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Calculates the Jaccard index for multilabel tasks. The `Jaccard index`_ (also known as the intersetion over union or jaccard similarity coefficient) is an statistic that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the intersection divided by the union of the sample sets: @@ -269,31 +281,93 @@ class JaccardIndex: >>> target = torch.randint(0, 2, (10, 25, 25)) >>> pred = torch.tensor(target) >>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15] - >>> jaccard = JaccardIndex(task="multiclass", num_classes=2) + >>> jaccard = JaccardIndex(num_classes=2) >>> jaccard(pred, target) tensor(0.9660) """ + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False def __new__( cls, - task: Literal["binary", "multiclass", "multilabel"], + num_classes: int, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, threshold: float = 0.5, - num_classes: Optional[int] = None, num_labels: Optional[int] = None, average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", ignore_index: Optional[int] = None, + absent_score: float = 0.0, + multilabel: bool = False, validate_args: bool = True, **kwargs: Any, ) -> Metric: - kwargs.update(dict(ignore_index=ignore_index, validate_args=validate_args)) - if task == "binary": - return BinaryJaccardIndex(threshold, **kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - return MulticlassJaccardIndex(num_classes, average, **kwargs) - if task == "multilabel": - assert isinstance(num_labels, int) - return MultilabelJaccardIndex(num_labels, threshold, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + if task is not None: + kwargs.update(dict(ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryJaccardIndex(threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + return MulticlassJaccardIndex(num_classes, average, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelJaccardIndex(num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + return super().__new__(cls) + + def __init__( + self, + num_classes: int, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + ignore_index: Optional[int] = None, + absent_score: float = 0.0, + threshold: float = 0.5, + multilabel: bool = False, + **kwargs: Any, + ) -> None: + kwargs["normalize"] = kwargs.get("normalize") + + super().__init__( + num_classes=num_classes, + threshold=threshold, + multilabel=multilabel, + **kwargs, ) + self.average = average + self.ignore_index = ignore_index + self.absent_score = absent_score + + def compute(self) -> Tensor: + """Computes intersection over union (IoU)""" + if self.multilabel: + return torch.stack( + [ + _jaccard_from_confmat( + confmat, + 2, + self.average, + self.ignore_index, + self.absent_score + )[1] + for confmat in self.confmat + ] + ) + else: + return _jaccard_from_confmat( + self.confmat, + self.num_classes, + self.average, + self.ignore_index, + self.absent_score + ) diff --git a/src/torchmetrics/classification/matthews_corrcoef.py b/src/torchmetrics/classification/matthews_corrcoef.py index aca91fd5a18..5e47b159fae 100644 --- a/src/torchmetrics/classification/matthews_corrcoef.py +++ b/src/torchmetrics/classification/matthews_corrcoef.py @@ -18,8 +18,13 @@ from typing_extensions import Literal from torchmetrics.classification import BinaryConfusionMatrix, MulticlassConfusionMatrix, MultilabelConfusionMatrix -from torchmetrics.functional.classification.matthews_corrcoef import _matthews_corrcoef_reduce +from torchmetrics.functional.classification.matthews_corrcoef import ( + _matthews_corrcoef_compute, + _matthews_corrcoef_reduce, + _matthews_corrcoef_update, +) from torchmetrics.metric import Metric +from torchmetrics.utilities.prints import rank_zero_warn class BinaryMatthewsCorrCoef(BinaryConfusionMatrix): @@ -215,8 +220,17 @@ def compute(self) -> Tensor: return _matthews_corrcoef_reduce(self.confmat) -class MatthewsCorrCoef: - r"""Calculates `Matthews correlation coefficient`_ . This metric measures the general correlation or quality of +class MatthewsCorrCoef(Metric): + r"""Matthews correlation coefficient. + + .. note:: + From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Calculates `Matthews correlation coefficient`_ . This metric measures the general correlation or quality of a classification. This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the @@ -227,30 +241,67 @@ class MatthewsCorrCoef: Legacy Example: >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0, 1, 0, 0]) - >>> matthews_corrcoef = MatthewsCorrCoef(task='binary') + >>> matthews_corrcoef = MatthewsCorrCoef(num_classes=2) >>> matthews_corrcoef(preds, target) tensor(0.5774) """ + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + confmat: Tensor def __new__( cls, - task: Literal["binary", "multiclass", "multilabel"] = None, + num_classes: int, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, threshold: float = 0.5, - num_classes: Optional[int] = None, num_labels: Optional[int] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, ) -> Metric: - kwargs.update(dict(ignore_index=ignore_index, validate_args=validate_args)) - if task == "binary": - return BinaryMatthewsCorrCoef(threshold, **kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - return MulticlassMatthewsCorrCoef(num_classes, **kwargs) - if task == "multilabel": - assert isinstance(num_labels, int) - return MultilabelMatthewsCorrCoef(num_labels, threshold, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) + if task is not None: + kwargs.update(dict(ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryMatthewsCorrCoef(threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + return MulticlassMatthewsCorrCoef(num_classes, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelMatthewsCorrCoef(num_labels, threshold, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + return super().__new__(cls) + + def __init__( + self, + num_classes: int, + threshold: float = 0.5, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.num_classes = num_classes + self.threshold = threshold + + self.add_state("confmat", default=torch.zeros(num_classes, num_classes), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update state with predictions and targets. + """ + confmat = _matthews_corrcoef_update(preds, target, self.num_classes, self.threshold) + self.confmat += confmat + + def compute(self) -> Tensor: + """Computes matthews correlation coefficient.""" + return _matthews_corrcoef_compute(self.confmat) diff --git a/src/torchmetrics/classification/precision_recall.py b/src/torchmetrics/classification/precision_recall.py index 146822bfd38..e64255e06e6 100644 --- a/src/torchmetrics/classification/precision_recall.py +++ b/src/torchmetrics/classification/precision_recall.py @@ -17,9 +17,20 @@ from torch import Tensor from typing_extensions import Literal -from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores -from torchmetrics.functional.classification.precision_recall import _precision_recall_reduce +from torchmetrics.classification.stat_scores import ( + BinaryStatScores, + MulticlassStatScores, + MultilabelStatScores, + StatScores, +) +from torchmetrics.functional.classification.precision_recall import ( + _precision_compute, + _precision_recall_reduce, + _recall_compute, +) from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import AverageMethod +from torchmetrics.utilities.prints import rank_zero_warn class BinaryPrecision(BinaryStatScores): @@ -592,8 +603,17 @@ def compute(self) -> Tensor: ) -class Precision: - r"""Computes `Precision`_: +class Precision(StatScores): + r"""Precision. + + .. note:: + From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes `Precision`_: .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} @@ -609,17 +629,20 @@ class Precision: >>> import torch >>> preds = torch.tensor([2, 0, 2, 1]) >>> target = torch.tensor([1, 1, 2, 0]) - >>> precision = Precision(task="multiclass", average='macro', num_classes=3) + >>> precision = Precision(average='macro', num_classes=3) >>> precision(preds, target) tensor(0.1667) - >>> precision = Precision(task="multiclass", average='micro', num_classes=3) + >>> precision = Precision(average='micro') >>> precision(preds, target) tensor(0.2500) """ + is_differentiable = False + higher_is_better = True + full_state_update: bool = False def __new__( cls, - task: Literal["binary", "multiclass", "multilabel"], + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, threshold: float = 0.5, num_classes: Optional[int] = None, num_labels: Optional[int] = None, @@ -628,26 +651,88 @@ def __new__( top_k: Optional[int] = 1, ignore_index: Optional[int] = None, validate_args: bool = True, + multiclass: Optional[bool] = None, + mdmc_average: Optional[str] = None, **kwargs: Any, ) -> Metric: - assert multidim_average is not None - kwargs.update(dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args)) - if task == "binary": - return BinaryPrecision(threshold, **kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - assert isinstance(top_k, int) - return MulticlassPrecision(num_classes, top_k, average, **kwargs) - if task == "multilabel": - assert isinstance(num_labels, int) - return MultilabelPrecision(num_labels, threshold, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + if task is not None: + assert multidim_average is not None + kwargs.update( + dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + ) + if task == "binary": + return BinaryPrecision(threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return MulticlassPrecision(num_classes, top_k, average, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelPrecision(num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + return super().__new__(cls) + + def __init__( + self, + num_classes: Optional[int] = None, + threshold: float = 0.5, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + mdmc_average: Optional[str] = None, + ignore_index: Optional[int] = None, + top_k: Optional[int] = None, + multiclass: Optional[bool] = None, + **kwargs: Any, + ) -> None: + allowed_average = ["micro", "macro", "weighted", "samples", "none", None] + if average not in allowed_average: + raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + + _reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None) + if "reduce" not in kwargs: + kwargs["reduce"] = AverageMethod.MACRO if average in _reduce_options else average + if "mdmc_reduce" not in kwargs: + kwargs["mdmc_reduce"] = mdmc_average + + super().__init__( + threshold=threshold, + top_k=top_k, + num_classes=num_classes, + multiclass=multiclass, + ignore_index=ignore_index, + **kwargs, ) + self.average = average + + def compute(self) -> Tensor: + """Computes the precision score based on inputs passed in to ``update`` previously. + """ + tp, fp, _, fn = self._get_final_stats() + return _precision_compute(tp, fp, fn, self.average, self.mdmc_reduce) + + +class Recall(StatScores): + r"""Recall. -class Recall: - r"""Computes `Recall`_: + .. note:: + From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes `Recall`_: .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}} @@ -663,17 +748,20 @@ class Recall: >>> import torch >>> preds = torch.tensor([2, 0, 2, 1]) >>> target = torch.tensor([1, 1, 2, 0]) - >>> recall = Recall(task="multiclass", average='macro', num_classes=3) + >>> recall = Recall(average='macro', num_classes=3) >>> recall(preds, target) tensor(0.3333) - >>> recall = Recall(task="multiclass", average='micro', num_classes=3) + >>> recall = Recall(average='micro') >>> recall(preds, target) tensor(0.2500) """ + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False def __new__( cls, - task: Literal["binary", "multiclass", "multilabel"], + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, threshold: float = 0.5, num_classes: Optional[int] = None, num_labels: Optional[int] = None, @@ -682,19 +770,73 @@ def __new__( top_k: Optional[int] = 1, ignore_index: Optional[int] = None, validate_args: bool = True, + mdmc_average: Optional[str] = None, + multiclass: Optional[bool] = None, **kwargs: Any, ) -> Metric: - assert multidim_average is not None - kwargs.update(dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args)) - if task == "binary": - return BinaryRecall(threshold, **kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - assert isinstance(top_k, int) - return MulticlassRecall(num_classes, top_k, average, **kwargs) - if task == "multilabel": - assert isinstance(num_labels, int) - return MultilabelRecall(num_labels, threshold, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + if task is not None: + assert multidim_average is not None + kwargs.update( + dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + ) + if task == "binary": + return BinaryRecall(threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return MulticlassRecall(num_classes, top_k, average, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelRecall(num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + return super().__new__(cls) + + def __init__( + self, + num_classes: Optional[int] = None, + threshold: float = 0.5, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + mdmc_average: Optional[str] = None, + ignore_index: Optional[int] = None, + top_k: Optional[int] = None, + multiclass: Optional[bool] = None, + **kwargs: Any, + ) -> None: + allowed_average = ["micro", "macro", "weighted", "samples", "none", None] + if average not in allowed_average: + raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + + _reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None) + if "reduce" not in kwargs: + kwargs["reduce"] = AverageMethod.MACRO if average in _reduce_options else average + if "mdmc_reduce" not in kwargs: + kwargs["mdmc_reduce"] = mdmc_average + + super().__init__( + threshold=threshold, + top_k=top_k, + num_classes=num_classes, + multiclass=multiclass, + ignore_index=ignore_index, + **kwargs, ) + + self.average = average + + def compute(self) -> Tensor: + """Computes the recall score based on inputs passed in to ``update`` previously. + + """ + tp, fp, _, fn = self._get_final_stats() + return _recall_compute(tp, fp, fn, self.average, self.mdmc_reduce) diff --git a/src/torchmetrics/classification/precision_recall_curve.py b/src/torchmetrics/classification/precision_recall_curve.py index ad367a41249..f22b47cd068 100644 --- a/src/torchmetrics/classification/precision_recall_curve.py +++ b/src/torchmetrics/classification/precision_recall_curve.py @@ -34,8 +34,11 @@ _multilabel_precision_recall_curve_format, _multilabel_precision_recall_curve_tensor_validation, _multilabel_precision_recall_curve_update, + _precision_recall_curve_compute, + _precision_recall_curve_update, ) from torchmetrics.metric import Metric +from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat @@ -417,8 +420,17 @@ def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], Li return _multilabel_precision_recall_curve_compute(state, self.num_labels, self.thresholds, self.ignore_index) -class PrecisionRecallCurve: - r"""Computes the precision-recall curve. The curve consist of multiple pairs of precision and recall values +class PrecisionRecallCurve(Metric): + r"""Precision Recall Curve. + + .. note:: + From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes the precision-recall curve. The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen. This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the @@ -429,7 +441,7 @@ class PrecisionRecallCurve: Legacy Example: >>> pred = torch.tensor([0, 0.1, 0.8, 0.4]) >>> target = torch.tensor([0, 1, 1, 0]) - >>> pr_curve = PrecisionRecallCurve(task="binary") + >>> pr_curve = PrecisionRecallCurve(pos_label=1) >>> precision, recall, thresholds = pr_curve(pred, target) >>> precision tensor([0.6667, 0.5000, 1.0000, 1.0000]) @@ -438,12 +450,13 @@ class PrecisionRecallCurve: >>> thresholds tensor([0.1000, 0.4000, 0.8000]) + Example (multiclass case): >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = torch.tensor([0, 1, 3, 2]) - >>> pr_curve = PrecisionRecallCurve(task="multiclass", num_classes=5) + >>> pr_curve = PrecisionRecallCurve(num_classes=5) >>> precision, recall, thresholds = pr_curve(pred, target) >>> precision [tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]), @@ -454,25 +467,82 @@ class PrecisionRecallCurve: [tensor(0.7500), tensor(0.7500), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor(0.0500)] """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + preds: List[Tensor] + target: List[Tensor] + def __new__( cls, - task: Literal["binary", "multiclass", "multilabel"], + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, thresholds: Optional[Union[int, List[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + pos_label: Optional[int] = None, **kwargs: Any, ) -> Metric: - kwargs.update(dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)) - if task == "binary": - return BinaryPrecisionRecallCurve(**kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - return MulticlassPrecisionRecallCurve(num_classes, **kwargs) - if task == "multilabel": - assert isinstance(num_labels, int) - return MultilabelPrecisionRecallCurve(num_labels, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + if task is not None: + kwargs.update(dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryPrecisionRecallCurve(**kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + return MulticlassPrecisionRecallCurve(num_classes, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelPrecisionRecallCurve(num_labels, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + return super().__new__(cls) + + def __init__( + self, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + + self.num_classes = num_classes + self.pos_label = pos_label + + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + + rank_zero_warn( + "Metric `PrecisionRecallCurve` will save all targets and predictions in buffer." + " For large datasets this may lead to large memory footprint." ) + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update state with predictions and targets. + """ + preds, target, num_classes, pos_label = _precision_recall_curve_update( + preds, target, self.num_classes, self.pos_label + ) + self.preds.append(preds) + self.target.append(target) + self.num_classes = num_classes + self.pos_label = pos_label + + def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + """Compute the precision-recall curve. + """ + preds = dim_zero_cat(self.preds) + target = dim_zero_cat(self.target) + if not self.num_classes: + raise ValueError(f"`num_classes` bas to be positive number, but got {self.num_classes}") + return _precision_recall_curve_compute(preds, target, self.num_classes, self.pos_label) diff --git a/src/torchmetrics/classification/roc.py b/src/torchmetrics/classification/roc.py index a26b80ee873..6b138166823 100644 --- a/src/torchmetrics/classification/roc.py +++ b/src/torchmetrics/classification/roc.py @@ -26,8 +26,11 @@ _binary_roc_compute, _multiclass_roc_compute, _multilabel_roc_compute, + _roc_compute, + _roc_update, ) from torchmetrics.metric import Metric +from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat @@ -302,8 +305,17 @@ def compute(self) -> Tuple[Tensor, Tensor, Tensor]: return _multilabel_roc_compute(state, self.num_labels, self.thresholds, self.ignore_index) -class ROC: - r"""Computes the Receiver Operating Characteristic (ROC). The curve consist of multiple pairs of true positive +class ROC(Metric): + r"""Receiver Operating Characteristic. + + .. note:: + From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes the Receiver Operating Characteristic (ROC). The curve consist of multiple pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at different thresholds, such that the tradeoff between the two values can be seen. @@ -313,40 +325,42 @@ class ROC: influence and examples. Legacy Example: - >>> pred = torch.tensor([0.0, 1.0, 2.0, 3.0]) + >>> pred = torch.tensor([0, 1, 2, 3]) >>> target = torch.tensor([0, 1, 1, 1]) - >>> roc = ROC(task="binary") + >>> roc = ROC(pos_label=1) >>> fpr, tpr, thresholds = roc(pred, target) >>> fpr tensor([0., 0., 0., 0., 1.]) >>> tpr tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000]) >>> thresholds - tensor([1.0000, 0.9526, 0.8808, 0.7311, 0.5000]) + tensor([4, 3, 2, 1, 0]) + Example (multiclass case): >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05], ... [0.05, 0.05, 0.05, 0.75]]) >>> target = torch.tensor([0, 1, 3, 2]) - >>> roc = ROC(task="multiclass", num_classes=4) + >>> roc = ROC(num_classes=4) >>> fpr, tpr, thresholds = roc(pred, target) >>> fpr [tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0.0000, 0.3333, 1.0000]), tensor([0.0000, 0.3333, 1.0000])] >>> tpr [tensor([0., 1., 1.]), tensor([0., 1., 1.]), tensor([0., 0., 1.]), tensor([0., 0., 1.])] - >>> thresholds # doctest: +NORMALIZE_WHITESPACE - [tensor([1.0000, 0.7500, 0.0500]), - tensor([1.0000, 0.7500, 0.0500]), - tensor([1.0000, 0.7500, 0.0500]), - tensor([1.0000, 0.7500, 0.0500])] + >>> thresholds + [tensor([1.7500, 0.7500, 0.0500]), + tensor([1.7500, 0.7500, 0.0500]), + tensor([1.7500, 0.7500, 0.0500]), + tensor([1.7500, 0.7500, 0.0500])] + Example (multilabel case): >>> pred = torch.tensor([[0.8191, 0.3680, 0.1138], ... [0.3584, 0.7576, 0.1183], ... [0.2286, 0.3468, 0.1338], ... [0.8603, 0.0745, 0.1837]]) >>> target = torch.tensor([[1, 1, 0], [0, 1, 0], [0, 0, 0], [0, 1, 1]]) - >>> roc = ROC(task='multilabel', num_labels=3) + >>> roc = ROC(num_classes=3, pos_label=1) >>> fpr, tpr, thresholds = roc(pred, target) >>> fpr [tensor([0.0000, 0.3333, 0.3333, 0.6667, 1.0000]), @@ -357,30 +371,85 @@ class ROC: tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000]), tensor([0., 1., 1., 1., 1.])] >>> thresholds - [tensor([1.0000, 0.8603, 0.8191, 0.3584, 0.2286]), - tensor([1.0000, 0.7576, 0.3680, 0.3468, 0.0745]), - tensor([1.0000, 0.1837, 0.1338, 0.1183, 0.1138])] + [tensor([1.8603, 0.8603, 0.8191, 0.3584, 0.2286]), + tensor([1.7576, 0.7576, 0.3680, 0.3468, 0.0745]), + tensor([1.1837, 0.1837, 0.1338, 0.1183, 0.1138])] """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + preds: List[Tensor] + target: List[Tensor] + def __new__( cls, - task: Literal["binary", "multiclass", "multilabel"], + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, thresholds: Optional[Union[int, List[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + pos_label: Optional[int] = None, **kwargs: Any, ) -> Metric: - kwargs.update(dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)) - if task == "binary": - return BinaryROC(**kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - return MulticlassROC(num_classes, **kwargs) - if task == "multilabel": - assert isinstance(num_labels, int) - return MultilabelROC(num_labels, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + if task is not None: + kwargs.update(dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryROC(**kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + return MulticlassROC(num_classes, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelROC(num_labels, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + return super().__new__(cls) + + def __init__( + self, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + + self.num_classes = num_classes + self.pos_label = pos_label + + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + + rank_zero_warn( + "Metric `ROC` will save all targets and predictions in buffer." + " For large datasets this may lead to large memory footprint." ) + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update state with predictions and targets. + """ + preds, target, num_classes, pos_label = _roc_update(preds, target, self.num_classes, self.pos_label) + self.preds.append(preds) + self.target.append(target) + self.num_classes = num_classes + self.pos_label = pos_label + + def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + """Compute the receiver operating characteristic. + """ + preds = dim_zero_cat(self.preds) + target = dim_zero_cat(self.target) + if not self.num_classes: + raise ValueError(f"`num_classes` bas to be positive number, but got {self.num_classes}") + return _roc_compute(preds, target, self.num_classes, self.pos_label) diff --git a/src/torchmetrics/classification/specificity.py b/src/torchmetrics/classification/specificity.py index 2a572e153f1..927fb25564f 100644 --- a/src/torchmetrics/classification/specificity.py +++ b/src/torchmetrics/classification/specificity.py @@ -17,9 +17,16 @@ from torch import Tensor from typing_extensions import Literal -from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores -from torchmetrics.functional.classification.specificity import _specificity_reduce +from torchmetrics.classification.stat_scores import ( + BinaryStatScores, + MulticlassStatScores, + MultilabelStatScores, + StatScores, +) +from torchmetrics.functional.classification.specificity import _specificity_compute, _specificity_reduce from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import AverageMethod +from torchmetrics.utilities.prints import rank_zero_warn class BinarySpecificity(BinaryStatScores): @@ -287,8 +294,17 @@ def compute(self) -> Tensor: return _specificity_reduce(tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average) -class Specificity: - r"""Computes `Specificity`_. +class Specificity(StatScores): + r"""Specificity. + + .. note:: + From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes `Specificity`_. .. math:: \text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}} @@ -303,17 +319,20 @@ class Specificity: Legacy Example: >>> preds = torch.tensor([2, 0, 2, 1]) >>> target = torch.tensor([1, 1, 2, 0]) - >>> specificity = Specificity(task="multiclass", average='macro', num_classes=3) + >>> specificity = Specificity(average='macro', num_classes=3) >>> specificity(preds, target) tensor(0.6111) - >>> specificity = Specificity(task="multiclass", average='micro', num_classes=3) + >>> specificity = Specificity(average='micro') >>> specificity(preds, target) tensor(0.6250) """ + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False def __new__( cls, - task: Literal["binary", "multiclass", "multilabel"], + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, threshold: float = 0.5, num_classes: Optional[int] = None, num_labels: Optional[int] = None, @@ -322,19 +341,72 @@ def __new__( top_k: Optional[int] = 1, ignore_index: Optional[int] = None, validate_args: bool = True, + mdmc_average: Optional[str] = None, + multiclass: Optional[bool] = None, **kwargs: Any, ) -> Metric: - assert multidim_average is not None - kwargs.update(dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args)) - if task == "binary": - return BinarySpecificity(threshold, **kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - assert isinstance(top_k, int) - return MulticlassSpecificity(num_classes, top_k, average, **kwargs) - if task == "multilabel": - assert isinstance(num_labels, int) - return MultilabelSpecificity(num_labels, threshold, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + if task is not None: + assert multidim_average is not None + kwargs.update( + dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + ) + if task == "binary": + return BinarySpecificity(threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return MulticlassSpecificity(num_classes, top_k, average, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelSpecificity(num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + return super().__new__(cls) + + def __init__( + self, + num_classes: Optional[int] = None, + threshold: float = 0.5, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + mdmc_average: Optional[str] = None, + ignore_index: Optional[int] = None, + top_k: Optional[int] = None, + multiclass: Optional[bool] = None, + **kwargs: Any, + ) -> None: + allowed_average = ["micro", "macro", "weighted", "samples", "none", None] + if average not in allowed_average: + raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + + _reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None) + if "reduce" not in kwargs: + kwargs["reduce"] = AverageMethod.MACRO if average in _reduce_options else average + if "mdmc_reduce" not in kwargs: + kwargs["mdmc_reduce"] = mdmc_average + + super().__init__( + threshold=threshold, + top_k=top_k, + num_classes=num_classes, + multiclass=multiclass, + ignore_index=ignore_index, + **kwargs, ) + + self.average = average + + def compute(self) -> Tensor: + """Computes the specificity score based on inputs passed in to ``update`` previously. + """ + tp, fp, tn, fn = self._get_final_stats() + return _specificity_compute(tp, fp, tn, fn, self.average, self.mdmc_reduce) diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index 4078b878057..b01d4e08451 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -33,9 +33,13 @@ _multilabel_stat_scores_format, _multilabel_stat_scores_tensor_validation, _multilabel_stat_scores_update, + _stat_scores_compute, + _stat_scores_update, ) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod +from torchmetrics.utilities.prints import rank_zero_warn class _AbstractStatScores(Metric): @@ -489,8 +493,17 @@ def compute(self) -> Tensor: return _multilabel_stat_scores_compute(tp, fp, tn, fn, self.average, self.multidim_average) -class StatScores: - r"""Computes the number of true positives, false positives, true negatives, false negatives and the support. +class StatScores(Metric): + r"""StatScores. + + .. note:: + From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes the number of true positives, false positives, true negatives, false negatives and the support. This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of @@ -500,19 +513,27 @@ class StatScores: Legacy Example: >>> preds = torch.tensor([1, 0, 2, 1]) >>> target = torch.tensor([1, 1, 2, 0]) - >>> stat_scores = StatScores(task="multiclass", num_classes=3, average='micro') - >>> stat_scores(preds, target) - tensor([2, 2, 6, 2, 4]) - >>> stat_scores = StatScores(task="multiclass", num_classes=3, average=None) + >>> stat_scores = StatScores(reduce='macro', num_classes=3) >>> stat_scores(preds, target) tensor([[0, 1, 2, 1, 1], [1, 1, 1, 1, 2], [1, 0, 3, 0, 1]]) + >>> stat_scores = StatScores(reduce='micro') + >>> stat_scores(preds, target) + tensor([2, 2, 6, 2, 4]) """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + # TODO: canot be used because if scripting + # tp: Union[Tensor, List[Tensor]] + # fp: Union[Tensor, List[Tensor]] + # tn: Union[Tensor, List[Tensor]] + # fn: Union[Tensor, List[Tensor]] def __new__( cls, - task: Literal["binary", "multiclass", "multilabel"], + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, threshold: float = 0.5, num_classes: Optional[int] = None, num_labels: Optional[int] = None, @@ -521,19 +542,128 @@ def __new__( top_k: Optional[int] = 1, ignore_index: Optional[int] = None, validate_args: bool = True, + mdmc_average: Optional[str] = None, + multiclass: Optional[bool] = None, **kwargs: Any, ) -> Metric: - assert multidim_average is not None - kwargs.update(dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args)) - if task == "binary": - return BinaryStatScores(threshold, **kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - assert isinstance(top_k, int) - return MulticlassStatScores(num_classes, top_k, average, **kwargs) - if task == "multilabel": - assert isinstance(num_labels, int) - return MultilabelStatScores(num_labels, threshold, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + if task is not None: + assert multidim_average is not None + kwargs.update( + dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + ) + if task == "binary": + return BinaryStatScores(threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return MulticlassStatScores(num_classes, top_k, average, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelStatScores(num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + return super().__new__(cls) + + def __init__( + self, + threshold: float = 0.5, + top_k: Optional[int] = None, + reduce: str = "micro", + num_classes: Optional[int] = None, + ignore_index: Optional[int] = None, + mdmc_reduce: Optional[str] = None, + multiclass: Optional[bool] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + num_labels: Optional[int] = None, + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.reduce = reduce + self.mdmc_reduce = mdmc_reduce + self.num_classes = num_classes + self.threshold = threshold + self.multiclass = multiclass + self.ignore_index = ignore_index + self.top_k = top_k + + if reduce not in ["micro", "macro", "samples"]: + raise ValueError(f"The `reduce` {reduce} is not valid.") + + if mdmc_reduce not in [None, "samplewise", "global"]: + raise ValueError(f"The `mdmc_reduce` {mdmc_reduce} is not valid.") + + if reduce == "macro" and (not num_classes or num_classes < 1): + raise ValueError("When you set `reduce` as 'macro', you have to provide the number of classes.") + + if num_classes and ignore_index is not None and (not ignore_index < num_classes or num_classes == 1): + raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") + + default: Callable = lambda: [] + reduce_fn: Optional[str] = "cat" + if mdmc_reduce != "samplewise" and reduce != "samples": + if reduce == "micro": + zeros_shape = [] + elif reduce == "macro": + zeros_shape = [num_classes] + else: + raise ValueError(f'Wrong reduce="{reduce}"') + default = lambda: torch.zeros(zeros_shape, dtype=torch.long) + reduce_fn = "sum" + + for s in ("tp", "fp", "tn", "fn"): + self.add_state(s, default=default(), dist_reduce_fx=reduce_fn) + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update state with predictions and targets. + """ + tp, fp, tn, fn = _stat_scores_update( + preds, + target, + reduce=self.reduce, + mdmc_reduce=self.mdmc_reduce, + threshold=self.threshold, + num_classes=self.num_classes, + top_k=self.top_k, + multiclass=self.multiclass, + ignore_index=self.ignore_index, ) + + # Update states + if self.reduce != AverageMethod.SAMPLES and self.mdmc_reduce != MDMCAverageMethod.SAMPLEWISE: + self.tp += tp + self.fp += fp + self.tn += tn + self.fn += fn + else: + self.tp.append(tp) + self.fp.append(fp) + self.tn.append(tn) + self.fn.append(fn) + + def _get_final_stats(self) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Performs concatenation on the stat scores if neccesary, before passing them to a compute function.""" + tp = torch.cat(self.tp) if isinstance(self.tp, list) else self.tp + fp = torch.cat(self.fp) if isinstance(self.fp, list) else self.fp + tn = torch.cat(self.tn) if isinstance(self.tn, list) else self.tn + fn = torch.cat(self.fn) if isinstance(self.fn, list) else self.fn + return tp, fp, tn, fn + + def compute(self) -> Tensor: + """Computes the stat scores based on inputs passed in to ``update`` previously. + + """ + tp, fp, tn, fn = self._get_final_stats() + return _stat_scores_compute(tp, fp, tn, fn) diff --git a/src/torchmetrics/functional/classification/accuracy.py b/src/torchmetrics/functional/classification/accuracy.py index 445d50f2953..18a695d1788 100644 --- a/src/torchmetrics/functional/classification/accuracy.py +++ b/src/torchmetrics/functional/classification/accuracy.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Tuple import torch -from torch import Tensor +from torch import Tensor, tensor from typing_extensions import Literal from torchmetrics.functional.classification.stat_scores import ( @@ -30,8 +30,12 @@ _multilabel_stat_scores_format, _multilabel_stat_scores_tensor_validation, _multilabel_stat_scores_update, + _reduce_stat_scores, + _stat_scores_update, ) +from torchmetrics.utilities.checks import _check_classification_inputs, _input_format_classification, _input_squeeze from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.enums import AverageMethod, DataType, MDMCAverageMethod def _accuracy_reduce( @@ -381,6 +385,182 @@ def multilabel_accuracy( return _accuracy_reduce(tp, fp, tn, fn, average=average, multidim_average=multidim_average, multilabel=True) +def _check_subset_validity(mode: DataType) -> bool: + """Checks input mode is valid.""" + return mode in (DataType.MULTILABEL, DataType.MULTIDIM_MULTICLASS) + + +def _mode( + preds: Tensor, + target: Tensor, + threshold: float, + top_k: Optional[int], + num_classes: Optional[int], + multiclass: Optional[bool], + ignore_index: Optional[int] = None, +) -> DataType: + """Finds the mode of the input tensors. + + Example: + >>> target = torch.tensor([0, 1, 2, 3]) + >>> preds = torch.tensor([0, 2, 1, 3]) + >>> _mode(preds, target, 0.5, None, None, None) + + """ + + mode = _check_classification_inputs( + preds, + target, + threshold=threshold, + top_k=top_k, + num_classes=num_classes, + multiclass=multiclass, + ignore_index=ignore_index, + ) + return mode + + +def _accuracy_update( + preds: Tensor, + target: Tensor, + reduce: Optional[str], + mdmc_reduce: Optional[str], + threshold: float, + num_classes: Optional[int], + top_k: Optional[int], + multiclass: Optional[bool], + ignore_index: Optional[int], + mode: DataType, +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Updates and returns stat scores (true positives, false positives, true negatives, false negatives) required + to compute accuracy. + """ + + if mode == DataType.MULTILABEL and top_k: + raise ValueError("You can not use the `top_k` parameter to calculate accuracy for multi-label inputs.") + preds, target = _input_squeeze(preds, target) + tp, fp, tn, fn = _stat_scores_update( + preds, + target, + reduce=reduce, + mdmc_reduce=mdmc_reduce, + threshold=threshold, + num_classes=num_classes, + top_k=top_k, + multiclass=multiclass, + ignore_index=ignore_index, + mode=mode, + ) + return tp, fp, tn, fn + + +def _accuracy_compute( + tp: Tensor, + fp: Tensor, + tn: Tensor, + fn: Tensor, + average: Optional[str], + mdmc_average: Optional[str], + mode: DataType, +) -> Tensor: + """Computes accuracy from stat scores: true positives, false positives, true negatives, false negatives. + + Example: + >>> preds = torch.tensor([0, 2, 1, 3]) + >>> target = torch.tensor([0, 1, 2, 3]) + >>> threshold = 0.5 + >>> reduce = average = 'micro' + >>> mdmc_average = 'global' + >>> mode = _mode(preds, target, threshold, top_k=None, num_classes=None, multiclass=None) + >>> tp, fp, tn, fn = _accuracy_update( + ... preds, target, reduce, mdmc_average, threshold=0.5, num_classes=None, top_k=None, + ... multiclass=None, ignore_index=None, mode=mode + ... ) + >>> _accuracy_compute(tp, fp, tn, fn, average, mdmc_average, mode) + tensor(0.5000) + + >>> target = torch.tensor([0, 1, 2]) + >>> preds = torch.tensor([[0.1, 0.9, 0], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3]]) + >>> top_k, threshold = 2, 0.5 + >>> reduce = average = 'micro' + >>> mdmc_average = 'global' + >>> mode = _mode(preds, target, threshold, top_k, num_classes=None, multiclass=None) + >>> tp, fp, tn, fn = _accuracy_update(preds, target, reduce, mdmc_average, threshold, + ... num_classes=None, top_k=top_k, multiclass=None, ignore_index=None, mode=mode) + >>> _accuracy_compute(tp, fp, tn, fn, average, mdmc_average, mode) + tensor(0.6667) + """ + + simple_average = [AverageMethod.MICRO, AverageMethod.SAMPLES] + if (mode == DataType.BINARY and average in simple_average) or mode == DataType.MULTILABEL: + numerator = tp + tn + denominator = tp + tn + fp + fn + else: + numerator = tp.clone() + denominator = tp + fn + + if mdmc_average != MDMCAverageMethod.SAMPLEWISE: + if average == AverageMethod.MACRO: + cond = tp + fp + fn == 0 + numerator = numerator[~cond] + denominator = denominator[~cond] + + if average == AverageMethod.NONE: + # a class is not present if there exists no TPs, no FPs, and no FNs + meaningless_indeces = torch.nonzero((tp | fn | fp) == 0).cpu() + numerator[meaningless_indeces, ...] = -1 + denominator[meaningless_indeces, ...] = -1 + + return _reduce_stat_scores( + numerator=numerator, + denominator=denominator, + weights=None if average != AverageMethod.WEIGHTED else tp + fn, + average=average, + mdmc_average=mdmc_average, + ) + + +def _subset_accuracy_update( + preds: Tensor, + target: Tensor, + threshold: float, + top_k: Optional[int], + ignore_index: Optional[int] = None, +) -> Tuple[Tensor, Tensor]: + """Updates and returns variables required to compute subset accuracy. + """ + + preds, target = _input_squeeze(preds, target) + preds, target, mode = _input_format_classification( + preds, target, threshold=threshold, top_k=top_k, ignore_index=ignore_index + ) + + if mode == DataType.MULTILABEL and top_k: + raise ValueError("You can not use the `top_k` parameter to calculate accuracy for multi-label inputs.") + + if mode == DataType.MULTILABEL: + correct = (preds == target).all(dim=1).sum() + total = tensor(target.shape[0], device=target.device) + elif mode == DataType.MULTICLASS: + correct = (preds * target).sum() + total = target.sum() + elif mode == DataType.MULTIDIM_MULTICLASS: + sample_correct = (preds * target).sum(dim=(1, 2)) + correct = (sample_correct == target.shape[2]).sum() + total = tensor(target.shape[0], device=target.device) + else: + correct, total = tensor(0), tensor(0) + + return correct, total + + +def _subset_accuracy_compute(correct: Tensor, total: Tensor) -> Tensor: + """Computes subset accuracy from number of correct observations and total number of observations. + """ + + return correct.float() / total + + def accuracy( preds: Tensor, target: Tensor, diff --git a/src/torchmetrics/functional/classification/auroc.py b/src/torchmetrics/functional/classification/auroc.py index 989eabd0ef3..53cf335e404 100644 --- a/src/torchmetrics/functional/classification/auroc.py +++ b/src/torchmetrics/functional/classification/auroc.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Union +import warnings +from typing import List, Optional, Sequence, Tuple, Union import torch from torch import Tensor, tensor @@ -35,9 +36,12 @@ _binary_roc_compute, _multiclass_roc_compute, _multilabel_roc_compute, + roc, ) +from torchmetrics.utilities.checks import _input_format_classification from torchmetrics.utilities.compute import _auc_compute_without_check, _safe_divide from torchmetrics.utilities.data import _bincount +from torchmetrics.utilities.enums import AverageMethod, DataType from torchmetrics.utilities.prints import rank_zero_warn @@ -414,6 +418,153 @@ def multilabel_auroc( return _multilabel_auroc_compute(state, num_labels, average, thresholds, ignore_index) +def _auroc_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor, DataType]: + """Updates and returns variables required to compute Area Under the Receiver Operating Characteristic Curve. + Validates the inputs and returns the mode of the inputs. + + """ + + # use _input_format_classification for validating the input and get the mode of data + _, _, mode = _input_format_classification(preds, target) + + if mode == "multi class multi dim": + n_classes = preds.shape[1] + preds = preds.transpose(0, 1).reshape(n_classes, -1).transpose(0, 1) + target = target.flatten() + if mode == "multi-label" and preds.ndim > 2: + n_classes = preds.shape[1] + preds = preds.transpose(0, 1).reshape(n_classes, -1).transpose(0, 1) + target = target.transpose(0, 1).reshape(n_classes, -1).transpose(0, 1) + + return preds, target, mode + + +def _auroc_compute( + preds: Tensor, + target: Tensor, + mode: DataType, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, + average: Optional[str] = "macro", + max_fpr: Optional[float] = None, + sample_weights: Optional[Sequence] = None, +) -> Tensor: + """Computes Area Under the Receiver Operating Characteristic Curve. + + Example: + >>> # binary case + >>> preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]) + >>> target = torch.tensor([0, 0, 1, 1, 1]) + >>> preds, target, mode = _auroc_update(preds, target) + >>> _auroc_compute(preds, target, mode, pos_label=1) + tensor(0.5000) + + >>> # multiclass case + >>> preds = torch.tensor([[0.90, 0.05, 0.05], + ... [0.05, 0.90, 0.05], + ... [0.05, 0.05, 0.90], + ... [0.85, 0.05, 0.10], + ... [0.10, 0.10, 0.80]]) + >>> target = torch.tensor([0, 1, 1, 2, 2]) + >>> preds, target, mode = _auroc_update(preds, target) + >>> _auroc_compute(preds, target, mode, num_classes=3) + tensor(0.7778) + """ + + # binary mode override num_classes + if mode == DataType.BINARY: + num_classes = 1 + + # check max_fpr parameter + if max_fpr is not None: + if not isinstance(max_fpr, float) and 0 < max_fpr <= 1: + raise ValueError(f"`max_fpr` should be a float in range (0, 1], got: {max_fpr}") + + # max_fpr parameter is only support for binary + if mode != DataType.BINARY: + raise ValueError( + "Partial AUC computation not available in multilabel/multiclass setting," + f" 'max_fpr' must be set to `None`, received `{max_fpr}`." + ) + + # calculate fpr, tpr + if mode == DataType.MULTILABEL: + if average == AverageMethod.MICRO: + fpr, tpr, _ = roc(preds.flatten(), target.flatten(), 1, pos_label, sample_weights) + elif num_classes: + # for multilabel we iteratively evaluate roc in a binary fashion + output = [ + roc(preds[:, i], target[:, i], num_classes=1, pos_label=1, sample_weights=sample_weights) + for i in range(num_classes) + ] + fpr = [o[0] for o in output] + tpr = [o[1] for o in output] + else: + raise ValueError("Detected input to be `multilabel` but you did not provide `num_classes` argument") + else: + if mode != DataType.BINARY: + if num_classes is None: + raise ValueError("Detected input to `multiclass` but you did not provide `num_classes` argument") + if average == AverageMethod.WEIGHTED and len(torch.unique(target)) < num_classes: + # If one or more classes has 0 observations, we should exclude them, as its weight will be 0 + target_bool_mat = torch.zeros((len(target), num_classes), dtype=bool, device=target.device) + target_bool_mat[torch.arange(len(target)), target.long()] = 1 + class_observed = target_bool_mat.sum(axis=0) > 0 + for c in range(num_classes): + if not class_observed[c]: + warnings.warn(f"Class {c} had 0 observations, omitted from AUROC calculation", UserWarning) + preds = preds[:, class_observed] + target = target_bool_mat[:, class_observed] + target = torch.where(target)[1] + num_classes = class_observed.sum() + if num_classes == 1: + raise ValueError("Found 1 non-empty class in `multiclass` AUROC calculation") + fpr, tpr, _ = roc(preds, target, num_classes, pos_label, sample_weights) + + # calculate standard roc auc score + if max_fpr is None or max_fpr == 1: + if mode == DataType.MULTILABEL and average == AverageMethod.MICRO: + pass + elif num_classes != 1: + # calculate auc scores per class + auc_scores = [_auc_compute_without_check(x, y, 1.0) for x, y in zip(fpr, tpr)] + + # calculate average + if average == AverageMethod.NONE: + return tensor(auc_scores) + if average == AverageMethod.MACRO: + return torch.mean(torch.stack(auc_scores)) + if average == AverageMethod.WEIGHTED: + if mode == DataType.MULTILABEL: + support = torch.sum(target, dim=0) + else: + support = _bincount(target.flatten(), minlength=num_classes) + return torch.sum(torch.stack(auc_scores) * support / support.sum()) + + allowed_average = (AverageMethod.NONE.value, AverageMethod.MACRO.value, AverageMethod.WEIGHTED.value) + raise ValueError( + f"Argument `average` expected to be one of the following: {allowed_average} but got {average}" + ) + + return _auc_compute_without_check(fpr, tpr, 1.0) + + _device = fpr.device if isinstance(fpr, Tensor) else fpr[0].device + max_area: Tensor = tensor(max_fpr, device=_device) + # Add a single point at max_fpr and interpolate its tpr value + stop = torch.bucketize(max_area, fpr, out_int32=True, right=True) + weight = (max_area - fpr[stop - 1]) / (fpr[stop] - fpr[stop - 1]) + interp_tpr: Tensor = torch.lerp(tpr[stop - 1], tpr[stop], weight) + tpr = torch.cat([tpr[:stop], interp_tpr.view(1)]) + fpr = torch.cat([fpr[:stop], max_area.view(1)]) + + # Compute partial AUC + partial_auc = _auc_compute_without_check(fpr, tpr, 1.0) + + # McClish correction: standardize result to be 0.5 if non-discriminant and 1 if maximal + min_area: Tensor = 0.5 * max_area**2 + return 0.5 * (1 + (partial_auc - min_area) / (max_area - min_area)) + + def auroc( preds: Tensor, target: Tensor, diff --git a/src/torchmetrics/functional/classification/average_precision.py b/src/torchmetrics/functional/classification/average_precision.py index 792e1e3600f..62d23824477 100644 --- a/src/torchmetrics/functional/classification/average_precision.py +++ b/src/torchmetrics/functional/classification/average_precision.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings from typing import List, Optional, Tuple, Union import torch @@ -33,6 +34,8 @@ _multilabel_precision_recall_curve_format, _multilabel_precision_recall_curve_tensor_validation, _multilabel_precision_recall_curve_update, + _precision_recall_curve_compute, + _precision_recall_curve_update, ) from torchmetrics.utilities.compute import _safe_divide from torchmetrics.utilities.data import _bincount @@ -398,6 +401,130 @@ def multilabel_average_precision( return _multilabel_average_precision_compute(state, num_labels, average, thresholds, ignore_index) +def _average_precision_update( + preds: Tensor, + target: Tensor, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, + average: Optional[str] = "macro", +) -> Tuple[Tensor, Tensor, int, Optional[int]]: + """Format the predictions and target based on the ``num_classes``, ``pos_label`` and ``average`` parameter. + + """ + preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target, num_classes, pos_label) + if average == "micro" and preds.ndim != target.ndim: + raise ValueError("Cannot use `micro` average with multi-class input") + + return preds, target, num_classes, pos_label + + +def _average_precision_compute( + preds: Tensor, + target: Tensor, + num_classes: int, + pos_label: Optional[int] = None, + average: Optional[str] = "macro", +) -> Union[List[Tensor], Tensor]: + """Computes the average precision score. + + Example: + >>> # binary case + >>> preds = torch.tensor([0, 1, 2, 3]) + >>> target = torch.tensor([0, 1, 1, 1]) + >>> pos_label = 1 + >>> preds, target, num_classes, pos_label = _average_precision_update(preds, target, pos_label=pos_label) + >>> _average_precision_compute(preds, target, num_classes, pos_label) + tensor(1.) + + >>> # multiclass case + >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> num_classes = 5 + >>> preds, target, num_classes, pos_label = _average_precision_update(preds, target, num_classes) + >>> _average_precision_compute(preds, target, num_classes, average=None) + [tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)] + """ + + if average == "micro" and preds.ndim == target.ndim: + preds = preds.flatten() + target = target.flatten() + num_classes = 1 + + precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes, pos_label) + if average == "weighted": + if preds.ndim == target.ndim and target.ndim > 1: + weights = target.sum(dim=0).float() + else: + weights = _bincount(target, minlength=max(num_classes, 2)).float() + weights = weights / torch.sum(weights) + else: + weights = None + return _average_precision_compute_with_precision_recall(precision, recall, num_classes, average, weights) + + +def _average_precision_compute_with_precision_recall( + precision: Tensor, + recall: Tensor, + num_classes: int, + average: Optional[str] = "macro", + weights: Optional[Tensor] = None, +) -> Union[List[Tensor], Tensor]: + """Computes the average precision score from precision and recall. + + Example: + >>> # binary case + >>> preds = torch.tensor([0, 1, 2, 3]) + >>> target = torch.tensor([0, 1, 1, 1]) + >>> pos_label = 1 + >>> preds, target, num_classes, pos_label = _average_precision_update(preds, target, pos_label=pos_label) + >>> precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes, pos_label) + >>> _average_precision_compute_with_precision_recall(precision, recall, num_classes, average=None) + tensor(1.) + + >>> # multiclass case + >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> num_classes = 5 + >>> preds, target, num_classes, pos_label = _average_precision_update(preds, target, num_classes) + >>> precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes) + >>> _average_precision_compute_with_precision_recall(precision, recall, num_classes, average=None) + [tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)] + """ + + # Return the step function integral + # The following works because the last entry of precision is + # guaranteed to be 1, as returned by precision_recall_curve + if num_classes == 1: + return -torch.sum((recall[1:] - recall[:-1]) * precision[:-1]) + + res = [] + for p, r in zip(precision, recall): + res.append(-torch.sum((r[1:] - r[:-1]) * p[:-1])) + + # Reduce + if average in ("macro", "weighted"): + res = torch.stack(res) + if torch.isnan(res).any(): + warnings.warn( + "Average precision score for one or more classes was `nan`. Ignoring these classes in average", + UserWarning, + ) + if average == "macro": + return res[~torch.isnan(res)].mean() + weights = torch.ones_like(res) if weights is None else weights + return (res * weights)[~torch.isnan(res)].sum() + if average is None or average == "none": + return res + allowed_average = ("micro", "macro", "weighted", "none", None) + raise ValueError(f"Expected argument `average` to be one of {allowed_average}" f" but got {average}") + + def average_precision( preds: Tensor, target: Tensor, diff --git a/src/torchmetrics/functional/classification/calibration_error.py b/src/torchmetrics/functional/classification/calibration_error.py index 4545dec0742..0a0b901aaa9 100644 --- a/src/torchmetrics/functional/classification/calibration_error.py +++ b/src/torchmetrics/functional/classification/calibration_error.py @@ -23,6 +23,8 @@ _multiclass_confusion_matrix_format, _multiclass_confusion_matrix_tensor_validation, ) +from torchmetrics.utilities.checks import _input_format_classification +from torchmetrics.utilities.enums import DataType def _binning_bucketize( @@ -313,6 +315,36 @@ def multiclass_calibration_error( return _ce_compute(confidences, accuracies, n_bins, norm) +def _ce_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: + """Given a predictions and targets tensor, computes the confidences of the top-1 prediction and records their + correctness. + + """ + _, _, mode = _input_format_classification(preds, target) + + if mode == DataType.BINARY: + if not ((0 <= preds) * (preds <= 1)).all(): + preds = preds.sigmoid() + confidences, accuracies = preds, target + elif mode == DataType.MULTICLASS: + if not ((0 <= preds) * (preds <= 1)).all(): + preds = preds.softmax(dim=1) + confidences, predictions = preds.max(dim=1) + accuracies = predictions.eq(target) + elif mode == DataType.MULTIDIM_MULTICLASS: + # reshape tensors + # for preds, move the class dimension to the final axis and flatten the rest + confidences, predictions = torch.transpose(preds, 1, -1).flatten(0, -2).max(dim=1) + # for targets, just flatten the target + accuracies = predictions.eq(target.flatten()) + else: + raise ValueError( + f"Calibration error is not well-defined for data with size {preds.size()} and targets {target.size()}." + ) + # must be cast to float for ddp allgather to work + return confidences.float(), accuracies.float() + + def calibration_error( preds: Tensor, target: Tensor, diff --git a/src/torchmetrics/functional/classification/cohen_kappa.py b/src/torchmetrics/functional/classification/cohen_kappa.py index b4d9c1217a0..f48ecfc06d8 100644 --- a/src/torchmetrics/functional/classification/cohen_kappa.py +++ b/src/torchmetrics/functional/classification/cohen_kappa.py @@ -22,6 +22,8 @@ _binary_confusion_matrix_format, _binary_confusion_matrix_tensor_validation, _binary_confusion_matrix_update, + _confusion_matrix_compute, + _confusion_matrix_update, _multiclass_confusion_matrix_arg_validation, _multiclass_confusion_matrix_format, _multiclass_confusion_matrix_tensor_validation, @@ -227,6 +229,47 @@ class labels. return _cohen_kappa_reduce(confmat, weights) +_cohen_kappa_update = _confusion_matrix_update + + +def _cohen_kappa_compute(confmat: Tensor, weights: Optional[str] = None) -> Tensor: + """Computes Cohen's kappa based on the weighting type. + + Example: + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0, 1, 0, 0]) + >>> confmat = _cohen_kappa_update(preds, target, num_classes=2) + >>> _cohen_kappa_compute(confmat) + tensor(0.5000) + """ + + confmat = _confusion_matrix_compute(confmat) + confmat = confmat.float() if not confmat.is_floating_point() else confmat + n_classes = confmat.shape[0] + sum0 = confmat.sum(dim=0, keepdim=True) + sum1 = confmat.sum(dim=1, keepdim=True) + expected = sum1 @ sum0 / sum0.sum() # outer product + + if weights is None: + w_mat = torch.ones_like(confmat).flatten() + w_mat[:: n_classes + 1] = 0 + w_mat = w_mat.reshape(n_classes, n_classes) + elif weights in ("linear", "quadratic"): + w_mat = torch.zeros_like(confmat) + w_mat += torch.arange(n_classes, dtype=w_mat.dtype, device=w_mat.device) + if weights == "linear": + w_mat = torch.abs(w_mat - w_mat.T) + else: + w_mat = torch.pow(w_mat - w_mat.T, 2.0) + else: + raise ValueError( + f"Received {weights} for argument ``weights`` but should be either" " None, 'linear' or 'quadratic'" + ) + + k = torch.sum(w_mat * confmat) / torch.sum(w_mat * expected) + return 1 - k + + def cohen_kappa( preds: Tensor, target: Tensor, diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index 3d5b76e1ad5..ecb50a036e5 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -17,8 +17,9 @@ from torch import Tensor from typing_extensions import Literal -from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.utilities.checks import _check_same_shape, _input_format_classification from torchmetrics.utilities.data import _bincount +from torchmetrics.utilities.enums import DataType from torchmetrics.utilities.prints import rank_zero_warn @@ -591,6 +592,82 @@ def multilabel_confusion_matrix( return _multilabel_confusion_matrix_compute(confmat, normalize) +def _confusion_matrix_update( + preds: Tensor, target: Tensor, num_classes: int, threshold: float = 0.5, multilabel: bool = False +) -> Tensor: + """Updates and returns confusion matrix (without any normalization) based on the mode of the input. + + """ + + preds, target, mode = _input_format_classification(preds, target, threshold) + if mode not in (DataType.BINARY, DataType.MULTILABEL): + preds = preds.argmax(dim=1) + target = target.argmax(dim=1) + if multilabel: + unique_mapping = ((2 * target + preds) + 4 * torch.arange(num_classes, device=preds.device)).flatten() + minlength = 4 * num_classes + else: + unique_mapping = (target.view(-1) * num_classes + preds.view(-1)).to(torch.long) + minlength = num_classes**2 + + bins = _bincount(unique_mapping, minlength=minlength) + if multilabel: + confmat = bins.reshape(num_classes, 2, 2) + else: + confmat = bins.reshape(num_classes, num_classes) + return confmat + + +def _confusion_matrix_compute(confmat: Tensor, normalize: Optional[str] = None) -> Tensor: + """Computes confusion matrix based on the normalization mode. + + Example: + >>> # binary case + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0, 1, 0, 0]) + >>> confmat = _confusion_matrix_update(preds, target, num_classes=2) + >>> _confusion_matrix_compute(confmat) + tensor([[2, 0], + [1, 1]]) + + >>> # multiclass case + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> confmat = _confusion_matrix_update(preds, target, num_classes=3) + >>> _confusion_matrix_compute(confmat) + tensor([[1, 1, 0], + [0, 1, 0], + [0, 0, 1]]) + + >>> # multilabel case + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> confmat = _confusion_matrix_update(preds, target, num_classes=3, multilabel=True) + >>> _confusion_matrix_compute(confmat) + tensor([[[1, 0], [0, 1]], + [[1, 0], [1, 0]], + [[0, 1], [0, 1]]]) + """ + + allowed_normalize = ("true", "pred", "all", "none", None) + if normalize not in allowed_normalize: + raise ValueError(f"Argument average needs to one of the following: {allowed_normalize}") + if normalize is not None and normalize != "none": + confmat = confmat.float() if not confmat.is_floating_point() else confmat + if normalize == "true": + confmat = confmat / confmat.sum(axis=1, keepdim=True) + elif normalize == "pred": + confmat = confmat / confmat.sum(axis=0, keepdim=True) + elif normalize == "all": + confmat = confmat / confmat.sum() + + nan_elements = confmat[torch.isnan(confmat)].nelement() + if nan_elements != 0: + confmat[torch.isnan(confmat)] = 0 + rank_zero_warn(f"{nan_elements} nan values found in confusion matrix have been replaced with zeros.") + return confmat + + def confusion_matrix( preds: Tensor, target: Tensor, diff --git a/src/torchmetrics/functional/classification/dice.py b/src/torchmetrics/functional/classification/dice.py index 301321bfaeb..5379d1d45eb 100644 --- a/src/torchmetrics/functional/classification/dice.py +++ b/src/torchmetrics/functional/classification/dice.py @@ -38,6 +38,15 @@ def _dice_compute( average: Defines the reduction that is applied mdmc_average: Defines how averaging is done for multi-dimensional multi-class inputs (on top of the ``average`` parameter) + + Example: + >>> from torchmetrics.functional.classification.stat_scores import _stat_scores_update + >>> from torchmetrics.functional.classification.dice import _dice_compute + >>> preds = torch.tensor([2, 0, 2, 1]) + >>> target = torch.tensor([1, 1, 2, 0]) + >>> tp, fp, tn, fn = _stat_scores_update(preds, target, reduce='micro') + >>> _dice_compute(tp, fp, fn, average='micro', mdmc_average=None) + tensor(0.2500) """ numerator = 2 * tp denominator = 2 * tp + fp + fn diff --git a/src/torchmetrics/functional/classification/f_beta.py b/src/torchmetrics/functional/classification/f_beta.py index 6c7aebba8bc..4e40ada96c6 100644 --- a/src/torchmetrics/functional/classification/f_beta.py +++ b/src/torchmetrics/functional/classification/f_beta.py @@ -30,8 +30,11 @@ _multilabel_stat_scores_format, _multilabel_stat_scores_tensor_validation, _multilabel_stat_scores_update, + _reduce_stat_scores, ) from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.enums import AverageMethod as AvgMethod +from torchmetrics.utilities.enums import MDMCAverageMethod def _fbeta_reduce( @@ -695,6 +698,70 @@ def multilabel_f1_score( ) +def _fbeta_compute( + tp: Tensor, + fp: Tensor, + tn: Tensor, + fn: Tensor, + beta: float, + ignore_index: Optional[int], + average: str, + mdmc_average: Optional[str], +) -> Tensor: + """Computes f_beta metric from stat scores: true positives, false positives, true negatives, false negatives. + + Example: + >>> from torchmetrics.functional.classification.stat_scores import _stat_scores_update + >>> target = torch.tensor([0, 1, 2, 0, 1, 2]) + >>> preds = torch.tensor([0, 2, 1, 0, 0, 1]) + >>> tp, fp, tn, fn = _stat_scores_update(preds, target, reduce='micro', num_classes=3) + >>> _fbeta_compute(tp, fp, tn, fn, beta=0.5, ignore_index=None, average='micro', mdmc_average=None) + tensor(0.3333) + """ + if average == AvgMethod.MICRO and mdmc_average != MDMCAverageMethod.SAMPLEWISE: + mask = tp >= 0 + precision = _safe_divide(tp[mask].sum().float(), (tp[mask] + fp[mask]).sum()) + recall = _safe_divide(tp[mask].sum().float(), (tp[mask] + fn[mask]).sum()) + else: + precision = _safe_divide(tp.float(), tp + fp) + recall = _safe_divide(tp.float(), tp + fn) + + num = (1 + beta**2) * precision * recall + denom = beta**2 * precision + recall + denom[denom == 0.0] = 1.0 # avoid division by 0 + + # if classes matter and a given class is not present in both the preds and the target, + # computing the score for this class is meaningless, thus they should be ignored + if average == AvgMethod.NONE and mdmc_average != MDMCAverageMethod.SAMPLEWISE: + # a class is not present if there exists no TPs, no FPs, and no FNs + meaningless_indeces = torch.nonzero((tp | fn | fp) == 0).cpu() + if ignore_index is None: + ignore_index = meaningless_indeces + else: + ignore_index = torch.unique(torch.cat((meaningless_indeces, torch.tensor([[ignore_index]])))) + + if ignore_index is not None: + if average not in (AvgMethod.MICRO, AvgMethod.SAMPLES) and mdmc_average == MDMCAverageMethod.SAMPLEWISE: + num[..., ignore_index] = -1 + denom[..., ignore_index] = -1 + elif average not in (AvgMethod.MICRO, AvgMethod.SAMPLES): + num[ignore_index, ...] = -1 + denom[ignore_index, ...] = -1 + + if average == AvgMethod.MACRO and mdmc_average != MDMCAverageMethod.SAMPLEWISE: + cond = (tp + fp + fn == 0) | (tp + fp + fn == -3) + num = num[~cond] + denom = denom[~cond] + + return _reduce_stat_scores( + numerator=num, + denominator=denom, + weights=None if average != AvgMethod.WEIGHTED else tp + fn, + average=average, + mdmc_average=mdmc_average, + ) + + def fbeta_score( preds: Tensor, target: Tensor, diff --git a/src/torchmetrics/functional/classification/hamming.py b/src/torchmetrics/functional/classification/hamming.py index 84da16735b1..f1539325d1f 100644 --- a/src/torchmetrics/functional/classification/hamming.py +++ b/src/torchmetrics/functional/classification/hamming.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Tuple, Union import torch from torch import Tensor @@ -31,6 +31,7 @@ _multilabel_stat_scores_tensor_validation, _multilabel_stat_scores_update, ) +from torchmetrics.utilities.checks import _input_format_classification from torchmetrics.utilities.compute import _safe_divide @@ -386,6 +387,36 @@ def multilabel_hamming_distance( return _hamming_distance_reduce(tp, fp, tn, fn, average=average, multidim_average=multidim_average, multilabel=True) +def _hamming_distance_update( + preds: Tensor, + target: Tensor, + threshold: float = 0.5, +) -> Tuple[Tensor, int]: + """Returns the number of positions where prediction equals target, and number of predictions. + """ + + preds, target, _ = _input_format_classification(preds, target, threshold=threshold) + + correct = (preds == target).sum() + total = preds.numel() + + return correct, total + + +def _hamming_distance_compute(correct: Tensor, total: Union[int, Tensor]) -> Tensor: + """Computes the Hamming distance. + + Example: + >>> target = torch.tensor([[0, 1], [1, 1]]) + >>> preds = torch.tensor([[0, 1], [0, 1]]) + >>> correct, total = _hamming_distance_update(preds, target) + >>> _hamming_distance_compute(correct, total) + tensor(0.2500) + """ + + return 1 - correct.float() / total + + def hamming_distance( preds: Tensor, target: Tensor, diff --git a/src/torchmetrics/functional/classification/hinge.py b/src/torchmetrics/functional/classification/hinge.py index cb7d441243d..9a2f7a6f4d0 100644 --- a/src/torchmetrics/functional/classification/hinge.py +++ b/src/torchmetrics/functional/classification/hinge.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch from torch import Tensor, tensor @@ -23,7 +23,9 @@ _multiclass_confusion_matrix_format, _multiclass_confusion_matrix_tensor_validation, ) +from torchmetrics.utilities.checks import _input_squeeze from torchmetrics.utilities.data import to_onehot +from torchmetrics.utilities.enums import DataType, EnumStr def _hinge_loss_compute(measure: Tensor, total: Tensor) -> Tensor: @@ -240,6 +242,117 @@ def multiclass_hinge_loss( return _hinge_loss_compute(measures, total) +class MulticlassMode(EnumStr): + """Enum to represent possible multiclass modes of hinge. + + >>> "Crammer-Singer" in list(MulticlassMode) + True + """ + + CRAMMER_SINGER = "crammer-singer" + ONE_VS_ALL = "one-vs-all" + + +def _check_shape_and_type_consistency_hinge( + preds: Tensor, + target: Tensor, +) -> DataType: + """Checks shape and type of ``preds`` and ``target`` and returns mode of the input tensors. + """ + + if target.ndim > 1: + raise ValueError( + f"The `target` should be one dimensional, got `target` with shape={target.shape}.", + ) + + if preds.ndim == 1: + if preds.shape != target.shape: + raise ValueError( + "The `preds` and `target` should have the same shape,", + f" got `preds` with shape={preds.shape} and `target` with shape={target.shape}.", + ) + mode = DataType.BINARY + elif preds.ndim == 2: + if preds.shape[0] != target.shape[0]: + raise ValueError( + "The `preds` and `target` should have the same shape in the first dimension,", + f" got `preds` with shape={preds.shape} and `target` with shape={target.shape}.", + ) + mode = DataType.MULTICLASS + else: + raise ValueError(f"The `preds` should be one or two dimensional, got `preds` with shape={preds.shape}.") + return mode + + +def _hinge_update( + preds: Tensor, + target: Tensor, + squared: bool = False, + multiclass_mode: Optional[Union[str, MulticlassMode]] = None, +) -> Tuple[Tensor, Tensor]: + """Updates and returns sum over Hinge loss scores for each observation and the total number of observations. + """ + preds, target = _input_squeeze(preds, target) + + mode = _check_shape_and_type_consistency_hinge(preds, target) + + if mode == DataType.MULTICLASS: + target = to_onehot(target, max(2, preds.shape[1])).bool() + + if mode == DataType.MULTICLASS and (multiclass_mode is None or multiclass_mode == MulticlassMode.CRAMMER_SINGER): + margin = preds[target] + margin -= torch.max(preds[~target].view(preds.shape[0], -1), dim=1)[0] + elif mode == DataType.BINARY or multiclass_mode == MulticlassMode.ONE_VS_ALL: + target = target.bool() + margin = torch.zeros_like(preds) + margin[target] = preds[target] + margin[~target] = -preds[~target] + else: + raise ValueError( + "The `multiclass_mode` should be either None / 'crammer-singer' / MulticlassMode.CRAMMER_SINGER" + "(default) or 'one-vs-all' / MulticlassMode.ONE_VS_ALL," + f" got {multiclass_mode}." + ) + + measures = 1 - margin + measures = torch.clamp(measures, 0) + + if squared: + measures = measures.pow(2) + + total = tensor(target.shape[0], device=target.device) + return measures.sum(dim=0), total + + +def _hinge_compute(measure: Tensor, total: Tensor) -> Tensor: + """Computes mean Hinge loss. + + Example: + >>> # binary case + >>> target = torch.tensor([0, 1, 1]) + >>> preds = torch.tensor([-2.2, 2.4, 0.1]) + >>> measure, total = _hinge_update(preds, target) + >>> _hinge_compute(measure, total) + tensor(0.3000) + + >>> # multiclass case + >>> target = torch.tensor([0, 1, 2]) + >>> preds = torch.tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]]) + >>> measure, total = _hinge_update(preds, target) + >>> _hinge_compute(measure, total) + tensor(2.9000) + + >>> # multiclass one-vs-all mode case + >>> target = torch.tensor([0, 1, 2]) + >>> preds = torch.tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]]) + >>> measure, total = _hinge_update(preds, target, multiclass_mode="one-vs-all") + >>> _hinge_compute(measure, total) + tensor([2.2333, 1.5000, 1.2333]) + """ + + return measure / total + + def hinge_loss( preds: Tensor, target: Tensor, diff --git a/src/torchmetrics/functional/classification/jaccard.py b/src/torchmetrics/functional/classification/jaccard.py index d9c70be4169..6c08eab0c83 100644 --- a/src/torchmetrics/functional/classification/jaccard.py +++ b/src/torchmetrics/functional/classification/jaccard.py @@ -296,6 +296,59 @@ def multilabel_jaccard_index( return _jaccard_index_reduce(confmat, average=average) +def _jaccard_from_confmat( + confmat: Tensor, + num_classes: int, + average: Optional[str] = "macro", + ignore_index: Optional[int] = None, + absent_score: float = 0.0, +) -> Tensor: + """Computes the intersection over union from confusion matrix. + + """ + allowed_average = ["micro", "macro", "weighted", "none", None] + if average not in allowed_average: + raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + + # Remove the ignored class index from the scores. + if ignore_index is not None and 0 <= ignore_index < num_classes: + confmat[ignore_index] = 0.0 + + if average == "none" or average is None: + intersection = torch.diag(confmat) + union = confmat.sum(0) + confmat.sum(1) - intersection + + # If this class is absent in both target AND pred (union == 0), then use the absent_score for this class. + scores = intersection.float() / union.float() + scores = scores.where(union != 0, torch.tensor(absent_score, dtype=scores.dtype, device=scores.device)) + + if ignore_index is not None and 0 <= ignore_index < num_classes: + scores = torch.cat( + [ + scores[:ignore_index], + scores[ignore_index + 1 :] + ] + ) + return scores + + if average == "macro": + scores = _jaccard_from_confmat( + confmat, num_classes, average="none", ignore_index=ignore_index, absent_score=absent_score + ) + return torch.mean(scores) + + if average == "micro": + intersection = torch.sum(torch.diag(confmat)) + union = torch.sum(torch.sum(confmat, dim=1) + torch.sum(confmat, dim=0) - torch.diag(confmat)) + return intersection.float() / union.float() + + weights = torch.sum(confmat, dim=1).float() / torch.sum(confmat).float() + scores = _jaccard_from_confmat( + confmat, num_classes, average="none", ignore_index=ignore_index, absent_score=absent_score + ) + return torch.sum(weights * scores) + + def jaccard_index( preds: Tensor, target: Tensor, diff --git a/src/torchmetrics/functional/classification/matthews_corrcoef.py b/src/torchmetrics/functional/classification/matthews_corrcoef.py index d70510de112..083d1e74e8b 100644 --- a/src/torchmetrics/functional/classification/matthews_corrcoef.py +++ b/src/torchmetrics/functional/classification/matthews_corrcoef.py @@ -22,6 +22,7 @@ _binary_confusion_matrix_format, _binary_confusion_matrix_tensor_validation, _binary_confusion_matrix_update, + _confusion_matrix_update, _multiclass_confusion_matrix_arg_validation, _multiclass_confusion_matrix_format, _multiclass_confusion_matrix_tensor_validation, @@ -230,6 +231,35 @@ def multilabel_matthews_corrcoef( return _matthews_corrcoef_reduce(confmat) +_matthews_corrcoef_update = _confusion_matrix_update + + +def _matthews_corrcoef_compute(confmat: Tensor) -> Tensor: + """Computes Matthews correlation coefficient. + + Example: + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0, 1, 0, 0]) + >>> confmat = _matthews_corrcoef_update(preds, target, num_classes=2) + >>> _matthews_corrcoef_compute(confmat) + tensor(0.5774) + """ + + tk = confmat.sum(dim=1).float() + pk = confmat.sum(dim=0).float() + c = torch.trace(confmat).float() + s = confmat.sum().float() + + cov_ytyp = c * s - sum(tk * pk) + cov_ypyp = s**2 - sum(pk * pk) + cov_ytyt = s**2 - sum(tk * tk) + + if cov_ypyp * cov_ytyt == 0: + return torch.tensor(0, dtype=confmat.dtype, device=confmat.device) + else: + return cov_ytyp / torch.sqrt(cov_ytyt * cov_ypyp) + + def matthews_corrcoef( preds: Tensor, target: Tensor, diff --git a/src/torchmetrics/functional/classification/precision_recall.py b/src/torchmetrics/functional/classification/precision_recall.py index 1c2f20e0e22..be1855a4b5e 100644 --- a/src/torchmetrics/functional/classification/precision_recall.py +++ b/src/torchmetrics/functional/classification/precision_recall.py @@ -30,8 +30,11 @@ _multilabel_stat_scores_format, _multilabel_stat_scores_tensor_validation, _multilabel_stat_scores_update, + _reduce_stat_scores, + _stat_scores_update, ) from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod def _precision_recall_reduce( @@ -652,6 +655,50 @@ def multilabel_recall( return _precision_recall_reduce("recall", tp, fp, tn, fn, average=average, multidim_average=multidim_average) +def _precision_compute( + tp: Tensor, + fp: Tensor, + fn: Tensor, + average: Optional[str], + mdmc_average: Optional[str], +) -> Tensor: + """Computes precision from the stat scores: true positives, false positives, true negatives, false negatives. + + Example: + >>> from torchmetrics.functional.classification.stat_scores import _stat_scores_update + >>> preds = torch.tensor([2, 0, 2, 1]) + >>> target = torch.tensor([1, 1, 2, 0]) + >>> tp, fp, tn, fn = _stat_scores_update( preds, target, reduce='macro', num_classes=3) + >>> _precision_compute(tp, fp, fn, average='macro', mdmc_average=None) + tensor(0.1667) + >>> tp, fp, tn, fn = _stat_scores_update(preds, target, reduce='micro') + >>> _precision_compute(tp, fp, fn, average='micro', mdmc_average=None) + tensor(0.2500) + """ + + numerator = tp.clone() + denominator = tp + fp + + if average == AverageMethod.MACRO and mdmc_average != MDMCAverageMethod.SAMPLEWISE: + cond = tp + fp + fn == 0 + numerator = numerator[~cond] + denominator = denominator[~cond] + + if average == AverageMethod.NONE and mdmc_average != MDMCAverageMethod.SAMPLEWISE: + # a class is not present if there exists no TPs, no FPs, and no FNs + meaningless_indeces = torch.nonzero((tp | fn | fp) == 0).cpu() + numerator[meaningless_indeces, ...] = -1 + denominator[meaningless_indeces, ...] = -1 + + return _reduce_stat_scores( + numerator=numerator, + denominator=denominator, + weights=None if average != "weighted" else tp + fn, + average=average, + mdmc_average=mdmc_average, + ) + + def precision( preds: Tensor, target: Tensor, diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index e69d46275e9..068f03b8fdb 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -19,6 +19,7 @@ from torch.nn import functional as F from typing_extensions import Literal +from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.checks import _check_same_shape from torchmetrics.utilities.compute import _safe_divide from torchmetrics.utilities.data import _bincount @@ -768,6 +769,174 @@ def multilabel_precision_recall_curve( return _multilabel_precision_recall_curve_compute(state, num_labels, thresholds, ignore_index) +def _precision_recall_curve_update( + preds: Tensor, + target: Tensor, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, +) -> Tuple[Tensor, Tensor, int, Optional[int]]: + """Updates and returns variables required to compute the precision-recall pairs for different thresholds. + + """ + + if len(preds.shape) == len(target.shape): + if pos_label is None: + pos_label = 1 + if num_classes is not None and num_classes != 1: + # multilabel problem + if num_classes != preds.shape[1]: + raise ValueError( + f"Argument `num_classes` was set to {num_classes} in" + f" metric `precision_recall_curve` but detected {preds.shape[1]}" + " number of classes from predictions" + ) + preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1) + target = target.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1) + else: + # binary problem + preds = preds.flatten() + target = target.flatten() + num_classes = 1 + + # multi class problem + elif len(preds.shape) == len(target.shape) + 1: + if pos_label is not None: + rank_zero_warn( + "Argument `pos_label` should be `None` when running" + f" multiclass precision recall curve. Got {pos_label}" + ) + if num_classes != preds.shape[1]: + raise ValueError( + f"Argument `num_classes` was set to {num_classes} in" + f" metric `precision_recall_curve` but detected {preds.shape[1]}" + " number of classes from predictions" + ) + preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1) + target = target.flatten() + + else: + raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds") + + return preds, target, num_classes, pos_label + + +def _precision_recall_curve_compute_single_class( + preds: Tensor, + target: Tensor, + pos_label: int, + sample_weights: Optional[Sequence] = None, +) -> Tuple[Tensor, Tensor, Tensor]: + """Computes precision-recall pairs for single class inputs. + + """ + + fps, tps, thresholds = _binary_clf_curve( + preds=preds, target=target, sample_weights=sample_weights, pos_label=pos_label + ) + precision = tps / (tps + fps) + recall = tps / tps[-1] + + # stop when full recall attained and reverse the outputs so recall is decreasing + last_ind = torch.where(tps == tps[-1])[0][0] + sl = slice(0, last_ind.item() + 1) + + # need to call reversed explicitly, since including that to slice would + # introduce negative strides that are not yet supported in pytorch + precision = torch.cat([reversed(precision[sl]), torch.ones(1, dtype=precision.dtype, device=precision.device)]) + + recall = torch.cat([reversed(recall[sl]), torch.zeros(1, dtype=recall.dtype, device=recall.device)]) + + thresholds = reversed(thresholds[sl]).detach().clone() # type: ignore + + return precision, recall, thresholds + + +def _precision_recall_curve_compute_multi_class( + preds: Tensor, + target: Tensor, + num_classes: int, + sample_weights: Optional[Sequence] = None, +) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]: + """Computes precision-recall pairs for multiclass inputs. + + """ + + # Recursively call per class + precision, recall, thresholds = [], [], [] + for cls in range(num_classes): + preds_cls = preds[:, cls] + + prc_args = dict( + preds=preds_cls, + target=target, + num_classes=1, + pos_label=cls, + sample_weights=sample_weights, + ) + if target.ndim > 1: + prc_args.update( + dict( + target=target[:, cls], + pos_label=1, + ) + ) + res = precision_recall_curve(**prc_args) + precision.append(res[0]) + recall.append(res[1]) + thresholds.append(res[2]) + + return precision, recall, thresholds + + +def _precision_recall_curve_compute( + preds: Tensor, + target: Tensor, + num_classes: int, + pos_label: Optional[int] = None, + sample_weights: Optional[Sequence] = None, +) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + """Computes precision-recall pairs based on the number of classes. + + Example: + >>> # binary case + >>> preds = torch.tensor([0, 1, 2, 3]) + >>> target = torch.tensor([0, 1, 1, 0]) + >>> pos_label = 1 + >>> preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target, pos_label=pos_label) + >>> precision, recall, thresholds = _precision_recall_curve_compute(preds, target, num_classes, pos_label) + >>> precision + tensor([0.6667, 0.5000, 0.0000, 1.0000]) + >>> recall + tensor([1.0000, 0.5000, 0.0000, 0.0000]) + >>> thresholds + tensor([1, 2, 3]) + + >>> # multiclass case + >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> num_classes = 5 + >>> preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target, num_classes) + >>> precision, recall, thresholds = _precision_recall_curve_compute(preds, target, num_classes) + >>> precision + [tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]), + tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])] + >>> recall + [tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])] + >>> thresholds + [tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])] + """ + + with torch.no_grad(): + if num_classes == 1: + if pos_label is None: + pos_label = 1 + return _precision_recall_curve_compute_single_class(preds, target, pos_label, sample_weights) + return _precision_recall_curve_compute_multi_class(preds, target, num_classes, sample_weights) + + def precision_recall_curve( preds: Tensor, target: Tensor, diff --git a/src/torchmetrics/functional/classification/roc.py b/src/torchmetrics/functional/classification/roc.py index 0726069e67b..0df65887812 100644 --- a/src/torchmetrics/functional/classification/roc.py +++ b/src/torchmetrics/functional/classification/roc.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Sequence, Tuple, Union import torch from torch import Tensor @@ -31,6 +31,7 @@ _multilabel_precision_recall_curve_format, _multilabel_precision_recall_curve_tensor_validation, _multilabel_precision_recall_curve_update, + _precision_recall_curve_update, ) from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.compute import _safe_divide @@ -420,6 +421,146 @@ def multilabel_roc( return _multilabel_roc_compute(state, num_labels, thresholds, ignore_index) +def _roc_update( + preds: Tensor, + target: Tensor, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, +) -> Tuple[Tensor, Tensor, int, Optional[int]]: + """Updates and returns variables required to compute the Receiver Operating Characteristic. + + """ + + return _precision_recall_curve_update(preds, target, num_classes, pos_label) + + +def _roc_compute_single_class( + preds: Tensor, + target: Tensor, + pos_label: int, + sample_weights: Optional[Sequence] = None, +) -> Tuple[Tensor, Tensor, Tensor]: + """Computes Receiver Operating Characteristic for single class inputs. Returns tensor with false positive + rates, tensor with true positive rates, tensor with thresholds used for computing false- and true-postive + rates. + + """ + + fps, tps, thresholds = _binary_clf_curve( + preds=preds, target=target, sample_weights=sample_weights, pos_label=pos_label + ) + # Add an extra threshold position to make sure that the curve starts at (0, 0) + tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps]) + fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps]) + thresholds = torch.cat([thresholds[0][None] + 1, thresholds]) + + if fps[-1] <= 0: + rank_zero_warn( + "No negative samples in targets, false positive value should be meaningless." + " Returning zero tensor in false positive score", + UserWarning, + ) + fpr = torch.zeros_like(thresholds) + else: + fpr = fps / fps[-1] + + if tps[-1] <= 0: + rank_zero_warn( + "No positive samples in targets, true positive value should be meaningless." + " Returning zero tensor in true positive score", + UserWarning, + ) + tpr = torch.zeros_like(thresholds) + else: + tpr = tps / tps[-1] + + return fpr, tpr, thresholds + + +def _roc_compute_multi_class( + preds: Tensor, + target: Tensor, + num_classes: int, + sample_weights: Optional[Sequence] = None, +) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]: + """Computes Receiver Operating Characteristic for multi class inputs. Returns tensor with false positive rates, + tensor with true positive rates, tensor with thresholds used for computing false- and true-postive rates. + + """ + + fpr, tpr, thresholds = [], [], [] + for cls in range(num_classes): + if preds.shape == target.shape: + target_cls = target[:, cls] + pos_label = 1 + else: + target_cls = target + pos_label = cls + res = roc( + preds=preds[:, cls], + target=target_cls, + num_classes=1, + pos_label=pos_label, + sample_weights=sample_weights, + ) + fpr.append(res[0]) + tpr.append(res[1]) + thresholds.append(res[2]) + + return fpr, tpr, thresholds + + +def _roc_compute( + preds: Tensor, + target: Tensor, + num_classes: int, + pos_label: Optional[int] = None, + sample_weights: Optional[Sequence] = None, +) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + """Computes Receiver Operating Characteristic based on the number of classes. + + Example: + >>> # binary case + >>> preds = torch.tensor([0, 1, 2, 3]) + >>> target = torch.tensor([0, 1, 1, 1]) + >>> pos_label = 1 + >>> preds, target, num_classes, pos_label = _roc_update(preds, target, pos_label=pos_label) + >>> fpr, tpr, thresholds = _roc_compute(preds, target, num_classes, pos_label) + >>> fpr + tensor([0., 0., 0., 0., 1.]) + >>> tpr + tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000]) + >>> thresholds + tensor([4, 3, 2, 1, 0]) + + >>> # multiclass case + >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05], + ... [0.05, 0.05, 0.05, 0.75]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> num_classes = 4 + >>> preds, target, num_classes, pos_label = _roc_update(preds, target, num_classes) + >>> fpr, tpr, thresholds = _roc_compute(preds, target, num_classes) + >>> fpr + [tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0.0000, 0.3333, 1.0000]), tensor([0.0000, 0.3333, 1.0000])] + >>> tpr + [tensor([0., 1., 1.]), tensor([0., 1., 1.]), tensor([0., 0., 1.]), tensor([0., 0., 1.])] + >>> thresholds + [tensor([1.7500, 0.7500, 0.0500]), + tensor([1.7500, 0.7500, 0.0500]), + tensor([1.7500, 0.7500, 0.0500]), + tensor([1.7500, 0.7500, 0.0500])] + """ + + with torch.no_grad(): + if num_classes == 1 and preds.ndim == 1: # binary + if pos_label is None: + pos_label = 1 + return _roc_compute_single_class(preds, target, pos_label, sample_weights) + return _roc_compute_multi_class(preds, target, num_classes, sample_weights) + + def roc( preds: Tensor, target: Tensor, diff --git a/src/torchmetrics/functional/classification/specificity.py b/src/torchmetrics/functional/classification/specificity.py index b9be98e100e..ba2e0e226f5 100644 --- a/src/torchmetrics/functional/classification/specificity.py +++ b/src/torchmetrics/functional/classification/specificity.py @@ -30,8 +30,11 @@ _multilabel_stat_scores_format, _multilabel_stat_scores_tensor_validation, _multilabel_stat_scores_update, + _reduce_stat_scores, + _stat_scores_update, ) from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod def _specificity_reduce( @@ -354,6 +357,44 @@ def multilabel_specificity( return _specificity_reduce(tp, fp, tn, fn, average=average, multidim_average=multidim_average) +def _specificity_compute( + tp: Tensor, + fp: Tensor, + tn: Tensor, + fn: Tensor, + average: Optional[str], + mdmc_average: Optional[str], +) -> Tensor: + """Computes specificity from the stat scores: true positives, false positives, true negatives, false negatives. + + Example: + >>> from torchmetrics.functional.classification.stat_scores import _stat_scores_update + >>> preds = torch.tensor([2, 0, 2, 1]) + >>> target = torch.tensor([1, 1, 2, 0]) + >>> tp, fp, tn, fn = _stat_scores_update(preds, target, reduce='macro', num_classes=3) + >>> _specificity_compute(tp, fp, tn, fn, average='macro', mdmc_average=None) + tensor(0.6111) + >>> tp, fp, tn, fn = _stat_scores_update(preds, target, reduce='micro') + >>> _specificity_compute(tp, fp, tn, fn, average='micro', mdmc_average=None) + tensor(0.6250) + """ + + numerator = tn.clone() + denominator = tn + fp + if average == AverageMethod.NONE and mdmc_average != MDMCAverageMethod.SAMPLEWISE: + # a class is not present if there exists no TPs, no FPs, and no FNs + meaningless_indeces = torch.nonzero((tp | fn | fp) == 0).cpu() + numerator[meaningless_indeces, ...] = -1 + denominator[meaningless_indeces, ...] = -1 + return _reduce_stat_scores( + numerator=numerator, + denominator=denominator, + weights=None if average != AverageMethod.WEIGHTED else denominator, + average=average, + mdmc_average=mdmc_average, + ) + + def specificity( preds: Tensor, target: Tensor, diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index 6571aaccbac..82adee32cce 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -1001,6 +1001,18 @@ def _stat_scores_compute(tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor) -> Tens fp: False positives tn: True negatives fn: False negatives + + Example: + >>> preds = torch.tensor([1, 0, 2, 1]) + >>> target = torch.tensor([1, 1, 2, 0]) + >>> tp, fp, tn, fn = _stat_scores_update(preds, target, reduce='macro', num_classes=3) + >>> _stat_scores_compute(tp, fp, tn, fn) + tensor([[0, 1, 2, 1, 1], + [1, 1, 1, 1, 2], + [1, 0, 3, 0, 1]]) + >>> tp, fp, tn, fn = _stat_scores_update(preds, target, reduce='micro') + >>> _stat_scores_compute(tp, fp, tn, fn) + tensor([2, 2, 6, 2, 4]) """ stats = [ tp.unsqueeze(-1), From 1f71c02cdfa4d69a09ee814c4cdd51370fa216bd Mon Sep 17 00:00:00 2001 From: Jirka Date: Sat, 24 Jun 2023 00:36:26 +0200 Subject: [PATCH 02/11] precommit update --- .pre-commit-config.yaml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 80dc07db4f1..d6b47590fb6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,7 +23,7 @@ ci: repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.4.0 hooks: - id: end-of-file-fixer - id: trailing-whitespace @@ -38,27 +38,27 @@ repos: - id: detect-private-key - repo: https://github.com/asottile/pyupgrade - rev: v2.38.2 + rev: v3.7.0 hooks: - id: pyupgrade args: [--py36-plus] name: Upgrade code - repo: https://github.com/PyCQA/docformatter - rev: v1.5.0 + rev: v1.7.3 hooks: - id: docformatter args: [--in-place, --wrap-summaries=115, --wrap-descriptions=120] - repo: https://github.com/PyCQA/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort name: imports require_serial: false - repo: https://github.com/psf/black - rev: 22.8.0 + rev: 23.3.0 hooks: - id: black name: Format code @@ -78,12 +78,12 @@ repos: )$ - repo: https://github.com/asottile/yesqa - rev: v1.4.0 + rev: v1.5.0 hooks: - id: yesqa - repo: https://github.com/PyCQA/flake8 - rev: 5.0.4 + rev: 6.0.0 hooks: - id: flake8 name: PEP8 From b5fc411b7635848e9b4724764af17fc407b96ff2 Mon Sep 17 00:00:00 2001 From: Jirka Date: Sat, 24 Jun 2023 00:39:40 +0200 Subject: [PATCH 03/11] fix pep8 config --- setup.cfg | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 22deeda0747..abae64d8402 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,8 +37,10 @@ verbose = 2 format = pylint # see: https://www.flake8rules.com/ ignore = - E731 # Do not assign a lambda expression, use a def - E203 # whitespace before ':' + # Do not assign a lambda expression, use a def + E731 + # whitespace before ':' + E203 # setup.cfg or tox.ini From 34f904637424086c565759b98dcfbf75fc7389d2 Mon Sep 17 00:00:00 2001 From: Jirka Date: Sat, 24 Jun 2023 00:46:16 +0200 Subject: [PATCH 04/11] fix import --- src/torchmetrics/classification/dice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/classification/dice.py b/src/torchmetrics/classification/dice.py index 6067d7309a1..714fa9de06b 100644 --- a/src/torchmetrics/classification/dice.py +++ b/src/torchmetrics/classification/dice.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any, Optional, no_type_check from torch import Tensor from typing_extensions import Literal From 26deece07220cd3d68622a127d2ec74fd5186415 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 23 Jun 2023 22:51:42 +0000 Subject: [PATCH 05/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- README.md | 1 - src/torchmetrics/classification/accuracy.py | 10 +- src/torchmetrics/classification/auroc.py | 3 +- .../classification/average_precision.py | 6 +- .../classification/calibration_error.py | 6 +- .../classification/cohen_kappa.py | 3 +- .../classification/confusion_matrix.py | 6 +- src/torchmetrics/classification/hamming.py | 3 +- src/torchmetrics/classification/jaccard.py | 14 +-- .../classification/matthews_corrcoef.py | 3 +- .../classification/precision_recall.py | 7 +- .../classification/precision_recall_curve.py | 6 +- src/torchmetrics/classification/roc.py | 6 +- .../classification/specificity.py | 3 +- .../classification/stat_scores.py | 7 +- src/torchmetrics/detection/mean_ap.py | 5 - .../functional/classification/accuracy.py | 9 +- .../functional/classification/auroc.py | 2 +- .../classification/average_precision.py | 4 +- .../classification/calibration_error.py | 4 +- .../classification/confusion_matrix.py | 4 +- .../functional/classification/hamming.py | 3 +- .../functional/classification/hinge.py | 7 +- .../functional/classification/jaccard.py | 11 +- .../classification/precision_recall_curve.py | 12 +- .../functional/classification/roc.py | 15 ++- .../functional/regression/spearman.py | 6 +- src/torchmetrics/functional/text/bleu.py | 2 +- src/torchmetrics/functional/text/cer.py | 2 +- src/torchmetrics/functional/text/chrf.py | 2 +- src/torchmetrics/functional/text/eed.py | 1 - src/torchmetrics/functional/text/helper.py | 3 +- .../functional/text/sacre_bleu.py | 4 +- src/torchmetrics/functional/text/ter.py | 2 +- src/torchmetrics/image/fid.py | 2 +- src/torchmetrics/image/lpip.py | 4 +- src/torchmetrics/metric.py | 4 +- src/torchmetrics/multimodal/clip_score.py | 1 - src/torchmetrics/utilities/compute.py | 1 - tests/unittests/audio/test_pit.py | 2 +- tests/unittests/bases/test_aggregation.py | 26 ++--- tests/unittests/bases/test_composition.py | 2 +- tests/unittests/bases/test_ddp.py | 2 - tests/unittests/bases/test_metric.py | 10 +- tests/unittests/classification/test_auroc.py | 2 +- .../classification/test_average_precision.py | 2 +- .../test_precision_recall_curve.py | 2 +- .../test_recall_at_fixed_precision.py | 2 +- tests/unittests/classification/test_roc.py | 2 +- .../classification/test_specificity.py | 2 +- tests/unittests/detection/test_map.py | 104 +++++++++--------- tests/unittests/image/test_fid.py | 4 +- tests/unittests/image/test_inception.py | 2 +- tests/unittests/image/test_kid.py | 4 +- tests/unittests/image/test_lpips.py | 12 +- tests/unittests/image/test_tv.py | 2 +- .../pairwise/test_pairwise_distance.py | 12 +- tests/unittests/regression/test_spearman.py | 2 +- tests/unittests/text/test_cer.py | 8 +- tests/unittests/text/test_infolm.py | 3 +- tests/unittests/text/test_mer.py | 3 - tests/unittests/text/test_wer.py | 3 - tests/unittests/text/test_wil.py | 3 - tests/unittests/text/test_wip.py | 3 - tests/unittests/utilities/test_utilities.py | 2 +- .../unittests/wrappers/test_bootstrapping.py | 3 +- tests/unittests/wrappers/test_minmax.py | 12 +- 67 files changed, 175 insertions(+), 255 deletions(-) diff --git a/README.md b/README.md index c45228afe01..5165b75ca23 100644 --- a/README.md +++ b/README.md @@ -184,7 +184,6 @@ def metric_ddp(rank, world_size): n_epochs = 5 # this shows iteration over multiple training epochs for n in range(n_epochs): - # this will be replaced by a DataLoader with a DistributedSampler n_batches = 10 for i in range(n_batches): diff --git a/src/torchmetrics/classification/accuracy.py b/src/torchmetrics/classification/accuracy.py index 65c0e380f1f..c69ac902475 100644 --- a/src/torchmetrics/classification/accuracy.py +++ b/src/torchmetrics/classification/accuracy.py @@ -27,14 +27,15 @@ _subset_accuracy_update, ) from torchmetrics.metric import Metric +from torchmetrics.utilities import rank_zero_warn +from torchmetrics.utilities.enums import AverageMethod from torchmetrics.classification.stat_scores import ( # isort:skip BinaryStatScores, MulticlassStatScores, - MultilabelStatScores, StatScores, + MultilabelStatScores, + StatScores, ) -from torchmetrics.utilities import rank_zero_warn -from torchmetrics.utilities.enums import AverageMethod class BinaryAccuracy(BinaryStatScores): @@ -459,8 +460,7 @@ def __init__( self.add_state("total", default=tensor(0), dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - """Update state with predictions and targets. - """ + """Update state with predictions and targets.""" mode = _mode(preds, target, self.threshold, self.top_k, self.num_classes, self.multiclass, self.ignore_index) if not self.mode: diff --git a/src/torchmetrics/classification/auroc.py b/src/torchmetrics/classification/auroc.py index 910802e78e4..e3a321d3a70 100644 --- a/src/torchmetrics/classification/auroc.py +++ b/src/torchmetrics/classification/auroc.py @@ -439,8 +439,7 @@ def __init__( ) def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - """Update state with predictions and targets. - """ + """Update state with predictions and targets.""" preds, target, mode = _auroc_update(preds, target) self.preds.append(preds) diff --git a/src/torchmetrics/classification/average_precision.py b/src/torchmetrics/classification/average_precision.py index 80eaa84c7ea..be89f6c0404 100644 --- a/src/torchmetrics/classification/average_precision.py +++ b/src/torchmetrics/classification/average_precision.py @@ -432,8 +432,7 @@ def __init__( ) def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - """Update state with predictions and targets. - """ + """Update state with predictions and targets.""" preds, target, num_classes, pos_label = _average_precision_update( preds, target, self.num_classes, self.pos_label, self.average ) @@ -443,8 +442,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.pos_label = pos_label def compute(self) -> Union[Tensor, List[Tensor]]: - """Compute the average precision score. - """ + """Compute the average precision score.""" preds = dim_zero_cat(self.preds) target = dim_zero_cat(self.target) if not self.num_classes: diff --git a/src/torchmetrics/classification/calibration_error.py b/src/torchmetrics/classification/calibration_error.py index 2b0021c5e30..2e7f8bc6cc0 100644 --- a/src/torchmetrics/classification/calibration_error.py +++ b/src/torchmetrics/classification/calibration_error.py @@ -316,16 +316,14 @@ def __init__( def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Computes top-level confidences and accuracies for the input probabilities and appends them to internal - state. - """ + state.""" confidences, accuracies = _ce_update(preds, target) self.confidences.append(confidences) self.accuracies.append(accuracies) def compute(self) -> Tensor: - """Computes calibration error across all confidences and accuracies. - """ + """Computes calibration error across all confidences and accuracies.""" confidences = dim_zero_cat(self.confidences) accuracies = dim_zero_cat(self.accuracies) return _ce_compute(confidences, accuracies, self.bin_boundaries.to(self.device), norm=self.norm) diff --git a/src/torchmetrics/classification/cohen_kappa.py b/src/torchmetrics/classification/cohen_kappa.py index 511d7176055..b10a0f01480 100644 --- a/src/torchmetrics/classification/cohen_kappa.py +++ b/src/torchmetrics/classification/cohen_kappa.py @@ -267,8 +267,7 @@ def __init__( self.add_state("confmat", default=torch.zeros(num_classes, num_classes, dtype=torch.long), dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - """Update state with predictions and targets. - """ + """Update state with predictions and targets.""" confmat = _cohen_kappa_update(preds, target, self.num_classes, self.threshold) self.confmat += confmat diff --git a/src/torchmetrics/classification/confusion_matrix.py b/src/torchmetrics/classification/confusion_matrix.py index 598db010bc6..a5a0ac8bb52 100644 --- a/src/torchmetrics/classification/confusion_matrix.py +++ b/src/torchmetrics/classification/confusion_matrix.py @@ -426,12 +426,10 @@ def __init__( self.add_state("confmat", default=default, dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - """Update state with predictions and targets. - """ + """Update state with predictions and targets.""" confmat = _confusion_matrix_update(preds, target, self.num_classes, self.threshold, self.multilabel) self.confmat += confmat def compute(self) -> Tensor: - """Computes confusion matrix. - """ + """Computes confusion matrix.""" return _confusion_matrix_compute(self.confmat, self.normalize) diff --git a/src/torchmetrics/classification/hamming.py b/src/torchmetrics/classification/hamming.py index 163ef938f8c..e836971e7ca 100644 --- a/src/torchmetrics/classification/hamming.py +++ b/src/torchmetrics/classification/hamming.py @@ -408,8 +408,7 @@ def __init__( self.threshold = threshold def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - """Update state with predictions and targets. - """ + """Update state with predictions and targets.""" correct, total = _hamming_distance_update(preds, target, self.threshold) self.correct += correct diff --git a/src/torchmetrics/classification/jaccard.py b/src/torchmetrics/classification/jaccard.py index 1b37c92397e..790445d9de7 100644 --- a/src/torchmetrics/classification/jaccard.py +++ b/src/torchmetrics/classification/jaccard.py @@ -353,21 +353,11 @@ def compute(self) -> Tensor: if self.multilabel: return torch.stack( [ - _jaccard_from_confmat( - confmat, - 2, - self.average, - self.ignore_index, - self.absent_score - )[1] + _jaccard_from_confmat(confmat, 2, self.average, self.ignore_index, self.absent_score)[1] for confmat in self.confmat ] ) else: return _jaccard_from_confmat( - self.confmat, - self.num_classes, - self.average, - self.ignore_index, - self.absent_score + self.confmat, self.num_classes, self.average, self.ignore_index, self.absent_score ) diff --git a/src/torchmetrics/classification/matthews_corrcoef.py b/src/torchmetrics/classification/matthews_corrcoef.py index 5e47b159fae..8bc88b819c0 100644 --- a/src/torchmetrics/classification/matthews_corrcoef.py +++ b/src/torchmetrics/classification/matthews_corrcoef.py @@ -297,8 +297,7 @@ def __init__( self.add_state("confmat", default=torch.zeros(num_classes, num_classes), dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - """Update state with predictions and targets. - """ + """Update state with predictions and targets.""" confmat = _matthews_corrcoef_update(preds, target, self.num_classes, self.threshold) self.confmat += confmat diff --git a/src/torchmetrics/classification/precision_recall.py b/src/torchmetrics/classification/precision_recall.py index e64255e06e6..2bca62deb83 100644 --- a/src/torchmetrics/classification/precision_recall.py +++ b/src/torchmetrics/classification/precision_recall.py @@ -716,8 +716,7 @@ def __init__( self.average = average def compute(self) -> Tensor: - """Computes the precision score based on inputs passed in to ``update`` previously. - """ + """Computes the precision score based on inputs passed in to ``update`` previously.""" tp, fp, _, fn = self._get_final_stats() return _precision_compute(tp, fp, fn, self.average, self.mdmc_reduce) @@ -835,8 +834,6 @@ def __init__( self.average = average def compute(self) -> Tensor: - """Computes the recall score based on inputs passed in to ``update`` previously. - - """ + """Computes the recall score based on inputs passed in to ``update`` previously.""" tp, fp, _, fn = self._get_final_stats() return _recall_compute(tp, fp, fn, self.average, self.mdmc_reduce) diff --git a/src/torchmetrics/classification/precision_recall_curve.py b/src/torchmetrics/classification/precision_recall_curve.py index f22b47cd068..f0192e03dd7 100644 --- a/src/torchmetrics/classification/precision_recall_curve.py +++ b/src/torchmetrics/classification/precision_recall_curve.py @@ -528,8 +528,7 @@ def __init__( ) def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - """Update state with predictions and targets. - """ + """Update state with predictions and targets.""" preds, target, num_classes, pos_label = _precision_recall_curve_update( preds, target, self.num_classes, self.pos_label ) @@ -539,8 +538,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.pos_label = pos_label def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: - """Compute the precision-recall curve. - """ + """Compute the precision-recall curve.""" preds = dim_zero_cat(self.preds) target = dim_zero_cat(self.target) if not self.num_classes: diff --git a/src/torchmetrics/classification/roc.py b/src/torchmetrics/classification/roc.py index 6b138166823..87d9b9cee8c 100644 --- a/src/torchmetrics/classification/roc.py +++ b/src/torchmetrics/classification/roc.py @@ -437,8 +437,7 @@ def __init__( ) def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - """Update state with predictions and targets. - """ + """Update state with predictions and targets.""" preds, target, num_classes, pos_label = _roc_update(preds, target, self.num_classes, self.pos_label) self.preds.append(preds) self.target.append(target) @@ -446,8 +445,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.pos_label = pos_label def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: - """Compute the receiver operating characteristic. - """ + """Compute the receiver operating characteristic.""" preds = dim_zero_cat(self.preds) target = dim_zero_cat(self.target) if not self.num_classes: diff --git a/src/torchmetrics/classification/specificity.py b/src/torchmetrics/classification/specificity.py index 927fb25564f..00d0d0cddb7 100644 --- a/src/torchmetrics/classification/specificity.py +++ b/src/torchmetrics/classification/specificity.py @@ -406,7 +406,6 @@ def __init__( self.average = average def compute(self) -> Tensor: - """Computes the specificity score based on inputs passed in to ``update`` previously. - """ + """Computes the specificity score based on inputs passed in to ``update`` previously.""" tp, fp, tn, fn = self._get_final_stats() return _specificity_compute(tp, fp, tn, fn, self.average, self.mdmc_reduce) diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index b01d4e08451..a0b75313c08 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -627,8 +627,7 @@ def __init__( self.add_state(s, default=default(), dist_reduce_fx=reduce_fn) def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - """Update state with predictions and targets. - """ + """Update state with predictions and targets.""" tp, fp, tn, fn = _stat_scores_update( preds, target, @@ -662,8 +661,6 @@ def _get_final_stats(self) -> Tuple[Tensor, Tensor, Tensor, Tensor]: return tp, fp, tn, fn def compute(self) -> Tensor: - """Computes the stat scores based on inputs passed in to ``update`` previously. - - """ + """Computes the stat scores based on inputs passed in to ``update`` previously.""" tp, fp, tn, fn = self._get_final_stats() return _stat_scores_compute(tp, fp, tn, fn) diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index 69cc0edf574..012931358cb 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -43,13 +43,11 @@ def compute_area(input: List[Any], iou_type: str = "bbox") -> Tensor: Default output for empty input is torch.Tensor([]) """ if len(input) == 0: - return torch.Tensor([]) if iou_type == "bbox": return box_area(torch.stack(input)) elif iou_type == "segm": - input = [{"size": i[0], "counts": i[1]} for i in input] area = torch.tensor(mask_utils.area(input).astype("float")) @@ -389,7 +387,6 @@ def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]] _input_validator(preds, target, iou_type=self.iou_type) for item in preds: - detections = self._get_safe_item_values(item) self.detections.append(detections) @@ -416,7 +413,6 @@ def _move_list_states_to_cpu(self) -> None: setattr(self, key, current_to_cpu) def _get_safe_item_values(self, item: Dict[str, Any]) -> Union[Tensor, Tuple]: - if self.iou_type == "bbox": boxes = _fix_empty_tensors(item["boxes"]) if boxes.numel() > 0: @@ -860,7 +856,6 @@ def __calculate_recall_precision_scores( diff_zero = torch.zeros((1,), device=pr.device) diff = torch.ones((1,), device=pr.device) while not torch.all(diff == 0): - diff = torch.clamp(torch.cat(((pr[1:] - pr[:-1]), diff_zero), 0), min=0) pr += diff diff --git a/src/torchmetrics/functional/classification/accuracy.py b/src/torchmetrics/functional/classification/accuracy.py index 18a695d1788..7d9dc92e7cf 100644 --- a/src/torchmetrics/functional/classification/accuracy.py +++ b/src/torchmetrics/functional/classification/accuracy.py @@ -433,8 +433,7 @@ def _accuracy_update( mode: DataType, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Updates and returns stat scores (true positives, false positives, true negatives, false negatives) required - to compute accuracy. - """ + to compute accuracy.""" if mode == DataType.MULTILABEL and top_k: raise ValueError("You can not use the `top_k` parameter to calculate accuracy for multi-label inputs.") @@ -527,8 +526,7 @@ def _subset_accuracy_update( top_k: Optional[int], ignore_index: Optional[int] = None, ) -> Tuple[Tensor, Tensor]: - """Updates and returns variables required to compute subset accuracy. - """ + """Updates and returns variables required to compute subset accuracy.""" preds, target = _input_squeeze(preds, target) preds, target, mode = _input_format_classification( @@ -555,8 +553,7 @@ def _subset_accuracy_update( def _subset_accuracy_compute(correct: Tensor, total: Tensor) -> Tensor: - """Computes subset accuracy from number of correct observations and total number of observations. - """ + """Computes subset accuracy from number of correct observations and total number of observations.""" return correct.float() / total diff --git a/src/torchmetrics/functional/classification/auroc.py b/src/torchmetrics/functional/classification/auroc.py index 53cf335e404..ea50adf8192 100644 --- a/src/torchmetrics/functional/classification/auroc.py +++ b/src/torchmetrics/functional/classification/auroc.py @@ -420,8 +420,8 @@ def multilabel_auroc( def _auroc_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor, DataType]: """Updates and returns variables required to compute Area Under the Receiver Operating Characteristic Curve. - Validates the inputs and returns the mode of the inputs. + Validates the inputs and returns the mode of the inputs. """ # use _input_format_classification for validating the input and get the mode of data diff --git a/src/torchmetrics/functional/classification/average_precision.py b/src/torchmetrics/functional/classification/average_precision.py index 62d23824477..5d7529d6c75 100644 --- a/src/torchmetrics/functional/classification/average_precision.py +++ b/src/torchmetrics/functional/classification/average_precision.py @@ -408,9 +408,7 @@ def _average_precision_update( pos_label: Optional[int] = None, average: Optional[str] = "macro", ) -> Tuple[Tensor, Tensor, int, Optional[int]]: - """Format the predictions and target based on the ``num_classes``, ``pos_label`` and ``average`` parameter. - - """ + """Format the predictions and target based on the ``num_classes``, ``pos_label`` and ``average`` parameter.""" preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target, num_classes, pos_label) if average == "micro" and preds.ndim != target.ndim: raise ValueError("Cannot use `micro` average with multi-class input") diff --git a/src/torchmetrics/functional/classification/calibration_error.py b/src/torchmetrics/functional/classification/calibration_error.py index 0a0b901aaa9..511d4817ae0 100644 --- a/src/torchmetrics/functional/classification/calibration_error.py +++ b/src/torchmetrics/functional/classification/calibration_error.py @@ -317,9 +317,7 @@ def multiclass_calibration_error( def _ce_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: """Given a predictions and targets tensor, computes the confidences of the top-1 prediction and records their - correctness. - - """ + correctness.""" _, _, mode = _input_format_classification(preds, target) if mode == DataType.BINARY: diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index ecb50a036e5..bb928021819 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -595,9 +595,7 @@ def multilabel_confusion_matrix( def _confusion_matrix_update( preds: Tensor, target: Tensor, num_classes: int, threshold: float = 0.5, multilabel: bool = False ) -> Tensor: - """Updates and returns confusion matrix (without any normalization) based on the mode of the input. - - """ + """Updates and returns confusion matrix (without any normalization) based on the mode of the input.""" preds, target, mode = _input_format_classification(preds, target, threshold) if mode not in (DataType.BINARY, DataType.MULTILABEL): diff --git a/src/torchmetrics/functional/classification/hamming.py b/src/torchmetrics/functional/classification/hamming.py index f1539325d1f..5f4dca56fb1 100644 --- a/src/torchmetrics/functional/classification/hamming.py +++ b/src/torchmetrics/functional/classification/hamming.py @@ -392,8 +392,7 @@ def _hamming_distance_update( target: Tensor, threshold: float = 0.5, ) -> Tuple[Tensor, int]: - """Returns the number of positions where prediction equals target, and number of predictions. - """ + """Returns the number of positions where prediction equals target, and number of predictions.""" preds, target, _ = _input_format_classification(preds, target, threshold=threshold) diff --git a/src/torchmetrics/functional/classification/hinge.py b/src/torchmetrics/functional/classification/hinge.py index 9a2f7a6f4d0..c1c4b162889 100644 --- a/src/torchmetrics/functional/classification/hinge.py +++ b/src/torchmetrics/functional/classification/hinge.py @@ -53,7 +53,6 @@ def _binary_hinge_loss_update( target: Tensor, squared: bool, ) -> Tuple[Tensor, Tensor]: - target = target.bool() margin = torch.zeros_like(preds) margin[target] = preds[target] @@ -257,8 +256,7 @@ def _check_shape_and_type_consistency_hinge( preds: Tensor, target: Tensor, ) -> DataType: - """Checks shape and type of ``preds`` and ``target`` and returns mode of the input tensors. - """ + """Checks shape and type of ``preds`` and ``target`` and returns mode of the input tensors.""" if target.ndim > 1: raise ValueError( @@ -290,8 +288,7 @@ def _hinge_update( squared: bool = False, multiclass_mode: Optional[Union[str, MulticlassMode]] = None, ) -> Tuple[Tensor, Tensor]: - """Updates and returns sum over Hinge loss scores for each observation and the total number of observations. - """ + """Updates and returns sum over Hinge loss scores for each observation and the total number of observations.""" preds, target = _input_squeeze(preds, target) mode = _check_shape_and_type_consistency_hinge(preds, target) diff --git a/src/torchmetrics/functional/classification/jaccard.py b/src/torchmetrics/functional/classification/jaccard.py index 6c08eab0c83..d4395912692 100644 --- a/src/torchmetrics/functional/classification/jaccard.py +++ b/src/torchmetrics/functional/classification/jaccard.py @@ -303,9 +303,7 @@ def _jaccard_from_confmat( ignore_index: Optional[int] = None, absent_score: float = 0.0, ) -> Tensor: - """Computes the intersection over union from confusion matrix. - - """ + """Computes the intersection over union from confusion matrix.""" allowed_average = ["micro", "macro", "weighted", "none", None] if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") @@ -323,12 +321,7 @@ def _jaccard_from_confmat( scores = scores.where(union != 0, torch.tensor(absent_score, dtype=scores.dtype, device=scores.device)) if ignore_index is not None and 0 <= ignore_index < num_classes: - scores = torch.cat( - [ - scores[:ignore_index], - scores[ignore_index + 1 :] - ] - ) + scores = torch.cat([scores[:ignore_index], scores[ignore_index + 1 :]]) return scores if average == "macro": diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index 068f03b8fdb..fa3187cf108 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -775,9 +775,7 @@ def _precision_recall_curve_update( num_classes: Optional[int] = None, pos_label: Optional[int] = None, ) -> Tuple[Tensor, Tensor, int, Optional[int]]: - """Updates and returns variables required to compute the precision-recall pairs for different thresholds. - - """ + """Updates and returns variables required to compute the precision-recall pairs for different thresholds.""" if len(preds.shape) == len(target.shape): if pos_label is None: @@ -826,9 +824,7 @@ def _precision_recall_curve_compute_single_class( pos_label: int, sample_weights: Optional[Sequence] = None, ) -> Tuple[Tensor, Tensor, Tensor]: - """Computes precision-recall pairs for single class inputs. - - """ + """Computes precision-recall pairs for single class inputs.""" fps, tps, thresholds = _binary_clf_curve( preds=preds, target=target, sample_weights=sample_weights, pos_label=pos_label @@ -857,9 +853,7 @@ def _precision_recall_curve_compute_multi_class( num_classes: int, sample_weights: Optional[Sequence] = None, ) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]: - """Computes precision-recall pairs for multiclass inputs. - - """ + """Computes precision-recall pairs for multiclass inputs.""" # Recursively call per class precision, recall, thresholds = [], [], [] diff --git a/src/torchmetrics/functional/classification/roc.py b/src/torchmetrics/functional/classification/roc.py index 0df65887812..37e3673366d 100644 --- a/src/torchmetrics/functional/classification/roc.py +++ b/src/torchmetrics/functional/classification/roc.py @@ -427,9 +427,7 @@ def _roc_update( num_classes: Optional[int] = None, pos_label: Optional[int] = None, ) -> Tuple[Tensor, Tensor, int, Optional[int]]: - """Updates and returns variables required to compute the Receiver Operating Characteristic. - - """ + """Updates and returns variables required to compute the Receiver Operating Characteristic.""" return _precision_recall_curve_update(preds, target, num_classes, pos_label) @@ -440,10 +438,10 @@ def _roc_compute_single_class( pos_label: int, sample_weights: Optional[Sequence] = None, ) -> Tuple[Tensor, Tensor, Tensor]: - """Computes Receiver Operating Characteristic for single class inputs. Returns tensor with false positive - rates, tensor with true positive rates, tensor with thresholds used for computing false- and true-postive - rates. + """Computes Receiver Operating Characteristic for single class inputs. + Returns tensor with false positive rates, tensor with true positive rates, tensor with thresholds used for computing + false- and true-postive rates. """ fps, tps, thresholds = _binary_clf_curve( @@ -483,9 +481,10 @@ def _roc_compute_multi_class( num_classes: int, sample_weights: Optional[Sequence] = None, ) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]: - """Computes Receiver Operating Characteristic for multi class inputs. Returns tensor with false positive rates, - tensor with true positive rates, tensor with thresholds used for computing false- and true-postive rates. + """Computes Receiver Operating Characteristic for multi class inputs. + Returns tensor with false positive rates, tensor with true positive rates, tensor with thresholds used for computing + false- and true-postive rates. """ fpr, tpr, thresholds = [], [], [] diff --git a/src/torchmetrics/functional/regression/spearman.py b/src/torchmetrics/functional/regression/spearman.py index d604893d4d9..29ec4f4ec4e 100644 --- a/src/torchmetrics/functional/regression/spearman.py +++ b/src/torchmetrics/functional/regression/spearman.py @@ -21,7 +21,7 @@ def _find_repeats(data: Tensor) -> Tensor: - """find and return values which have repeats i.e. the same value are more than once in the tensor.""" + """Find and return values which have repeats i.e. the same value are more than once in the tensor.""" temp = data.detach().clone() temp = temp.sort()[0] @@ -36,8 +36,8 @@ def _find_repeats(data: Tensor) -> Tensor: def _rank_data(data: Tensor) -> Tensor: """Calculate the rank for each element of a tensor. - The rank refers to the indices of an element in the corresponding sorted tensor (starting from 1). - Duplicates of the same value will be assigned the mean of their rank. + The rank refers to the indices of an element in the corresponding sorted tensor (starting from 1). Duplicates of the + same value will be assigned the mean of their rank. Adopted from `Rank of element tensor`_ """ diff --git a/src/torchmetrics/functional/text/bleu.py b/src/torchmetrics/functional/text/bleu.py index 93423c091bf..5d302435c7c 100644 --- a/src/torchmetrics/functional/text/bleu.py +++ b/src/torchmetrics/functional/text/bleu.py @@ -81,7 +81,7 @@ def _bleu_score_update( target_: Sequence[Sequence[Sequence[str]]] = [[tokenizer(line) if line else [] for line in t] for t in target] preds_: Sequence[Sequence[str]] = [tokenizer(line) if line else [] for line in preds] - for (pred, targets) in zip(preds_, target_): + for pred, targets in zip(preds_, target_): preds_len += len(pred) target_len_list = [len(tgt) for tgt in targets] target_len_diff = [abs(len(pred) - x) for x in target_len_list] diff --git a/src/torchmetrics/functional/text/cer.py b/src/torchmetrics/functional/text/cer.py index 7e66208c30a..680e9c6c9ca 100644 --- a/src/torchmetrics/functional/text/cer.py +++ b/src/torchmetrics/functional/text/cer.py @@ -62,7 +62,7 @@ def _cer_compute(errors: Tensor, total: Tensor) -> Tensor: def char_error_rate(preds: Union[str, List[str]], target: Union[str, List[str]]) -> Tensor: - """character error rate is a common metric of the performance of an automatic speech recognition system. This + """Character error rate is a common metric of the performance of an automatic speech recognition system. This value indicates the percentage of characters that were incorrectly predicted. The lower the value, the better the performance of the ASR system with a CER of 0 being a perfect score. diff --git a/src/torchmetrics/functional/text/chrf.py b/src/torchmetrics/functional/text/chrf.py index 49068cdccb2..805d378bbd1 100644 --- a/src/torchmetrics/functional/text/chrf.py +++ b/src/torchmetrics/functional/text/chrf.py @@ -432,7 +432,7 @@ def _chrf_score_update( """ target_corpus, preds = _validate_inputs(target, preds) - for (pred, targets) in zip(preds, target_corpus): + for pred, targets in zip(preds, target_corpus): ( pred_char_n_grams_counts, pred_word_n_grams_counts, diff --git a/src/torchmetrics/functional/text/eed.py b/src/torchmetrics/functional/text/eed.py index d01a5c57a19..2b3d1f29bb0 100644 --- a/src/torchmetrics/functional/text/eed.py +++ b/src/torchmetrics/functional/text/eed.py @@ -144,7 +144,6 @@ def _eed_function( for w in range(1, len(ref) + 1): for i in range(0, len(hyp) + 1): - if i > 0: next_row[i] = min( next_row[i - 1] + deletion, diff --git a/src/torchmetrics/functional/text/helper.py b/src/torchmetrics/functional/text/helper.py index dbe6620fc83..26f66a7e054 100644 --- a/src/torchmetrics/functional/text/helper.py +++ b/src/torchmetrics/functional/text/helper.py @@ -66,7 +66,8 @@ class _LevenshteinEditDistance: values to hasten the calculation. The implementation follows the implemenation from - https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/lib_ter.py, where the most of this implementation + https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/lib_ter.py, + where the most of this implementation is adapted and copied from. """ diff --git a/src/torchmetrics/functional/text/sacre_bleu.py b/src/torchmetrics/functional/text/sacre_bleu.py index b9c166cd8da..1875494b6ea 100644 --- a/src/torchmetrics/functional/text/sacre_bleu.py +++ b/src/torchmetrics/functional/text/sacre_bleu.py @@ -143,7 +143,7 @@ def _tokenize_regex(cls, line: str) -> str: Return: the tokenized line """ - for (_re, repl) in cls._REGEX: + for _re, repl in cls._REGEX: line = _re.sub(repl, line) # no leading or trailing spaces, single space within words return " ".join(line.split()) @@ -252,7 +252,7 @@ def _tokenize_international(cls, line: str) -> str: Return: The tokenized string. """ - for (_re, repl) in cls._INT_REGEX: + for _re, repl in cls._INT_REGEX: line = _re.sub(repl, line) return " ".join(line.split()) diff --git a/src/torchmetrics/functional/text/ter.py b/src/torchmetrics/functional/text/ter.py index 2d99924073e..f0d2fa70d88 100644 --- a/src/torchmetrics/functional/text/ter.py +++ b/src/torchmetrics/functional/text/ter.py @@ -497,7 +497,7 @@ def _ter_update( """ target, preds = _validate_inputs(target, preds) - for (pred, tgt) in zip(preds, target): + for pred, tgt in zip(preds, target): tgt_words_: List[List[str]] = [_preprocess_sentence(_tgt, tokenizer).split() for _tgt in tgt] pred_words_: List[str] = _preprocess_sentence(pred, tokenizer).split() num_edits, tgt_length = _compute_sentence_statistics(pred_words_, tgt_words_) diff --git a/src/torchmetrics/image/fid.py b/src/torchmetrics/image/fid.py index afd56e46600..4ee51aa4738 100644 --- a/src/torchmetrics/image/fid.py +++ b/src/torchmetrics/image/fid.py @@ -50,7 +50,7 @@ def __init__( self.eval() def train(self, mode: bool) -> "NoTrainInceptionV3": - """the inception network should not be able to be switched away from evaluation mode.""" + """The inception network should not be able to be switched away from evaluation mode.""" return super().train(False) def forward(self, x: Tensor) -> Tensor: diff --git a/src/torchmetrics/image/lpip.py b/src/torchmetrics/image/lpip.py index 84ff4868c5a..b82666f3311 100644 --- a/src/torchmetrics/image/lpip.py +++ b/src/torchmetrics/image/lpip.py @@ -33,12 +33,12 @@ class _LPIPS(Module): # type: ignore class NoTrainLpips(_LPIPS): def train(self, mode: bool) -> "NoTrainLpips": - """the network should not be able to be switched away from evaluation mode.""" + """The network should not be able to be switched away from evaluation mode.""" return super().train(False) def _valid_img(img: Tensor, normalize: bool) -> bool: - """check that input is a valid image to the network.""" + """Check that input is a valid image to the network.""" value_check = img.max() <= 1.0 and img.min() >= 0.0 if normalize else img.min() >= -1 return img.ndim == 4 and img.shape[1] == 3 and value_check diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index e11b19c50ba..1d3f21d512c 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -699,7 +699,7 @@ def _load_from_state_dict( ) def _filter_kwargs(self, **kwargs: Any) -> Dict[str, Any]: - """filter kwargs such that they match the update signature of the metric.""" + """Filter kwargs such that they match the update signature of the metric.""" # filter all parameters based on update signature except those of # type VAR_POSITIONAL (*args) and VAR_KEYWORD (**kwargs) @@ -899,7 +899,6 @@ def update(self, *args: Any, **kwargs: Any) -> None: self.metric_b.update(*args, **self.metric_b._filter_kwargs(**kwargs)) def compute(self) -> Any: - # also some parsing for kwargs? if isinstance(self.metric_a, Metric): val_a = self.metric_a.compute() @@ -918,7 +917,6 @@ def compute(self) -> Any: @torch.jit.unused def forward(self, *args: Any, **kwargs: Any) -> Any: - val_a = ( self.metric_a(*args, **self.metric_a._filter_kwargs(**kwargs)) if isinstance(self.metric_a, Metric) diff --git a/src/torchmetrics/multimodal/clip_score.py b/src/torchmetrics/multimodal/clip_score.py index 234874b931b..14104f7c718 100644 --- a/src/torchmetrics/multimodal/clip_score.py +++ b/src/torchmetrics/multimodal/clip_score.py @@ -77,7 +77,6 @@ def __init__( ] = "openai/clip-vit-large-patch14", **kwargs: Any, ) -> None: - super().__init__(**kwargs) self.model, self.processor = _get_model_and_processor(model_name_or_path) self.add_state("score", torch.tensor(0.0), dist_reduce_fx="sum") diff --git a/src/torchmetrics/utilities/compute.py b/src/torchmetrics/utilities/compute.py index cc9dd06af7b..8dd218399ae 100644 --- a/src/torchmetrics/utilities/compute.py +++ b/src/torchmetrics/utilities/compute.py @@ -37,7 +37,6 @@ def _safe_xlogy(x: Tensor, y: Tensor) -> Tensor: >>> x = torch.zeros(1) >>> _safe_xlogy(x, 1/x) tensor([0.]) - """ res = x * torch.log(y) res[x == 0] = 0.0 diff --git a/tests/unittests/audio/test_pit.py b/tests/unittests/audio/test_pit.py index 48446dd7670..34513ed3bb0 100644 --- a/tests/unittests/audio/test_pit.py +++ b/tests/unittests/audio/test_pit.py @@ -86,7 +86,7 @@ def naive_implementation_pit_scipy( def _average_metric(preds: Tensor, target: Tensor, metric_func: Callable) -> Tensor: - """average the metric values. + """Average the metric values. Args: preds: predictions, shape[batch, spk, time] diff --git a/tests/unittests/bases/test_aggregation.py b/tests/unittests/bases/test_aggregation.py index 0a5a45260ac..8fb8d460d60 100644 --- a/tests/unittests/bases/test_aggregation.py +++ b/tests/unittests/bases/test_aggregation.py @@ -7,22 +7,22 @@ def compare_mean(values, weights): - """reference implementation for mean aggregation.""" + """Reference implementation for mean aggregation.""" return np.average(values.numpy(), weights=weights) def compare_sum(values, weights): - """reference implementation for sum aggregation.""" + """Reference implementation for sum aggregation.""" return np.sum(values.numpy()) def compare_min(values, weights): - """reference implementation for min aggregation.""" + """Reference implementation for min aggregation.""" return np.min(values.numpy()) def compare_max(values, weights): - """reference implementation for max aggregation.""" + """Reference implementation for max aggregation.""" return np.max(values.numpy()) @@ -32,7 +32,7 @@ class WrappedMinMetric(MinMetric): """Wrapped min metric.""" def update(self, values, weights): - """only pass values on.""" + """Only pass values on.""" super().update(values) @@ -40,7 +40,7 @@ class WrappedMaxMetric(MaxMetric): """Wrapped max metric.""" def update(self, values, weights): - """only pass values on.""" + """Only pass values on.""" super().update(values) @@ -48,7 +48,7 @@ class WrappedSumMetric(SumMetric): """Wrapped min metric.""" def update(self, values, weights): - """only pass values on.""" + """Only pass values on.""" super().update(values) @@ -56,7 +56,7 @@ class WrappedCatMetric(CatMetric): """Wrapped cat metric.""" def update(self, values, weights): - """only pass values on.""" + """Only pass values on.""" super().update(values) @@ -83,7 +83,7 @@ class TestAggregation(MetricTester): @pytest.mark.parametrize("ddp", [False, True]) @pytest.mark.parametrize("dist_sync_on_step", [False]) def test_aggreagation(self, ddp, dist_sync_on_step, metric_class, compare_fn, values, weights): - """test modular implementation.""" + """Test modular implementation.""" self.run_class_metric_test( ddp=ddp, dist_sync_on_step=dist_sync_on_step, @@ -104,7 +104,7 @@ def test_aggreagation(self, ddp, dist_sync_on_step, metric_class, compare_fn, va @pytest.mark.parametrize("nan_strategy", ["error", "warn"]) @pytest.mark.parametrize("metric_class", [MinMetric, MaxMetric, SumMetric, MeanMetric, CatMetric]) def test_nan_error(value, nan_strategy, metric_class): - """test correct errors are raised.""" + """Test correct errors are raised.""" metric = metric_class(nan_strategy=nan_strategy) if nan_strategy == "error": with pytest.raises(RuntimeError, match="Encounted `nan` values in tensor"): @@ -141,7 +141,7 @@ def test_nan_error(value, nan_strategy, metric_class): ], ) def test_nan_expected(metric_class, nan_strategy, value, expected): - """test that nan values are handled correctly.""" + """Test that nan values are handled correctly.""" metric = metric_class(nan_strategy=nan_strategy) metric.update(value.clone()) out = metric.compute() @@ -150,7 +150,7 @@ def test_nan_expected(metric_class, nan_strategy, value, expected): @pytest.mark.parametrize("metric_class", [MinMetric, MaxMetric, SumMetric, MeanMetric, CatMetric]) def test_error_on_wrong_nan_strategy(metric_class): - """test error raised on wrong nan_strategy argument.""" + """Test error raised on wrong nan_strategy argument.""" with pytest.raises(ValueError, match="Arg `nan_strategy` should either .*"): metric_class(nan_strategy=[]) @@ -160,7 +160,7 @@ def test_error_on_wrong_nan_strategy(metric_class): "weights, expected", [(1, 11.5), (torch.ones(2, 1, 1), 11.5), (torch.tensor([1, 2]).reshape(2, 1, 1), 13.5)] ) def test_mean_metric_broadcasting(weights, expected): - """check that weight broadcasting works for mean metric.""" + """Check that weight broadcasting works for mean metric.""" values = torch.arange(24).reshape(2, 3, 4) avg = MeanMetric() diff --git a/tests/unittests/bases/test_composition.py b/tests/unittests/bases/test_composition.py index f028bcd7d8f..9a85984288d 100644 --- a/tests/unittests/bases/test_composition.py +++ b/tests/unittests/bases/test_composition.py @@ -533,7 +533,7 @@ def test_metrics_getitem(value, idx, expected_result): def test_compositional_metrics_update(): - """test update method for compositional metrics.""" + """Test update method for compositional metrics.""" compos = DummyMetric(5) + DummyMetric(4) assert isinstance(compos, CompositionalMetric) diff --git a/tests/unittests/bases/test_ddp.py b/tests/unittests/bases/test_ddp.py index d2262204602..009ad1c9838 100644 --- a/tests/unittests/bases/test_ddp.py +++ b/tests/unittests/bases/test_ddp.py @@ -168,9 +168,7 @@ def verify_metric(metric, i, world_size): steps = 5 for i in range(steps): - if metric._is_synced: - with pytest.raises(TorchMetricsUserError, match="The Metric shouldn't be synced when performing"): metric(i) diff --git a/tests/unittests/bases/test_metric.py b/tests/unittests/bases/test_metric.py index 46ea1ca988b..47f3c74287d 100644 --- a/tests/unittests/bases/test_metric.py +++ b/tests/unittests/bases/test_metric.py @@ -224,7 +224,7 @@ def test_pickle(tmpdir): def test_state_dict(tmpdir): - """test that metric states can be removed and added to state dict.""" + """Test that metric states can be removed and added to state dict.""" metric = DummyMetric() assert metric.state_dict() == OrderedDict() metric.persistent(True) @@ -234,7 +234,7 @@ def test_state_dict(tmpdir): def test_load_state_dict(tmpdir): - """test that metric states can be loaded with state dict.""" + """Test that metric states can be loaded with state dict.""" metric = DummyMetricSum() metric.persistent(True) metric.update(5) @@ -244,7 +244,7 @@ def test_load_state_dict(tmpdir): def test_child_metric_state_dict(): - """test that child metric states will be added to parent state dict.""" + """Test that child metric states will be added to parent state dict.""" class TestModule(Module): def __init__(self): @@ -286,7 +286,7 @@ def test_device_and_dtype_transfer(tmpdir): def test_warning_on_compute_before_update(): - """test that an warning is raised if user tries to call compute before update.""" + """Test that an warning is raised if user tries to call compute before update.""" metric = DummyMetricSum() # make sure everything is fine with forward @@ -309,7 +309,7 @@ def test_warning_on_compute_before_update(): def test_metric_scripts(): - """test that metrics are scriptable.""" + """Test that metrics are scriptable.""" torch.jit.script(DummyMetric()) torch.jit.script(DummyMetricSum()) diff --git a/tests/unittests/classification/test_auroc.py b/tests/unittests/classification/test_auroc.py index 2b344ff3721..1d866719e89 100644 --- a/tests/unittests/classification/test_auroc.py +++ b/tests/unittests/classification/test_auroc.py @@ -363,7 +363,7 @@ def test_multilabel_auroc_threshold_arg(self, input, average): ) @pytest.mark.parametrize("thresholds", [None, 100, [0.3, 0.5, 0.7, 0.9], torch.linspace(0, 1, 10)]) def test_valid_input_thresholds(metric, thresholds): - """test valid formats of the threshold argument.""" + """Test valid formats of the threshold argument.""" with pytest.warns(None) as record: metric(thresholds=thresholds) assert len(record) == 0 diff --git a/tests/unittests/classification/test_average_precision.py b/tests/unittests/classification/test_average_precision.py index 411c773a488..24a364f984e 100644 --- a/tests/unittests/classification/test_average_precision.py +++ b/tests/unittests/classification/test_average_precision.py @@ -368,7 +368,7 @@ def test_multilabel_average_precision_threshold_arg(self, input, average): ) @pytest.mark.parametrize("thresholds", [None, 100, [0.3, 0.5, 0.7, 0.9], torch.linspace(0, 1, 10)]) def test_valid_input_thresholds(metric, thresholds): - """test valid formats of the threshold argument.""" + """Test valid formats of the threshold argument.""" with pytest.warns(None) as record: metric(thresholds=thresholds) assert len(record) == 0 diff --git a/tests/unittests/classification/test_precision_recall_curve.py b/tests/unittests/classification/test_precision_recall_curve.py index 932efb7e88d..d9762c11809 100644 --- a/tests/unittests/classification/test_precision_recall_curve.py +++ b/tests/unittests/classification/test_precision_recall_curve.py @@ -356,7 +356,7 @@ def test_multilabel_precision_recall_curve_threshold_arg(self, input, threshold_ ) @pytest.mark.parametrize("thresholds", [None, 100, [0.3, 0.5, 0.7, 0.9], torch.linspace(0, 1, 10)]) def test_valid_input_thresholds(metric, thresholds): - """test valid formats of the threshold argument.""" + """Test valid formats of the threshold argument.""" with pytest.warns(None) as record: metric(thresholds=thresholds) assert len(record) == 0 diff --git a/tests/unittests/classification/test_recall_at_fixed_precision.py b/tests/unittests/classification/test_recall_at_fixed_precision.py index a1242fe1fd9..a40df60a369 100644 --- a/tests/unittests/classification/test_recall_at_fixed_precision.py +++ b/tests/unittests/classification/test_recall_at_fixed_precision.py @@ -392,7 +392,7 @@ def test_multilabel_recall_at_fixed_precision_threshold_arg(self, input, min_pre ) @pytest.mark.parametrize("thresholds", [None, 100, [0.3, 0.5, 0.7, 0.9], torch.linspace(0, 1, 10)]) def test_valid_input_thresholds(metric, thresholds): - """test valid formats of the threshold argument.""" + """Test valid formats of the threshold argument.""" with pytest.warns(None) as record: metric(min_precision=0.5, thresholds=thresholds) assert len(record) == 0 diff --git a/tests/unittests/classification/test_roc.py b/tests/unittests/classification/test_roc.py index d0328aa5e2e..3cfe9c443a2 100644 --- a/tests/unittests/classification/test_roc.py +++ b/tests/unittests/classification/test_roc.py @@ -345,7 +345,7 @@ def test_multilabel_roc_threshold_arg(self, input, threshold_fn): ) @pytest.mark.parametrize("thresholds", [None, 100, [0.3, 0.5, 0.7, 0.9], torch.linspace(0, 1, 10)]) def test_valid_input_thresholds(metric, thresholds): - """test valid formats of the threshold argument.""" + """Test valid formats of the threshold argument.""" with pytest.warns(None) as record: metric(thresholds=thresholds) assert len(record) == 0 diff --git a/tests/unittests/classification/test_specificity.py b/tests/unittests/classification/test_specificity.py index 26343e2ecdb..966573f7221 100644 --- a/tests/unittests/classification/test_specificity.py +++ b/tests/unittests/classification/test_specificity.py @@ -34,7 +34,7 @@ def _calc_specificity(tn, fp): - """safely calculate specificity.""" + """Safely calculate specificity.""" denom = tn + fp if np.isscalar(tn): denom = 1.0 if denom == 0 else denom diff --git a/tests/unittests/detection/test_map.py b/tests/unittests/detection/test_map.py index bd63e9676ed..fabe2aaf1ac 100644 --- a/tests/unittests/detection/test_map.py +++ b/tests/unittests/detection/test_map.py @@ -228,44 +228,45 @@ def _compare_fn(preds, target) -> dict: """Comparison function for map implementation. - Official pycocotools results calculated from a subset of https://github.com/cocodataset/cocoapi/tree/master/results - All classes - Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.637 - Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.859 - Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.761 - Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.622 - Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.800 - Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.635 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.432 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.652 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.652 - Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.673 - Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.800 - Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.633 - - Class 0 - Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.725 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.780 - - Class 1 - Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.800 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.800 - - Class 2 - Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.454 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.450 - - Class 3 - Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = -1.000 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = -1.000 - - Class 4 - Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.650 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.650 - - Class 49 - Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.556 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.580 + Official pycocotools results calculated from a subset of + https://github.com/cocodataset/cocoapi/tree/master/results + All classes + Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.637 + Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.859 + Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.761 + Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.622 + Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.800 + Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.635 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.432 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.652 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.652 + Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.673 + Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.800 + Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.633 + + Class 0 + Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.725 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.780 + + Class 1 + Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.800 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.800 + + Class 2 + Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.454 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.450 + + Class 3 + Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = -1.000 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = -1.000 + + Class 4 + Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.650 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.650 + + Class 49 + Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.556 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.580 """ return { "map": Tensor([0.637]), @@ -288,19 +289,20 @@ def _compare_fn(preds, target) -> dict: def _compare_fn_segm(preds, target) -> dict: """Comparison function for map implementation for instance segmentation. - Official pycocotools results calculated from a subset of https://github.com/cocodataset/cocoapi/tree/master/results - Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.352 - Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.752 - Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.252 - Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000 - Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000 - Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.352 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.350 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.350 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.350 - Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000 - Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000 - Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.350 + Official pycocotools results calculated from a subset of + https://github.com/cocodataset/cocoapi/tree/master/results + Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.352 + Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.752 + Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.252 + Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000 + Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000 + Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.352 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.350 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.350 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.350 + Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000 + Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000 + Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.350 """ return { "map": Tensor([0.352]), diff --git a/tests/unittests/image/test_fid.py b/tests/unittests/image/test_fid.py index 2e658a394ff..79d77b9cf53 100644 --- a/tests/unittests/image/test_fid.py +++ b/tests/unittests/image/test_fid.py @@ -28,7 +28,7 @@ @pytest.mark.parametrize("matrix_size", [2, 10, 100, 500]) def test_matrix_sqrt(matrix_size): - """test that metrix sqrt function works as expected.""" + """Test that metrix sqrt function works as expected.""" def generate_cov(n): data = torch.randn(2 * n, n) @@ -123,7 +123,7 @@ def __len__(self): @pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") @pytest.mark.parametrize("equal_size", [False, True]) def test_compare_fid(tmpdir, equal_size, feature=2048): - """check that the hole pipeline give the same result as torch-fidelity.""" + """Check that the hole pipeline give the same result as torch-fidelity.""" from torch_fidelity import calculate_metrics metric = FrechetInceptionDistance(feature=feature).cuda() diff --git a/tests/unittests/image/test_inception.py b/tests/unittests/image/test_inception.py index 93198b874c2..7477fb7d3a9 100644 --- a/tests/unittests/image/test_inception.py +++ b/tests/unittests/image/test_inception.py @@ -108,7 +108,7 @@ def __len__(self): @pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") @pytest.mark.parametrize("compute_on_cpu", [True, False]) def test_compare_is(tmpdir, compute_on_cpu): - """check that the hole pipeline give the same result as torch-fidelity.""" + """Check that the hole pipeline give the same result as torch-fidelity.""" from torch_fidelity import calculate_metrics metric = InceptionScore(splits=1, compute_on_cpu=compute_on_cpu).cuda() diff --git a/tests/unittests/image/test_kid.py b/tests/unittests/image/test_kid.py index 69ebc320766..7f78e181cc8 100644 --- a/tests/unittests/image/test_kid.py +++ b/tests/unittests/image/test_kid.py @@ -105,7 +105,7 @@ def test_kid_extra_parameters(): @pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") @pytest.mark.parametrize("feature", [64, 192, 768, 2048]) def test_kid_same_input(feature): - """test that the metric works.""" + """Test that the metric works.""" metric = KernelInceptionDistance(feature=feature, subsets=5, subset_size=2) for _ in range(2): @@ -134,7 +134,7 @@ def __len__(self): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test is too slow without gpu") @pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") def test_compare_kid(tmpdir, feature=2048): - """check that the hole pipeline give the same result as torch-fidelity.""" + """Check that the hole pipeline give the same result as torch-fidelity.""" from torch_fidelity import calculate_metrics metric = KernelInceptionDistance(feature=feature, subsets=1, subset_size=100).cuda() diff --git a/tests/unittests/image/test_lpips.py b/tests/unittests/image/test_lpips.py index e3194b605a5..675ad39e15a 100644 --- a/tests/unittests/image/test_lpips.py +++ b/tests/unittests/image/test_lpips.py @@ -35,7 +35,7 @@ def _compare_fn(img1: Tensor, img2: Tensor, net_type: str, normalize: bool, reduction: str = "mean") -> Tensor: - """comparison function for tm implementation.""" + """Comparison function for tm implementation.""" ref = LPIPS_reference(net=net_type) res = ref(img1, img2, normalize=normalize).detach().cpu().numpy() if reduction == "mean": @@ -51,7 +51,7 @@ class TestLPIPS(MetricTester): @pytest.mark.parametrize("normalize", [False, True]) @pytest.mark.parametrize("ddp", [True, False]) def test_lpips(self, net_type, normalize, ddp): - """test modular implementation for correctness.""" + """Test modular implementation for correctness.""" self.run_class_metric_test( ddp=ddp, preds=_inputs.img1, @@ -65,7 +65,7 @@ def test_lpips(self, net_type, normalize, ddp): ) def test_lpips_differentiability(self): - """test for differentiability of LPIPS metric.""" + """Test for differentiability of LPIPS metric.""" self.run_differentiability_test( preds=_inputs.img1, target=_inputs.img2, metric_module=LearnedPerceptualImagePatchSimilarity ) @@ -73,12 +73,12 @@ def test_lpips_differentiability(self): # LPIPS half + cpu does not work due to missing support in torch.min @pytest.mark.xfail(reason="LPIPS metric does not support cpu + half precision") def test_lpips_half_cpu(self): - """test for half + cpu support.""" + """Test for half + cpu support.""" self.run_precision_test_cpu(_inputs.img1, _inputs.img2, LearnedPerceptualImagePatchSimilarity) @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") def test_lpips_half_gpu(self): - """test for half + gpu support.""" + """Test for half + gpu support.""" self.run_precision_test_gpu(_inputs.img1, _inputs.img2, LearnedPerceptualImagePatchSimilarity) @@ -103,7 +103,7 @@ def test_error_on_wrong_init(): ], ) def test_error_on_wrong_update(inp1, inp2): - """test error is raised on wrong input to update method.""" + """Test error is raised on wrong input to update method.""" metric = LearnedPerceptualImagePatchSimilarity() with pytest.raises(ValueError, match="Expected both input arguments to be normalized tensors .*"): metric(inp1, inp2) diff --git a/tests/unittests/image/test_tv.py b/tests/unittests/image/test_tv.py index 6a65a516277..2d934f4cc22 100644 --- a/tests/unittests/image/test_tv.py +++ b/tests/unittests/image/test_tv.py @@ -108,7 +108,7 @@ def test_sam_half_gpu(self, preds, target, reduction): def test_correct_args(): - """that that arguments have the right type and sizes.""" + """That that arguments have the right type and sizes.""" with pytest.raises(ValueError, match="Expected argument `reduction`.*"): _ = TotalVariation(reduction="diff") diff --git a/tests/unittests/pairwise/test_pairwise_distance.py b/tests/unittests/pairwise/test_pairwise_distance.py index 61dbac6714f..ae6b3a50687 100644 --- a/tests/unittests/pairwise/test_pairwise_distance.py +++ b/tests/unittests/pairwise/test_pairwise_distance.py @@ -47,7 +47,7 @@ def _sk_metric(x, y, sk_fn, reduction): - """comparison function.""" + """Comparison function.""" x = x.view(-1, extra_dim).numpy() y = y.view(-1, extra_dim).numpy() res = sk_fn(x, y) @@ -76,12 +76,12 @@ def _sk_metric(x, y, sk_fn, reduction): ) @pytest.mark.parametrize("reduction", ["sum", "mean", None]) class TestPairwise(MetricTester): - """test pairwise implementations.""" + """Test pairwise implementations.""" atol = 1e-4 def test_pairwise_functional(self, x, y, metric_functional, sk_fn, reduction): - """test functional pairwise implementations.""" + """Test functional pairwise implementations.""" self.run_functional_metric_test( preds=x, target=y, @@ -91,14 +91,14 @@ def test_pairwise_functional(self, x, y, metric_functional, sk_fn, reduction): ) def test_pairwise_half_cpu(self, x, y, metric_functional, sk_fn, reduction): - """test half precision support on cpu.""" + """Test half precision support on cpu.""" if metric_functional == pairwise_euclidean_distance: pytest.xfail("pairwise_euclidean_distance metric does not support cpu + half precision") self.run_precision_test_cpu(x, y, None, metric_functional, metric_args={"reduction": reduction}) @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") def test_pairwise_half_gpu(self, x, y, metric_functional, sk_fn, reduction): - """test half precision support on gpu.""" + """Test half precision support on gpu.""" self.run_precision_test_gpu(x, y, None, metric_functional, metric_args={"reduction": reduction}) @@ -127,7 +127,7 @@ def test_error_on_wrong_shapes(metric): ], ) def test_precison_case(metric_functional, sk_fn): - """test that metrics are robust towars cases where high precision is needed.""" + """Test that metrics are robust towars cases where high precision is needed.""" x = torch.tensor([[772.0, 112.0], [772.20001, 112.0]]) res1 = metric_functional(x, zero_diagonal=False) res2 = sk_fn(x) diff --git a/tests/unittests/regression/test_spearman.py b/tests/unittests/regression/test_spearman.py index 2fb2e1cb992..a883161f706 100644 --- a/tests/unittests/regression/test_spearman.py +++ b/tests/unittests/regression/test_spearman.py @@ -62,7 +62,7 @@ ], ) def test_ranking(preds, target): - """test that ranking function works as expected.""" + """Test that ranking function works as expected.""" for p, t in zip(preds, target): scipy_ranking = [rankdata(p.numpy()), rankdata(t.numpy())] tm_ranking = [_rank_data(p), _rank_data(t)] diff --git a/tests/unittests/text/test_cer.py b/tests/unittests/text/test_cer.py index 8765f50d32c..8906db02c20 100644 --- a/tests/unittests/text/test_cer.py +++ b/tests/unittests/text/test_cer.py @@ -28,12 +28,12 @@ def compare_fn(preds: Union[str, List[str]], target: Union[str, List[str]]): ], ) class TestCharErrorRate(TextTester): - """test class for character error rate.""" + """Test class for character error rate.""" @pytest.mark.parametrize("ddp", [False, True]) @pytest.mark.parametrize("dist_sync_on_step", [False, True]) def test_cer_class(self, ddp, dist_sync_on_step, preds, targets): - """test modular version of cer.""" + """Test modular version of cer.""" self.run_class_metric_test( ddp=ddp, preds=preds, @@ -44,7 +44,7 @@ def test_cer_class(self, ddp, dist_sync_on_step, preds, targets): ) def test_cer_functional(self, preds, targets): - """test functional version of cer.""" + """Test functional version of cer.""" self.run_functional_metric_test( preds, targets, @@ -53,7 +53,7 @@ def test_cer_functional(self, preds, targets): ) def test_cer_differentiability(self, preds, targets): - """test differentiability of cer metric.""" + """Test differentiability of cer metric.""" self.run_differentiability_test( preds=preds, targets=targets, diff --git a/tests/unittests/text/test_infolm.py b/tests/unittests/text/test_infolm.py index 1fd190cde00..75c23eaf58f 100644 --- a/tests/unittests/text/test_infolm.py +++ b/tests/unittests/text/test_infolm.py @@ -31,7 +31,8 @@ def reference_infolm_score(preds, target, model_name, information_measure, idf, """Currently, a reference package is not available. We, therefore, are enforced to relied on hard-coded results for now. The results below were generated using scripts - in https://github.com/stancld/infolm-docker. + in + https://github.com/stancld/infolm-docker. """ if model_name != "google/bert_uncased_L-2_H-128_A-2": raise ValueError( diff --git a/tests/unittests/text/test_mer.py b/tests/unittests/text/test_mer.py index 0e2c3a93751..d62303de571 100644 --- a/tests/unittests/text/test_mer.py +++ b/tests/unittests/text/test_mer.py @@ -30,7 +30,6 @@ class TestMatchErrorRate(TextTester): @pytest.mark.parametrize("ddp", [False, True]) @pytest.mark.parametrize("dist_sync_on_step", [False, True]) def test_mer_class(self, ddp, dist_sync_on_step, preds, targets): - self.run_class_metric_test( ddp=ddp, preds=preds, @@ -41,7 +40,6 @@ def test_mer_class(self, ddp, dist_sync_on_step, preds, targets): ) def test_mer_functional(self, preds, targets): - self.run_functional_metric_test( preds, targets, @@ -50,7 +48,6 @@ def test_mer_functional(self, preds, targets): ) def test_mer_differentiability(self, preds, targets): - self.run_differentiability_test( preds=preds, targets=targets, diff --git a/tests/unittests/text/test_wer.py b/tests/unittests/text/test_wer.py index 4f80c0457d8..96af3d515b2 100644 --- a/tests/unittests/text/test_wer.py +++ b/tests/unittests/text/test_wer.py @@ -30,7 +30,6 @@ class TestWER(TextTester): @pytest.mark.parametrize("ddp", [False, True]) @pytest.mark.parametrize("dist_sync_on_step", [False, True]) def test_wer_class(self, ddp, dist_sync_on_step, preds, targets): - self.run_class_metric_test( ddp=ddp, preds=preds, @@ -41,7 +40,6 @@ def test_wer_class(self, ddp, dist_sync_on_step, preds, targets): ) def test_wer_functional(self, preds, targets): - self.run_functional_metric_test( preds, targets, @@ -50,7 +48,6 @@ def test_wer_functional(self, preds, targets): ) def test_wer_differentiability(self, preds, targets): - self.run_differentiability_test( preds=preds, targets=targets, diff --git a/tests/unittests/text/test_wil.py b/tests/unittests/text/test_wil.py index b0fa5473dca..8e19bb8c44c 100644 --- a/tests/unittests/text/test_wil.py +++ b/tests/unittests/text/test_wil.py @@ -26,7 +26,6 @@ class TestWordInfoLost(TextTester): @pytest.mark.parametrize("ddp", [False, True]) @pytest.mark.parametrize("dist_sync_on_step", [False, True]) def test_wil_class(self, ddp, dist_sync_on_step, preds, targets): - self.run_class_metric_test( ddp=ddp, preds=preds, @@ -37,7 +36,6 @@ def test_wil_class(self, ddp, dist_sync_on_step, preds, targets): ) def test_wil_functional(self, preds, targets): - self.run_functional_metric_test( preds, targets, @@ -46,7 +44,6 @@ def test_wil_functional(self, preds, targets): ) def test_wil_differentiability(self, preds, targets): - self.run_differentiability_test( preds=preds, targets=targets, diff --git a/tests/unittests/text/test_wip.py b/tests/unittests/text/test_wip.py index 2f0720d224c..8bdaf0b9878 100644 --- a/tests/unittests/text/test_wip.py +++ b/tests/unittests/text/test_wip.py @@ -26,7 +26,6 @@ class TestWordInfoPreserved(TextTester): @pytest.mark.parametrize("ddp", [False, True]) @pytest.mark.parametrize("dist_sync_on_step", [False, True]) def test_wip_class(self, ddp, dist_sync_on_step, preds, targets): - self.run_class_metric_test( ddp=ddp, preds=preds, @@ -37,7 +36,6 @@ def test_wip_class(self, ddp, dist_sync_on_step, preds, targets): ) def test_wip_functional(self, preds, targets): - self.run_functional_metric_test( preds, targets, @@ -46,7 +44,6 @@ def test_wip_functional(self, preds, targets): ) def test_wip_differentiability(self, preds, targets): - self.run_differentiability_test( preds=preds, targets=targets, diff --git a/tests/unittests/utilities/test_utilities.py b/tests/unittests/utilities/test_utilities.py index b3b4163bbfe..2f25f0a9476 100644 --- a/tests/unittests/utilities/test_utilities.py +++ b/tests/unittests/utilities/test_utilities.py @@ -110,7 +110,7 @@ def test_flatten_dict(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires gpu") def test_bincount(): - """test that bincount works in deterministic setting on GPU.""" + """Test that bincount works in deterministic setting on GPU.""" torch.use_deterministic_algorithms(True) x = torch.randint(10, size=(100,)) diff --git a/tests/unittests/wrappers/test_bootstrapping.py b/tests/unittests/wrappers/test_bootstrapping.py index fa57afe6885..242c6681bec 100644 --- a/tests/unittests/wrappers/test_bootstrapping.py +++ b/tests/unittests/wrappers/test_bootstrapping.py @@ -58,7 +58,7 @@ def _sample_checker(old_samples, new_samples, op: operator, threshold: int): @pytest.mark.parametrize("sampling_strategy", ["poisson", "multinomial"]) def test_bootstrap_sampler(sampling_strategy): - """make sure that the bootstrap sampler works as intended.""" + """Make sure that the bootstrap sampler works as intended.""" old_samples = torch.randn(20, 2) # make sure that the new samples are only made up of old samples @@ -102,7 +102,6 @@ def test_bootstrap(device, sampling_strategy, metric, sk_metric): bootstrapper.update(p, t) for i, o in enumerate(bootstrapper.out): - collected_preds[i].append(o[0]) collected_target[i].append(o[1]) diff --git a/tests/unittests/wrappers/test_minmax.py b/tests/unittests/wrappers/test_minmax.py index 8be6b288194..10cc018b647 100644 --- a/tests/unittests/wrappers/test_minmax.py +++ b/tests/unittests/wrappers/test_minmax.py @@ -15,10 +15,10 @@ class TestingMinMaxMetric(MinMaxMetric): - """wrap metric to fit testing framework.""" + """Wrap metric to fit testing framework.""" def compute(self): - """instead of returning dict, return as list.""" + """Instead of returning dict, return as list.""" output_dict = super().compute() return [output_dict["raw"], output_dict["min"], output_dict["max"]] @@ -28,7 +28,7 @@ def forward(self, *args, **kwargs): def compare_fn(preds, target, base_fn): - """comparing function for minmax wrapper.""" + """Comparing function for minmax wrapper.""" v_min, v_max = 1e6, -1e6 # pick some very large numbers for comparing for i in range(NUM_BATCHES): val = base_fn(preds[: (i + 1) * BATCH_SIZE], target[: (i + 1) * BATCH_SIZE]).cpu().numpy() @@ -96,7 +96,7 @@ def test_minmax_wrapper(self, preds, target, base_metric, ddp): ], ) def test_basic_example(preds, labels, raws, maxs, mins) -> None: - """tests that both min and max versions of MinMaxMetric operate correctly after calling compute.""" + """Tests that both min and max versions of MinMaxMetric operate correctly after calling compute.""" acc = BinaryAccuracy() min_max_acc = MinMaxMetric(acc) labels = Tensor(labels).long() @@ -111,13 +111,13 @@ def test_basic_example(preds, labels, raws, maxs, mins) -> None: def test_no_base_metric() -> None: - """tests that ValueError is raised when no base_metric is passed.""" + """Tests that ValueError is raised when no base_metric is passed.""" with pytest.raises(ValueError, match=r"Expected base metric to be an instance .*"): MinMaxMetric([]) def test_no_scalar_compute() -> None: - """tests that an assertion error is thrown if the wrapped basemetric gives a non-scalar on compute.""" + """Tests that an assertion error is thrown if the wrapped basemetric gives a non-scalar on compute.""" min_max_nsm = MinMaxMetric(BinaryConfusionMatrix(num_classes=2)) with pytest.raises(RuntimeError, match=r"Returned value from base metric should be a scalar .*"): From f094071a8502623091c56a220e14e3b43663e188 Mon Sep 17 00:00:00 2001 From: Jirka Date: Fri, 30 Jun 2023 10:03:39 +0200 Subject: [PATCH 06/11] precommit --- src/torchmetrics/classification/accuracy.py | 2 +- src/torchmetrics/functional/classification/precision_recall.py | 1 - src/torchmetrics/functional/classification/specificity.py | 1 - 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/torchmetrics/classification/accuracy.py b/src/torchmetrics/classification/accuracy.py index c69ac902475..97bb6ddc901 100644 --- a/src/torchmetrics/classification/accuracy.py +++ b/src/torchmetrics/classification/accuracy.py @@ -28,7 +28,7 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn -from torchmetrics.utilities.enums import AverageMethod +from torchmetrics.utilities.enums import AverageMethod, DataType from torchmetrics.classification.stat_scores import ( # isort:skip BinaryStatScores, diff --git a/src/torchmetrics/functional/classification/precision_recall.py b/src/torchmetrics/functional/classification/precision_recall.py index be1855a4b5e..4bb1966f693 100644 --- a/src/torchmetrics/functional/classification/precision_recall.py +++ b/src/torchmetrics/functional/classification/precision_recall.py @@ -31,7 +31,6 @@ _multilabel_stat_scores_tensor_validation, _multilabel_stat_scores_update, _reduce_stat_scores, - _stat_scores_update, ) from torchmetrics.utilities.compute import _safe_divide from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod diff --git a/src/torchmetrics/functional/classification/specificity.py b/src/torchmetrics/functional/classification/specificity.py index ba2e0e226f5..e161d5a0d21 100644 --- a/src/torchmetrics/functional/classification/specificity.py +++ b/src/torchmetrics/functional/classification/specificity.py @@ -31,7 +31,6 @@ _multilabel_stat_scores_tensor_validation, _multilabel_stat_scores_update, _reduce_stat_scores, - _stat_scores_update, ) from torchmetrics.utilities.compute import _safe_divide from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod From 2706cdb096452965d0446d80dee29355e4a296ab Mon Sep 17 00:00:00 2001 From: Jirka Date: Mon, 3 Jul 2023 13:39:50 +0200 Subject: [PATCH 07/11] _recall_compute --- .../classification/precision_recall.py | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/src/torchmetrics/functional/classification/precision_recall.py b/src/torchmetrics/functional/classification/precision_recall.py index 4bb1966f693..4cddbb434df 100644 --- a/src/torchmetrics/functional/classification/precision_recall.py +++ b/src/torchmetrics/functional/classification/precision_recall.py @@ -64,6 +64,57 @@ def _precision_recall_reduce( return _safe_divide(weights * score, weights.sum(-1, keepdim=True)).sum(-1) +def _recall_compute( + tp: Tensor, + fp: Tensor, + fn: Tensor, + average: Optional[str], + mdmc_average: Optional[str], +) -> Tensor: + """Computes precision from the stat scores: true positives, false positives, true negatives, false negatives. + + Args: + tp: True positives + fp: False positives + fn: False negatives + average: Defines the reduction that is applied + mdmc_average: Defines how averaging is done for multi-dimensional multi-class inputs (on top of the + ``average`` parameter) + + Example: + >>> from torchmetrics.functional.classification.stat_scores import _stat_scores_update + >>> preds = torch.tensor([2, 0, 2, 1]) + >>> target = torch.tensor([1, 1, 2, 0]) + >>> tp, fp, tn, fn = _stat_scores_update(preds, target, reduce='macro', num_classes=3) + >>> _recall_compute(tp, fp, fn, average='macro', mdmc_average=None) + tensor(0.3333) + >>> tp, fp, tn, fn = _stat_scores_update(preds, target, reduce='micro') + >>> _recall_compute(tp, fp, fn, average='micro', mdmc_average=None) + tensor(0.2500) + """ + numerator = tp.clone() + denominator = tp + fn + + if average == AverageMethod.MACRO and mdmc_average != MDMCAverageMethod.SAMPLEWISE: + cond = tp + fp + fn == 0 + numerator = numerator[~cond] + denominator = denominator[~cond] + + if average == AverageMethod.NONE and mdmc_average != MDMCAverageMethod.SAMPLEWISE: + # a class is not present if there exists no TPs, no FPs, and no FNs + meaningless_indeces = ((tp | fn | fp) == 0).nonzero().cpu() + numerator[meaningless_indeces, ...] = -1 + denominator[meaningless_indeces, ...] = -1 + + return _reduce_stat_scores( + numerator=numerator, + denominator=denominator, + weights=None if average != AverageMethod.WEIGHTED else tp + fn, + average=average, + mdmc_average=mdmc_average, + ) + + def binary_precision( preds: Tensor, target: Tensor, From 087c69c6cdc59b98cd48324efe8c7ab1817931be Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+borda@users.noreply.github.com> Date: Tue, 9 May 2023 11:35:38 +0200 Subject: [PATCH 08/11] replace local adjustment script with external (#1758) (cherry picked from commit 0bf4dc0b806c6a7155fbe02ea98faaacf1e6bab1) --- .github/workflows/ci-integrate.yml | 9 +-- .github/workflows/ci-tests-full.yml | 5 +- Makefile | 4 +- requirements/adjust-versions.py | 86 ----------------------------- 4 files changed, 10 insertions(+), 94 deletions(-) delete mode 100644 requirements/adjust-versions.py diff --git a/.github/workflows/ci-integrate.yml b/.github/workflows/ci-integrate.yml index f406d83104c..6261b156de4 100644 --- a/.github/workflows/ci-integrate.yml +++ b/.github/workflows/ci-integrate.yml @@ -63,10 +63,11 @@ jobs: - name: Install all dependencies run: | set -e - pip install -r requirements/test.txt - pip install -r requirements/integrate.txt --find-links $PYTORCH_URL --upgrade-strategy eager - python ./requirements/adjust-versions.py requirements.txt - python ./requirements/adjust-versions.py requirements/image.txt + curl https://raw.githubusercontent.com/Lightning-AI/utilities/main/scripts/adjust-torch-versions.py -o adjust-torch-versions.py + pip install -r requirements/test.txt -r requirements/integrate.txt \ + --find-links $PYTORCH_URL --upgrade-strategy eager + python adjust-torch-versions.py requirements.txt + python adjust-torch-versions.py requirements/image.txt cat requirements.txt pip install -e . --find-links $PYTORCH_URL pip list diff --git a/.github/workflows/ci-tests-full.yml b/.github/workflows/ci-tests-full.yml index 343115fe037..d44c9b2fab3 100644 --- a/.github/workflows/ci-tests-full.yml +++ b/.github/workflows/ci-tests-full.yml @@ -107,8 +107,9 @@ jobs: - name: Install all dependencies run: | - python ./requirements/adjust-versions.py requirements.txt - python ./requirements/adjust-versions.py requirements/image.txt + curl https://raw.githubusercontent.com/Lightning-AI/utilities/main/scripts/adjust-torch-versions.py -o adjust-torch-versions.py + python adjust-torch-versions.py requirements.txt + python adjust-torch-versions.py requirements/image.txt pip install --requirement requirements/devel.txt -U --find-links $PYTORCH_URL pip list diff --git a/Makefile b/Makefile index 33aac7b5da0..67621f5dc11 100644 --- a/Makefile +++ b/Makefile @@ -31,8 +31,8 @@ docs: clean python -m sphinx -b html -W --keep-going docs/source docs/build env: - pip install -e . - python ./requirements/adjust-versions.py requirements/image.txt + export FREEZE_REQUIREMENTS=1 + pip install -e . -U pip install -r requirements/devel.txt data: diff --git a/requirements/adjust-versions.py b/requirements/adjust-versions.py deleted file mode 100644 index 437424a24c2..00000000000 --- a/requirements/adjust-versions.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -import os -import re -import sys -from typing import Dict, Optional - -from packaging.version import Version - -VERSIONS = [ - dict(torch="1.13.0", torchvision="0.14.0", torchtext="0.14.0"), # nightly - dict(torch="1.12.1", torchvision="0.13.1", torchtext="0.13.1"), # stable - dict(torch="1.12.0", torchvision="0.13.0", torchtext="0.13.0"), - dict(torch="1.11.0", torchvision="0.12.0", torchtext="0.12.0"), - dict(torch="1.10.2", torchvision="0.11.3", torchtext="0.11.2"), - dict(torch="1.10.1", torchvision="0.11.2", torchtext="0.11.1"), - dict(torch="1.10.0", torchvision="0.11.1", torchtext="0.11.0"), - dict(torch="1.9.1", torchvision="0.10.1", torchtext="0.10.1"), - dict(torch="1.9.0", torchvision="0.10.0", torchtext="0.10.0"), - dict(torch="1.8.2", torchvision="0.9.1", torchtext="0.9.1"), - dict(torch="1.8.1", torchvision="0.9.1", torchtext="0.9.1"), - dict(torch="1.8.0", torchvision="0.9.0", torchtext="0.9.0"), -] -VERSIONS.sort(key=lambda v: Version(v["torch"]), reverse=True) - - -def find_latest(ver: str) -> Dict[str, str]: - # drop all except semantic version - ver = re.search(r"([\.\d]+)", ver).groups()[0] - # in case there remaining dot at the end - e.g "1.9.0.dev20210504" - ver = ver[:-1] if ver[-1] == "." else ver - logging.info(f"finding ecosystem versions for: {ver}") - - # find first match - for option in VERSIONS: - if option["torch"].startswith(ver): - return option - - raise ValueError(f"Missing {ver} in {VERSIONS}") - - -def adjust(requires: str, torch_version: Optional[str] = None) -> str: - if not torch_version: - import torch - - torch_version = torch.__version__ - assert torch_version, f"invalid torch: {torch_version}" - - # remove comments and strip whitespace - requires = re.sub(rf"\s*#.*{os.linesep}", os.linesep, requires).strip() - - latest = find_latest(torch_version) - for lib, version in latest.items(): - replace = f"{lib}=={version}" if version else "" - requires = re.sub(rf"\b{lib}(?![-_\w]).*", replace, requires) - - return requires - - -if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) - - if len(sys.argv) == 3: - requirements_path, torch_version = sys.argv[1:] - else: - requirements_path, torch_version = sys.argv[1], None - logging.info(f"requirements_path='{requirements_path}' with torch_version='{torch_version}'") - - with open(requirements_path) as fp: - requirements = fp.read() - requirements = adjust(requirements, torch_version) - logging.info(requirements) # on purpose - to debug - with open(requirements_path, "w") as fp: - fp.write(requirements) From 9f714e4fb93d637922b92f9f08d7395b07b1afb5 Mon Sep 17 00:00:00 2001 From: Jirka Date: Mon, 3 Jul 2023 14:29:44 +0200 Subject: [PATCH 09/11] adjust --- .github/workflows/ci-tests-full.yml | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci-tests-full.yml b/.github/workflows/ci-tests-full.yml index d44c9b2fab3..5f801a7dd92 100644 --- a/.github/workflows/ci-tests-full.yml +++ b/.github/workflows/ci-tests-full.yml @@ -79,8 +79,9 @@ jobs: - name: Set PyTorch version if: inputs.requires != 'oldest' run: | - pip install packaging - python ./requirements/adjust-versions.py requirements.txt ${{ matrix.pytorch-version }} + curl https://raw.githubusercontent.com/Lightning-AI/utilities/main/scripts/adjust-torch-versions.py \ + -o adjust-torch-versions.py + python adjust-torch-versions.py requirements.txt ${{ matrix.pytorch-version }} - uses: ./.github/actions/caching with: @@ -107,8 +108,6 @@ jobs: - name: Install all dependencies run: | - curl https://raw.githubusercontent.com/Lightning-AI/utilities/main/scripts/adjust-torch-versions.py -o adjust-torch-versions.py - python adjust-torch-versions.py requirements.txt python adjust-torch-versions.py requirements/image.txt pip install --requirement requirements/devel.txt -U --find-links $PYTORCH_URL pip list From 4e43ff4c34b1fd44ede9b9c364f6e9ab1a0a9ecc Mon Sep 17 00:00:00 2001 From: Jirka Date: Mon, 3 Jul 2023 22:25:11 +0200 Subject: [PATCH 10/11] top_k None --- .../functional/classification/auroc.py | 2 +- .../classification/average_precision.py | 11 +++++----- .../classification/precision_recall_curve.py | 17 +++++--------- .../functional/classification/roc.py | 22 +++++++------------ .../functional/classification/stat_scores.py | 6 ++--- 5 files changed, 23 insertions(+), 35 deletions(-) diff --git a/src/torchmetrics/functional/classification/auroc.py b/src/torchmetrics/functional/classification/auroc.py index ea50adf8192..b04a42ea870 100644 --- a/src/torchmetrics/functional/classification/auroc.py +++ b/src/torchmetrics/functional/classification/auroc.py @@ -519,7 +519,7 @@ def _auroc_compute( num_classes = class_observed.sum() if num_classes == 1: raise ValueError("Found 1 non-empty class in `multiclass` AUROC calculation") - fpr, tpr, _ = roc(preds, target, num_classes, pos_label, sample_weights) + fpr, tpr, _ = roc(preds, target, "binary", num_classes, pos_label, sample_weights) # calculate standard roc auc score if max_fpr is None or max_fpr == 1: diff --git a/src/torchmetrics/functional/classification/average_precision.py b/src/torchmetrics/functional/classification/average_precision.py index 5d7529d6c75..9645c7b1097 100644 --- a/src/torchmetrics/functional/classification/average_precision.py +++ b/src/torchmetrics/functional/classification/average_precision.py @@ -451,7 +451,7 @@ def _average_precision_compute( target = target.flatten() num_classes = 1 - precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes, pos_label) + precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes=num_classes, pos_label=pos_label) if average == "weighted": if preds.ndim == target.ndim and target.ndim > 1: weights = target.sum(dim=0).float() @@ -460,7 +460,7 @@ def _average_precision_compute( weights = weights / torch.sum(weights) else: weights = None - return _average_precision_compute_with_precision_recall(precision, recall, num_classes, average, weights) + return _average_precision_compute_with_precision_recall(precision, recall, num_classes=num_classes, average=average, weights=weights) def _average_precision_compute_with_precision_recall( @@ -488,10 +488,9 @@ def _average_precision_compute_with_precision_recall( ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = torch.tensor([0, 1, 3, 2]) - >>> num_classes = 5 - >>> preds, target, num_classes, pos_label = _average_precision_update(preds, target, num_classes) - >>> precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes) - >>> _average_precision_compute_with_precision_recall(precision, recall, num_classes, average=None) + >>> preds, target, num_classes, pos_label = _average_precision_update(preds, target, num_classes=5) + >>> precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes=num_classes) + >>> _average_precision_compute_with_precision_recall(precision, recall, num_classes=num_classes, average=None) [tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)] """ diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index fa3187cf108..8683ee4dbfe 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -851,7 +851,6 @@ def _precision_recall_curve_compute_multi_class( preds: Tensor, target: Tensor, num_classes: int, - sample_weights: Optional[Sequence] = None, ) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]: """Computes precision-recall pairs for multiclass inputs.""" @@ -864,14 +863,11 @@ def _precision_recall_curve_compute_multi_class( preds=preds_cls, target=target, num_classes=1, - pos_label=cls, - sample_weights=sample_weights, ) if target.ndim > 1: prc_args.update( dict( - target=target[:, cls], - pos_label=1, + target=target[:, cls] ) ) res = precision_recall_curve(**prc_args) @@ -911,9 +907,8 @@ def _precision_recall_curve_compute( ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = torch.tensor([0, 1, 3, 2]) - >>> num_classes = 5 - >>> preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target, num_classes) - >>> precision, recall, thresholds = _precision_recall_curve_compute(preds, target, num_classes) + >>> preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target, num_classes=4) + >>> precision, recall, thresholds = _precision_recall_curve_compute(preds, target, num_classes=num_classes) >>> precision [tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]), tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])] @@ -927,8 +922,8 @@ def _precision_recall_curve_compute( if num_classes == 1: if pos_label is None: pos_label = 1 - return _precision_recall_curve_compute_single_class(preds, target, pos_label, sample_weights) - return _precision_recall_curve_compute_multi_class(preds, target, num_classes, sample_weights) + return _precision_recall_curve_compute_single_class(preds, target, pos_label=pos_label, sample_weights=sample_weights) + return _precision_recall_curve_compute_multi_class(preds, target, num_classes=num_classes) def precision_recall_curve( @@ -942,7 +937,7 @@ def precision_recall_curve( validate_args: bool = True, ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: r"""Computes the precision-recall curve. The curve consist of multiple pairs of precision and recall values - evaluated at different thresholds, such that the tradeoff between the two values can been seen. + evaluated at different thresholds, such that the tradeoff between the two values can be seen. This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of diff --git a/src/torchmetrics/functional/classification/roc.py b/src/torchmetrics/functional/classification/roc.py index 37e3673366d..9d01b2babfd 100644 --- a/src/torchmetrics/functional/classification/roc.py +++ b/src/torchmetrics/functional/classification/roc.py @@ -479,7 +479,6 @@ def _roc_compute_multi_class( preds: Tensor, target: Tensor, num_classes: int, - sample_weights: Optional[Sequence] = None, ) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]: """Computes Receiver Operating Characteristic for multi class inputs. @@ -491,16 +490,13 @@ def _roc_compute_multi_class( for cls in range(num_classes): if preds.shape == target.shape: target_cls = target[:, cls] - pos_label = 1 else: target_cls = target - pos_label = cls res = roc( + task="multiclass", preds=preds[:, cls], target=target_cls, num_classes=1, - pos_label=pos_label, - sample_weights=sample_weights, ) fpr.append(res[0]) tpr.append(res[1]) @@ -514,7 +510,6 @@ def _roc_compute( target: Tensor, num_classes: int, pos_label: Optional[int] = None, - sample_weights: Optional[Sequence] = None, ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: """Computes Receiver Operating Characteristic based on the number of classes. @@ -538,9 +533,8 @@ def _roc_compute( ... [0.05, 0.05, 0.75, 0.05], ... [0.05, 0.05, 0.05, 0.75]]) >>> target = torch.tensor([0, 1, 3, 2]) - >>> num_classes = 4 - >>> preds, target, num_classes, pos_label = _roc_update(preds, target, num_classes) - >>> fpr, tpr, thresholds = _roc_compute(preds, target, num_classes) + >>> preds, target, num_classes, pos_label = _roc_update(preds, target, num_classes=4) + >>> fpr, tpr, thresholds = _roc_compute(preds, target, num_classes=4) >>> fpr [tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0.0000, 0.3333, 1.0000]), tensor([0.0000, 0.3333, 1.0000])] >>> tpr @@ -556,8 +550,8 @@ def _roc_compute( if num_classes == 1 and preds.ndim == 1: # binary if pos_label is None: pos_label = 1 - return _roc_compute_single_class(preds, target, pos_label, sample_weights) - return _roc_compute_multi_class(preds, target, num_classes, sample_weights) + return _roc_compute_single_class(preds, target, pos_label) + return _roc_compute_multi_class(preds, target, num_classes) def roc( @@ -624,13 +618,13 @@ def roc( tensor([1.0000, 0.1837, 0.1338, 0.1183, 0.1138])] """ if task == "binary": - return binary_roc(preds, target, thresholds, ignore_index, validate_args) + return binary_roc(preds, target, thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args) if task == "multiclass": assert isinstance(num_classes, int) - return multiclass_roc(preds, target, num_classes, thresholds, ignore_index, validate_args) + return multiclass_roc(preds, target, num_classes=num_classes, thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args) if task == "multilabel": assert isinstance(num_labels, int) - return multilabel_roc(preds, target, num_labels, thresholds, ignore_index, validate_args) + return multilabel_roc(preds, target, num_labels=num_labels, thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index 82adee32cce..0bc2188a87a 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -912,7 +912,7 @@ def _stat_scores_update( reduce: Optional[str] = "micro", mdmc_reduce: Optional[str] = None, num_classes: Optional[int] = None, - top_k: Optional[int] = 1, + top_k: Optional[int] = None, threshold: float = 0.5, multiclass: Optional[bool] = None, ignore_index: Optional[int] = None, @@ -1005,12 +1005,12 @@ def _stat_scores_compute(tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor) -> Tens Example: >>> preds = torch.tensor([1, 0, 2, 1]) >>> target = torch.tensor([1, 1, 2, 0]) - >>> tp, fp, tn, fn = _stat_scores_update(preds, target, reduce='macro', num_classes=3) + >>> tp, fp, tn, fn = _stat_scores_update(preds, target, top_k=1, reduce='macro', num_classes=3) >>> _stat_scores_compute(tp, fp, tn, fn) tensor([[0, 1, 2, 1, 1], [1, 1, 1, 1, 2], [1, 0, 3, 0, 1]]) - >>> tp, fp, tn, fn = _stat_scores_update(preds, target, reduce='micro') + >>> tp, fp, tn, fn = _stat_scores_update(preds, target, top_k=1, reduce='micro') >>> _stat_scores_compute(tp, fp, tn, fn) tensor([2, 2, 6, 2, 4]) """ From 9d7a67cc8ef89d30c2664bbd1b4203d2214c12f1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 3 Jul 2023 20:25:53 +0000 Subject: [PATCH 11/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../classification/average_precision.py | 4 +++- .../classification/precision_recall_curve.py | 10 ++++------ .../functional/classification/roc.py | 18 ++++++++++++++++-- 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/src/torchmetrics/functional/classification/average_precision.py b/src/torchmetrics/functional/classification/average_precision.py index 9645c7b1097..dcc5bfde5f7 100644 --- a/src/torchmetrics/functional/classification/average_precision.py +++ b/src/torchmetrics/functional/classification/average_precision.py @@ -460,7 +460,9 @@ def _average_precision_compute( weights = weights / torch.sum(weights) else: weights = None - return _average_precision_compute_with_precision_recall(precision, recall, num_classes=num_classes, average=average, weights=weights) + return _average_precision_compute_with_precision_recall( + precision, recall, num_classes=num_classes, average=average, weights=weights + ) def _average_precision_compute_with_precision_recall( diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index 8683ee4dbfe..aa63f03b950 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -865,11 +865,7 @@ def _precision_recall_curve_compute_multi_class( num_classes=1, ) if target.ndim > 1: - prc_args.update( - dict( - target=target[:, cls] - ) - ) + prc_args.update(dict(target=target[:, cls])) res = precision_recall_curve(**prc_args) precision.append(res[0]) recall.append(res[1]) @@ -922,7 +918,9 @@ def _precision_recall_curve_compute( if num_classes == 1: if pos_label is None: pos_label = 1 - return _precision_recall_curve_compute_single_class(preds, target, pos_label=pos_label, sample_weights=sample_weights) + return _precision_recall_curve_compute_single_class( + preds, target, pos_label=pos_label, sample_weights=sample_weights + ) return _precision_recall_curve_compute_multi_class(preds, target, num_classes=num_classes) diff --git a/src/torchmetrics/functional/classification/roc.py b/src/torchmetrics/functional/classification/roc.py index 9d01b2babfd..5d990e91cc8 100644 --- a/src/torchmetrics/functional/classification/roc.py +++ b/src/torchmetrics/functional/classification/roc.py @@ -621,10 +621,24 @@ def roc( return binary_roc(preds, target, thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args) if task == "multiclass": assert isinstance(num_classes, int) - return multiclass_roc(preds, target, num_classes=num_classes, thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args) + return multiclass_roc( + preds, + target, + num_classes=num_classes, + thresholds=thresholds, + ignore_index=ignore_index, + validate_args=validate_args, + ) if task == "multilabel": assert isinstance(num_labels, int) - return multilabel_roc(preds, target, num_labels=num_labels, thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args) + return multilabel_roc( + preds, + target, + num_labels=num_labels, + thresholds=thresholds, + ignore_index=ignore_index, + validate_args=validate_args, + ) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" )