Skip to content

Commit

Permalink
Fix missing attributes in some metric (#2028)
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Aug 28, 2023
1 parent 71177fd commit cc6f6cc
Show file tree
Hide file tree
Showing 13 changed files with 45 additions and 24 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed support for pixelwise MSE ([#2017](https://github.com/Lightning-AI/torchmetrics/pull/2017)


- Fixed missing attributes `higher_is_better`, `is_differentiable` for some metrics ([#2028](https://github.com/Lightning-AI/torchmetrics/pull/2028)

## [1.1.0] - 2023-08-22

### Added
Expand Down
12 changes: 6 additions & 6 deletions src/torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ class BinaryAccuracy(BinaryStatScores):
tensor([0.3333, 0.1667])
"""
is_differentiable = False
higher_is_better = True
is_differentiable: bool = False
higher_is_better: bool = True
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
Expand Down Expand Up @@ -240,8 +240,8 @@ class MulticlassAccuracy(MulticlassStatScores):
[0.0000, 0.3333, 0.5000]])
"""
is_differentiable = False
higher_is_better = True
is_differentiable: bool = False
higher_is_better: bool = True
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
Expand Down Expand Up @@ -389,8 +389,8 @@ class MultilabelAccuracy(MultilabelStatScores):
[0.0000, 0.0000, 0.5000]])
"""
is_differentiable = False
higher_is_better = True
is_differentiable: bool = False
higher_is_better: bool = True
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
Expand Down
6 changes: 3 additions & 3 deletions src/torchmetrics/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class BinaryAUROC(BinaryPrecisionRecallCurve):
"""
is_differentiable: bool = False
higher_is_better: Optional[bool] = None
higher_is_better: bool = True
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
Expand Down Expand Up @@ -241,7 +241,7 @@ class MulticlassAUROC(MulticlassPrecisionRecallCurve):
"""

is_differentiable: bool = False
higher_is_better: Optional[bool] = None
higher_is_better: bool = True
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
Expand Down Expand Up @@ -390,7 +390,7 @@ class MultilabelAUROC(MultilabelPrecisionRecallCurve):
"""
is_differentiable: bool = False
higher_is_better: Optional[bool] = None
higher_is_better: bool = True
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
Expand Down
6 changes: 3 additions & 3 deletions src/torchmetrics/classification/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class BinaryAveragePrecision(BinaryPrecisionRecallCurve):
"""
is_differentiable: bool = False
higher_is_better: Optional[bool] = None
higher_is_better: bool = True
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
Expand Down Expand Up @@ -239,7 +239,7 @@ class MulticlassAveragePrecision(MulticlassPrecisionRecallCurve):
"""

is_differentiable: bool = False
higher_is_better: Optional[bool] = None
higher_is_better: bool = True
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
Expand Down Expand Up @@ -393,7 +393,7 @@ class MultilabelAveragePrecision(MultilabelPrecisionRecallCurve):
"""
is_differentiable: bool = False
higher_is_better: Optional[bool] = None
higher_is_better: bool = True
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
Expand Down
8 changes: 4 additions & 4 deletions src/torchmetrics/classification/exact_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ class MulticlassExactMatch(Metric):
tensor([1., 0.])
"""
is_differentiable = False
higher_is_better = True
is_differentiable: bool = False
higher_is_better: bool = True
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
Expand Down Expand Up @@ -257,8 +257,8 @@ class MultilabelExactMatch(Metric):
"""

is_differentiable = False
higher_is_better = True
is_differentiable: bool = False
higher_is_better: bool = True
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
Expand Down
8 changes: 4 additions & 4 deletions src/torchmetrics/classification/group_fairness.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ class BinaryGroupStatRates(_AbstractGroupStatScores):
{'group_0': tensor([0., 0., 1., 0.]), 'group_1': tensor([1., 0., 0., 0.])}
"""
is_differentiable = False
higher_is_better = False
is_differentiable: bool = False
higher_is_better: bool = False
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
Expand Down Expand Up @@ -207,8 +207,8 @@ class BinaryFairness(_AbstractGroupStatScores):
{'DP_0_1': tensor(0.), 'EO_0_1': tensor(0.)}
"""
is_differentiable = False
higher_is_better = False
is_differentiable: bool = False
higher_is_better: bool = False
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/clustering/mutual_info_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ class MutualInfoScore(Metric):
"""

is_differentiable = True
higher_is_better = None
is_differentiable: bool = True
higher_is_better: bool = True
full_state_update: bool = True
plot_lower_bound: float = 0.0
preds: List[Tensor]
Expand Down
4 changes: 4 additions & 0 deletions src/torchmetrics/detection/ciou.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ class CompleteIntersectionOverUnion(IntersectionOverUnion):
If torchvision is not installed with version 0.13.0 or newer.
"""
is_differentiable: bool = False
higher_is_better: Optional[bool] = True
full_state_update: bool = True

_iou_type: str = "ciou"
_invalid_val: float = -2.0 # unsure, min val could be just -1.5 as well

Expand Down
4 changes: 4 additions & 0 deletions src/torchmetrics/detection/diou.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ class DistanceIntersectionOverUnion(IntersectionOverUnion):
If torchvision is not installed with version 0.13.0 or newer.
"""
is_differentiable: bool = False
higher_is_better: Optional[bool] = True
full_state_update: bool = True

_iou_type: str = "diou"
_invalid_val: float = -1.0

Expand Down
4 changes: 4 additions & 0 deletions src/torchmetrics/detection/giou.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ class GeneralizedIntersectionOverUnion(IntersectionOverUnion):
If torchvision is not installed with version 0.8.0 or newer.
"""
is_differentiable: bool = False
higher_is_better: Optional[bool] = True
full_state_update: bool = True

_iou_type: str = "giou"
_invalid_val: float = -1.0

Expand Down
3 changes: 3 additions & 0 deletions src/torchmetrics/image/perceptual_path_length.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ class PerceptualPathLength(Metric):
tensor([0.3502, 0.1362, 0.2535, 0.0902, 0.1784, 0.0769, 0.5871, 0.0691, 0.3921]))
"""
is_differentiable: bool = False
higher_is_better: Optional[bool] = True
full_state_update: bool = True

def __init__(
self,
Expand Down
4 changes: 4 additions & 0 deletions src/torchmetrics/regression/concordance.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ class ConcordanceCorrCoef(PearsonCorrCoef):
tensor([0.7273, 0.9887])
"""
is_differentiable: bool = True
higher_is_better: bool = True
full_state_update: bool = True

plot_lower_bound: float = -1.0
plot_upper_bound: float = 1.0

Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/regression/pearson.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ class PearsonCorrCoef(Metric):
tensor([1., 1.])
"""
is_differentiable = True
higher_is_better = None # both -1 and 1 are optimal
is_differentiable: bool = True
higher_is_better: Optional[bool] = None # both -1 and 1 are optimal
full_state_update: bool = True
plot_lower_bound: float = -1.0
plot_upper_bound: float = 1.0
Expand Down

0 comments on commit cc6f6cc

Please sign in to comment.