diff --git a/CHANGELOG.md b/CHANGELOG.md index b23ce58d355..ef9681f1e79 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed support for pixelwise MSE ([#2017](https://github.com/Lightning-AI/torchmetrics/pull/2017) +- Fixed missing attributes `higher_is_better`, `is_differentiable` for some metrics ([#2028](https://github.com/Lightning-AI/torchmetrics/pull/2028) + ## [1.1.0] - 2023-08-22 ### Added diff --git a/src/torchmetrics/classification/accuracy.py b/src/torchmetrics/classification/accuracy.py index 8e7e6eb3455..8deb4fe5544 100644 --- a/src/torchmetrics/classification/accuracy.py +++ b/src/torchmetrics/classification/accuracy.py @@ -92,8 +92,8 @@ class BinaryAccuracy(BinaryStatScores): tensor([0.3333, 0.1667]) """ - is_differentiable = False - higher_is_better = True + is_differentiable: bool = False + higher_is_better: bool = True full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 @@ -240,8 +240,8 @@ class MulticlassAccuracy(MulticlassStatScores): [0.0000, 0.3333, 0.5000]]) """ - is_differentiable = False - higher_is_better = True + is_differentiable: bool = False + higher_is_better: bool = True full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 @@ -389,8 +389,8 @@ class MultilabelAccuracy(MultilabelStatScores): [0.0000, 0.0000, 0.5000]]) """ - is_differentiable = False - higher_is_better = True + is_differentiable: bool = False + higher_is_better: bool = True full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 diff --git a/src/torchmetrics/classification/auroc.py b/src/torchmetrics/classification/auroc.py index f42c4b942f9..fb3465f4010 100644 --- a/src/torchmetrics/classification/auroc.py +++ b/src/torchmetrics/classification/auroc.py @@ -99,7 +99,7 @@ class BinaryAUROC(BinaryPrecisionRecallCurve): """ is_differentiable: bool = False - higher_is_better: Optional[bool] = None + higher_is_better: bool = True full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 @@ -241,7 +241,7 @@ class MulticlassAUROC(MulticlassPrecisionRecallCurve): """ is_differentiable: bool = False - higher_is_better: Optional[bool] = None + higher_is_better: bool = True full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 @@ -390,7 +390,7 @@ class MultilabelAUROC(MultilabelPrecisionRecallCurve): """ is_differentiable: bool = False - higher_is_better: Optional[bool] = None + higher_is_better: bool = True full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 diff --git a/src/torchmetrics/classification/average_precision.py b/src/torchmetrics/classification/average_precision.py index 8ade1884d73..2d29c30e05b 100644 --- a/src/torchmetrics/classification/average_precision.py +++ b/src/torchmetrics/classification/average_precision.py @@ -106,7 +106,7 @@ class BinaryAveragePrecision(BinaryPrecisionRecallCurve): """ is_differentiable: bool = False - higher_is_better: Optional[bool] = None + higher_is_better: bool = True full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 @@ -239,7 +239,7 @@ class MulticlassAveragePrecision(MulticlassPrecisionRecallCurve): """ is_differentiable: bool = False - higher_is_better: Optional[bool] = None + higher_is_better: bool = True full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 @@ -393,7 +393,7 @@ class MultilabelAveragePrecision(MultilabelPrecisionRecallCurve): """ is_differentiable: bool = False - higher_is_better: Optional[bool] = None + higher_is_better: bool = True full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 diff --git a/src/torchmetrics/classification/exact_match.py b/src/torchmetrics/classification/exact_match.py index 075c3430524..3d9bf724ccc 100644 --- a/src/torchmetrics/classification/exact_match.py +++ b/src/torchmetrics/classification/exact_match.py @@ -94,8 +94,8 @@ class MulticlassExactMatch(Metric): tensor([1., 0.]) """ - is_differentiable = False - higher_is_better = True + is_differentiable: bool = False + higher_is_better: bool = True full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 @@ -257,8 +257,8 @@ class MultilabelExactMatch(Metric): """ - is_differentiable = False - higher_is_better = True + is_differentiable: bool = False + higher_is_better: bool = True full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 diff --git a/src/torchmetrics/classification/group_fairness.py b/src/torchmetrics/classification/group_fairness.py index 33ea8950fb7..9e46dd36891 100644 --- a/src/torchmetrics/classification/group_fairness.py +++ b/src/torchmetrics/classification/group_fairness.py @@ -101,8 +101,8 @@ class BinaryGroupStatRates(_AbstractGroupStatScores): {'group_0': tensor([0., 0., 1., 0.]), 'group_1': tensor([1., 0., 0., 0.])} """ - is_differentiable = False - higher_is_better = False + is_differentiable: bool = False + higher_is_better: bool = False full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 @@ -207,8 +207,8 @@ class BinaryFairness(_AbstractGroupStatScores): {'DP_0_1': tensor(0.), 'EO_0_1': tensor(0.)} """ - is_differentiable = False - higher_is_better = False + is_differentiable: bool = False + higher_is_better: bool = False full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 diff --git a/src/torchmetrics/clustering/mutual_info_score.py b/src/torchmetrics/clustering/mutual_info_score.py index 86118daf41c..504d12c2718 100644 --- a/src/torchmetrics/clustering/mutual_info_score.py +++ b/src/torchmetrics/clustering/mutual_info_score.py @@ -62,8 +62,8 @@ class MutualInfoScore(Metric): """ - is_differentiable = True - higher_is_better = None + is_differentiable: bool = True + higher_is_better: bool = True full_state_update: bool = True plot_lower_bound: float = 0.0 preds: List[Tensor] diff --git a/src/torchmetrics/detection/ciou.py b/src/torchmetrics/detection/ciou.py index 0adc57af4f6..5b62679a396 100644 --- a/src/torchmetrics/detection/ciou.py +++ b/src/torchmetrics/detection/ciou.py @@ -93,6 +93,10 @@ class CompleteIntersectionOverUnion(IntersectionOverUnion): If torchvision is not installed with version 0.13.0 or newer. """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = True + full_state_update: bool = True + _iou_type: str = "ciou" _invalid_val: float = -2.0 # unsure, min val could be just -1.5 as well diff --git a/src/torchmetrics/detection/diou.py b/src/torchmetrics/detection/diou.py index 3508d80fee1..6778979b1c0 100644 --- a/src/torchmetrics/detection/diou.py +++ b/src/torchmetrics/detection/diou.py @@ -93,6 +93,10 @@ class DistanceIntersectionOverUnion(IntersectionOverUnion): If torchvision is not installed with version 0.13.0 or newer. """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = True + full_state_update: bool = True + _iou_type: str = "diou" _invalid_val: float = -1.0 diff --git a/src/torchmetrics/detection/giou.py b/src/torchmetrics/detection/giou.py index d53d3e88777..e4ec9aee65c 100644 --- a/src/torchmetrics/detection/giou.py +++ b/src/torchmetrics/detection/giou.py @@ -93,6 +93,10 @@ class GeneralizedIntersectionOverUnion(IntersectionOverUnion): If torchvision is not installed with version 0.8.0 or newer. """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = True + full_state_update: bool = True + _iou_type: str = "giou" _invalid_val: float = -1.0 diff --git a/src/torchmetrics/image/perceptual_path_length.py b/src/torchmetrics/image/perceptual_path_length.py index 829a5249bef..bb7561f7c4b 100644 --- a/src/torchmetrics/image/perceptual_path_length.py +++ b/src/torchmetrics/image/perceptual_path_length.py @@ -118,6 +118,9 @@ class PerceptualPathLength(Metric): tensor([0.3502, 0.1362, 0.2535, 0.0902, 0.1784, 0.0769, 0.5871, 0.0691, 0.3921])) """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = True + full_state_update: bool = True def __init__( self, diff --git a/src/torchmetrics/regression/concordance.py b/src/torchmetrics/regression/concordance.py index d45a06cb943..2a52e8a8036 100644 --- a/src/torchmetrics/regression/concordance.py +++ b/src/torchmetrics/regression/concordance.py @@ -67,6 +67,10 @@ class ConcordanceCorrCoef(PearsonCorrCoef): tensor([0.7273, 0.9887]) """ + is_differentiable: bool = True + higher_is_better: bool = True + full_state_update: bool = True + plot_lower_bound: float = -1.0 plot_upper_bound: float = 1.0 diff --git a/src/torchmetrics/regression/pearson.py b/src/torchmetrics/regression/pearson.py index 0fc575f7449..a25fa72ff7e 100644 --- a/src/torchmetrics/regression/pearson.py +++ b/src/torchmetrics/regression/pearson.py @@ -110,8 +110,8 @@ class PearsonCorrCoef(Metric): tensor([1., 1.]) """ - is_differentiable = True - higher_is_better = None # both -1 and 1 are optimal + is_differentiable: bool = True + higher_is_better: Optional[bool] = None # both -1 and 1 are optimal full_state_update: bool = True plot_lower_bound: float = -1.0 plot_upper_bound: float = 1.0