From a406b7cbfb97d7969f73edab138b8a35f1d0c3f4 Mon Sep 17 00:00:00 2001 From: Rittik Panda Date: Wed, 1 Oct 2025 01:55:17 +0530 Subject: [PATCH 01/14] add zero_division --- src/torchmetrics/segmentation/dice.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/segmentation/dice.py b/src/torchmetrics/segmentation/dice.py index 2723a2c3ebd..cf61c764a32 100644 --- a/src/torchmetrics/segmentation/dice.py +++ b/src/torchmetrics/segmentation/dice.py @@ -86,6 +86,8 @@ class DiceScore(Metric): If ``average`` is not one of ``"micro"``, ``"macro"``, ``"weighted"``, ``"none"`` or ``None`` ValueError: If ``input_format`` is not one of ``"one-hot"``, ``"index"`` or ``"mixed"`` + ValueError: + If ``zero_division`` is not one of 0.0, 1.0, "warn", or "nan" Example: >>> from torch import randint @@ -99,6 +101,15 @@ class DiceScore(Metric): >>> dice_score(preds, target) tensor([0.4860, 0.4999, 0.5014, 0.4885, 0.4915]) + Example with zero_division: + >>> from torch import zeros + >>> from torchmetrics.segmentation import DiceScore + >>> preds = zeros(2, 3, 16, 16) # Empty predictions + >>> target = zeros(2, 3, 16, 16) # Empty targets + >>> dice_score = DiceScore(num_classes=3, zero_division=1.0) + >>> dice_score(preds, target) + tensor(1.0000) + """ full_state_update: bool = False @@ -118,6 +129,7 @@ def __init__( average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", aggregation_level: Optional[Literal["samplewise", "global"]] = "samplewise", input_format: Literal["one-hot", "index", "mixed"] = "one-hot", + zero_division: Union[float, Literal["warn", "nan"]] = "nan", **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -128,12 +140,13 @@ def __init__( " If you've explicitly set this parameter, you can ignore this warning.", UserWarning, ) - _dice_score_validate_args(num_classes, include_background, average, input_format, aggregation_level) + _dice_score_validate_args(num_classes, include_background, average, input_format, aggregation_level, zero_division) self.num_classes = num_classes self.include_background = include_background self.average = average self.aggregation_level = aggregation_level self.input_format = input_format + self.zero_division = zero_division num_classes = num_classes - 1 if not include_background else num_classes self.add_state("numerator", [], dist_reduce_fx="cat") @@ -157,6 +170,7 @@ def compute(self) -> Tensor: self.average, self.aggregation_level, support=dim_zero_cat(self.support) if self.average == "weighted" else None, + zero_division=self.zero_division, ).nanmean(dim=0) def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: From d3b3f0d98a4e5e6eb43a565630e42cbe318f4002 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Sep 2025 20:29:19 +0000 Subject: [PATCH 02/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/segmentation/dice.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/segmentation/dice.py b/src/torchmetrics/segmentation/dice.py index cf61c764a32..398e9b15e61 100644 --- a/src/torchmetrics/segmentation/dice.py +++ b/src/torchmetrics/segmentation/dice.py @@ -140,7 +140,9 @@ def __init__( " If you've explicitly set this parameter, you can ignore this warning.", UserWarning, ) - _dice_score_validate_args(num_classes, include_background, average, input_format, aggregation_level, zero_division) + _dice_score_validate_args( + num_classes, include_background, average, input_format, aggregation_level, zero_division + ) self.num_classes = num_classes self.include_background = include_background self.average = average From 6861fa799995d0af494e8e0f773551aed98c9d99 Mon Sep 17 00:00:00 2001 From: Rittik Panda Date: Wed, 1 Oct 2025 02:00:00 +0530 Subject: [PATCH 03/14] Update functional interface --- .../functional/segmentation/dice.py | 43 ++++++++++++++++--- 1 file changed, 37 insertions(+), 6 deletions(-) diff --git a/src/torchmetrics/functional/segmentation/dice.py b/src/torchmetrics/functional/segmentation/dice.py index d402ca67246..e847c73439f 100644 --- a/src/torchmetrics/functional/segmentation/dice.py +++ b/src/torchmetrics/functional/segmentation/dice.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Union import torch from torch import Tensor @@ -28,6 +28,7 @@ def _dice_score_validate_args( average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", input_format: Literal["one-hot", "index", "mixed"] = "one-hot", aggregation_level: Optional[Literal["samplewise", "global"]] = "samplewise", + zero_division: Union[float, Literal["warn", "nan"]] = "nan", ) -> None: """Validate the arguments of the metric.""" if not isinstance(num_classes, int) or num_classes <= 0: @@ -45,6 +46,10 @@ def _dice_score_validate_args( raise ValueError( f"Expected argument `aggregation_level` to be one of `samplewise`, `global`, but got {aggregation_level}" ) + if zero_division not in (0.0, 1.0, "warn", "nan"): + raise ValueError( + f"Expected argument `zero_division` to be one of 0.0, 1.0, 'warn', or 'nan', but got {zero_division}." + ) def _dice_score_update( @@ -74,6 +79,7 @@ def _dice_score_compute( average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", aggregation_level: Optional[Literal["samplewise", "global"]] = "samplewise", support: Optional[Tensor] = None, + zero_division: Union[float, Literal["warn", "nan"]] = "nan", ) -> Tensor: """Compute the Dice score from the numerator and denominator.""" if aggregation_level == "global": @@ -81,18 +87,26 @@ def _dice_score_compute( denominator = torch.sum(denominator, dim=0).unsqueeze(0) support = torch.sum(support, dim=0) if support is not None else None + # Determine the zero_division value to use + if zero_division == "warn": + zero_div_value = "warn" + elif zero_division == "nan": + zero_div_value = "nan" + else: + zero_div_value = float(zero_division) + if average == "micro": numerator = torch.sum(numerator, dim=-1) denominator = torch.sum(denominator, dim=-1) - return _safe_divide(numerator, denominator, zero_division="nan") + return _safe_divide(numerator, denominator, zero_division=zero_div_value) - dice = _safe_divide(numerator, denominator, zero_division="nan") + dice = _safe_divide(numerator, denominator, zero_division=zero_div_value) if average == "macro": return torch.nanmean(dice, dim=-1) if average == "weighted": if not isinstance(support, torch.Tensor): raise ValueError(f"Expected argument `support` to be a tensor, got: {type(support)}.") - weights = _safe_divide(support, torch.sum(support, dim=-1, keepdim=True), zero_division="nan") + weights = _safe_divide(support, torch.sum(support, dim=-1, keepdim=True), zero_division=zero_div_value) nan_mask = dice.isnan().all(dim=-1) dice = torch.nansum(dice * weights, dim=-1) dice[nan_mask] = torch.nan @@ -110,6 +124,7 @@ def dice_score( average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", input_format: Literal["one-hot", "index", "mixed"] = "one-hot", aggregation_level: Optional[Literal["samplewise", "global"]] = "samplewise", + zero_division: Union[float, Literal["warn", "nan"]] = "nan", ) -> Tensor: """Compute the Dice score for semantic segmentation. @@ -126,6 +141,8 @@ def dice_score( aggregation_level: The level at which to aggregate the dice score. Options are ``"samplewise"`` or ``"global"``. For ``"samplewise"`` the dice score is computed for each sample and then averaged. For ``"global"`` the dice score is computed globally over all samples. + zero_division: The value to return when there is a division by zero. Options are 1.0, 0.0, "warn" or "nan". + Setting it to "warn" behaves like 0.0 but will also create a warning. Returns: The Dice score. @@ -168,6 +185,18 @@ def dice_score( >>> dice_score(preds, target, num_classes=5, average="macro", aggregation_level="global", input_format="index") tensor([0.1965]) + Example (with zero_division parameter): + >>> from torch import randint, zeros + >>> from torchmetrics.functional.segmentation import dice_score + >>> preds = zeros(2, 3, 16, 16) # Empty predictions + >>> target = zeros(2, 3, 16, 16) # Empty targets + >>> # Using zero_division=1.0 + >>> dice_score(preds, target, num_classes=3, zero_division=1.0) + tensor([1., 1.]) + >>> # Using zero_division=0.0 + >>> dice_score(preds, target, num_classes=3, zero_division=0.0) + tensor([0., 0.]) + """ if average == "micro": rank_zero_warn( @@ -176,6 +205,8 @@ def dice_score( " If you've explicitly set this parameter, you can ignore this warning.", UserWarning, ) - _dice_score_validate_args(num_classes, include_background, average, input_format, aggregation_level) + _dice_score_validate_args(num_classes, include_background, average, input_format, aggregation_level, zero_division) numerator, denominator, support = _dice_score_update(preds, target, num_classes, include_background, input_format) - return _dice_score_compute(numerator, denominator, average, aggregation_level=aggregation_level, support=support) + return _dice_score_compute( + numerator, denominator, average, aggregation_level=aggregation_level, support=support, zero_division=zero_division + ) From 9adc17b25218e853d215ea7771c036c31110e037 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Sep 2025 20:30:25 +0000 Subject: [PATCH 04/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/functional/segmentation/dice.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/segmentation/dice.py b/src/torchmetrics/functional/segmentation/dice.py index e847c73439f..030ad8ee789 100644 --- a/src/torchmetrics/functional/segmentation/dice.py +++ b/src/torchmetrics/functional/segmentation/dice.py @@ -208,5 +208,10 @@ def dice_score( _dice_score_validate_args(num_classes, include_background, average, input_format, aggregation_level, zero_division) numerator, denominator, support = _dice_score_update(preds, target, num_classes, include_background, input_format) return _dice_score_compute( - numerator, denominator, average, aggregation_level=aggregation_level, support=support, zero_division=zero_division + numerator, + denominator, + average, + aggregation_level=aggregation_level, + support=support, + zero_division=zero_division, ) From f0086002273812bec782df78762302934f7e139a Mon Sep 17 00:00:00 2001 From: Rittik Panda Date: Wed, 1 Oct 2025 02:48:46 +0530 Subject: [PATCH 05/14] Update dice.py --- src/torchmetrics/functional/segmentation/dice.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchmetrics/functional/segmentation/dice.py b/src/torchmetrics/functional/segmentation/dice.py index 030ad8ee789..81170704bf9 100644 --- a/src/torchmetrics/functional/segmentation/dice.py +++ b/src/torchmetrics/functional/segmentation/dice.py @@ -191,10 +191,10 @@ def dice_score( >>> preds = zeros(2, 3, 16, 16) # Empty predictions >>> target = zeros(2, 3, 16, 16) # Empty targets >>> # Using zero_division=1.0 - >>> dice_score(preds, target, num_classes=3, zero_division=1.0) + >>> dice_score(preds, target, num_classes=3, zero_division=1.0, average="micro") tensor([1., 1.]) - >>> # Using zero_division=0.0 - >>> dice_score(preds, target, num_classes=3, zero_division=0.0) + >>> # Using zero_division=0.0 + >>> dice_score(preds, target, num_classes=3, zero_division=0.0, average="micro") tensor([0., 0.]) """ From be245b70c11fe11214ab84ec154a488afc54e05d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Sep 2025 21:19:07 +0000 Subject: [PATCH 06/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/functional/segmentation/dice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/segmentation/dice.py b/src/torchmetrics/functional/segmentation/dice.py index 81170704bf9..2e834b13173 100644 --- a/src/torchmetrics/functional/segmentation/dice.py +++ b/src/torchmetrics/functional/segmentation/dice.py @@ -193,7 +193,7 @@ def dice_score( >>> # Using zero_division=1.0 >>> dice_score(preds, target, num_classes=3, zero_division=1.0, average="micro") tensor([1., 1.]) - >>> # Using zero_division=0.0 + >>> # Using zero_division=0.0 >>> dice_score(preds, target, num_classes=3, zero_division=0.0, average="micro") tensor([0., 0.]) From 96d935c114662b4811432bcb95a3d4a5abb322ec Mon Sep 17 00:00:00 2001 From: Rittik Panda Date: Wed, 1 Oct 2025 02:49:30 +0530 Subject: [PATCH 07/14] Update dice.py --- src/torchmetrics/segmentation/dice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/segmentation/dice.py b/src/torchmetrics/segmentation/dice.py index 398e9b15e61..c01b2cf5462 100644 --- a/src/torchmetrics/segmentation/dice.py +++ b/src/torchmetrics/segmentation/dice.py @@ -108,7 +108,7 @@ class DiceScore(Metric): >>> target = zeros(2, 3, 16, 16) # Empty targets >>> dice_score = DiceScore(num_classes=3, zero_division=1.0) >>> dice_score(preds, target) - tensor(1.0000) + tensor(1.0) """ From 4c161146ae86b7cef5840cb93adefe285cd3b427 Mon Sep 17 00:00:00 2001 From: Rittik Panda Date: Wed, 1 Oct 2025 02:57:13 +0530 Subject: [PATCH 08/14] Update dice.py --- src/torchmetrics/segmentation/dice.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/torchmetrics/segmentation/dice.py b/src/torchmetrics/segmentation/dice.py index c01b2cf5462..fcff807b208 100644 --- a/src/torchmetrics/segmentation/dice.py +++ b/src/torchmetrics/segmentation/dice.py @@ -101,15 +101,6 @@ class DiceScore(Metric): >>> dice_score(preds, target) tensor([0.4860, 0.4999, 0.5014, 0.4885, 0.4915]) - Example with zero_division: - >>> from torch import zeros - >>> from torchmetrics.segmentation import DiceScore - >>> preds = zeros(2, 3, 16, 16) # Empty predictions - >>> target = zeros(2, 3, 16, 16) # Empty targets - >>> dice_score = DiceScore(num_classes=3, zero_division=1.0) - >>> dice_score(preds, target) - tensor(1.0) - """ full_state_update: bool = False From ceed33af79179f50668753a4f38d7e9da4d6c306 Mon Sep 17 00:00:00 2001 From: Rittik Panda Date: Wed, 1 Oct 2025 02:57:46 +0530 Subject: [PATCH 09/14] Update dice.py --- src/torchmetrics/functional/segmentation/dice.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/torchmetrics/functional/segmentation/dice.py b/src/torchmetrics/functional/segmentation/dice.py index 2e834b13173..98d8ebe9f86 100644 --- a/src/torchmetrics/functional/segmentation/dice.py +++ b/src/torchmetrics/functional/segmentation/dice.py @@ -185,18 +185,6 @@ def dice_score( >>> dice_score(preds, target, num_classes=5, average="macro", aggregation_level="global", input_format="index") tensor([0.1965]) - Example (with zero_division parameter): - >>> from torch import randint, zeros - >>> from torchmetrics.functional.segmentation import dice_score - >>> preds = zeros(2, 3, 16, 16) # Empty predictions - >>> target = zeros(2, 3, 16, 16) # Empty targets - >>> # Using zero_division=1.0 - >>> dice_score(preds, target, num_classes=3, zero_division=1.0, average="micro") - tensor([1., 1.]) - >>> # Using zero_division=0.0 - >>> dice_score(preds, target, num_classes=3, zero_division=0.0, average="micro") - tensor([0., 0.]) - """ if average == "micro": rank_zero_warn( From 2270794e3a7684951a2c4f48e2a607e094f3d6bd Mon Sep 17 00:00:00 2001 From: Rittik Panda Date: Sat, 4 Oct 2025 03:39:49 +0530 Subject: [PATCH 10/14] Update dice.py --- src/torchmetrics/segmentation/dice.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/torchmetrics/segmentation/dice.py b/src/torchmetrics/segmentation/dice.py index fcff807b208..66819994976 100644 --- a/src/torchmetrics/segmentation/dice.py +++ b/src/torchmetrics/segmentation/dice.py @@ -73,8 +73,6 @@ class DiceScore(Metric): input_format: What kind of input the function receives. Choose between ``"one-hot"`` for one-hot encoded tensors, ``"index"`` for index tensors or ``"mixed"`` for one one-hot encoded and one index tensor - zero_division: The value to return when there is a division by zero. Options are 1.0, 0.0, "warn" or "nan". - Setting it to "warn" behaves like 0.0 but will also create a warning. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: From 5fde24a1a9f21df3e994c28b7f94076d2224cd86 Mon Sep 17 00:00:00 2001 From: Rittik Panda Date: Sat, 4 Oct 2025 03:41:53 +0530 Subject: [PATCH 11/14] Update dice.py --- src/torchmetrics/segmentation/dice.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/torchmetrics/segmentation/dice.py b/src/torchmetrics/segmentation/dice.py index 66819994976..3a70453b480 100644 --- a/src/torchmetrics/segmentation/dice.py +++ b/src/torchmetrics/segmentation/dice.py @@ -84,8 +84,6 @@ class DiceScore(Metric): If ``average`` is not one of ``"micro"``, ``"macro"``, ``"weighted"``, ``"none"`` or ``None`` ValueError: If ``input_format`` is not one of ``"one-hot"``, ``"index"`` or ``"mixed"`` - ValueError: - If ``zero_division`` is not one of 0.0, 1.0, "warn", or "nan" Example: >>> from torch import randint @@ -118,7 +116,6 @@ def __init__( average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", aggregation_level: Optional[Literal["samplewise", "global"]] = "samplewise", input_format: Literal["one-hot", "index", "mixed"] = "one-hot", - zero_division: Union[float, Literal["warn", "nan"]] = "nan", **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -130,14 +127,13 @@ def __init__( UserWarning, ) _dice_score_validate_args( - num_classes, include_background, average, input_format, aggregation_level, zero_division + num_classes, include_background, average, input_format, aggregation_level ) self.num_classes = num_classes self.include_background = include_background self.average = average self.aggregation_level = aggregation_level self.input_format = input_format - self.zero_division = zero_division num_classes = num_classes - 1 if not include_background else num_classes self.add_state("numerator", [], dist_reduce_fx="cat") @@ -161,7 +157,6 @@ def compute(self) -> Tensor: self.average, self.aggregation_level, support=dim_zero_cat(self.support) if self.average == "weighted" else None, - zero_division=self.zero_division, ).nanmean(dim=0) def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: From 7319f4921124045ddd2af7af8febd9de05ddb767 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 3 Oct 2025 22:12:13 +0000 Subject: [PATCH 12/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/segmentation/dice.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/torchmetrics/segmentation/dice.py b/src/torchmetrics/segmentation/dice.py index 3a70453b480..148dc2e15c2 100644 --- a/src/torchmetrics/segmentation/dice.py +++ b/src/torchmetrics/segmentation/dice.py @@ -126,9 +126,7 @@ def __init__( " If you've explicitly set this parameter, you can ignore this warning.", UserWarning, ) - _dice_score_validate_args( - num_classes, include_background, average, input_format, aggregation_level - ) + _dice_score_validate_args(num_classes, include_background, average, input_format, aggregation_level) self.num_classes = num_classes self.include_background = include_background self.average = average From f613926473cd4a7043205ab6b0a5ac8e67072d0e Mon Sep 17 00:00:00 2001 From: Rittik Panda Date: Sat, 4 Oct 2025 03:44:16 +0530 Subject: [PATCH 13/14] Update dice.py --- .../functional/segmentation/dice.py | 36 ++++--------------- 1 file changed, 6 insertions(+), 30 deletions(-) diff --git a/src/torchmetrics/functional/segmentation/dice.py b/src/torchmetrics/functional/segmentation/dice.py index 98d8ebe9f86..d402ca67246 100644 --- a/src/torchmetrics/functional/segmentation/dice.py +++ b/src/torchmetrics/functional/segmentation/dice.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import Optional import torch from torch import Tensor @@ -28,7 +28,6 @@ def _dice_score_validate_args( average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", input_format: Literal["one-hot", "index", "mixed"] = "one-hot", aggregation_level: Optional[Literal["samplewise", "global"]] = "samplewise", - zero_division: Union[float, Literal["warn", "nan"]] = "nan", ) -> None: """Validate the arguments of the metric.""" if not isinstance(num_classes, int) or num_classes <= 0: @@ -46,10 +45,6 @@ def _dice_score_validate_args( raise ValueError( f"Expected argument `aggregation_level` to be one of `samplewise`, `global`, but got {aggregation_level}" ) - if zero_division not in (0.0, 1.0, "warn", "nan"): - raise ValueError( - f"Expected argument `zero_division` to be one of 0.0, 1.0, 'warn', or 'nan', but got {zero_division}." - ) def _dice_score_update( @@ -79,7 +74,6 @@ def _dice_score_compute( average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", aggregation_level: Optional[Literal["samplewise", "global"]] = "samplewise", support: Optional[Tensor] = None, - zero_division: Union[float, Literal["warn", "nan"]] = "nan", ) -> Tensor: """Compute the Dice score from the numerator and denominator.""" if aggregation_level == "global": @@ -87,26 +81,18 @@ def _dice_score_compute( denominator = torch.sum(denominator, dim=0).unsqueeze(0) support = torch.sum(support, dim=0) if support is not None else None - # Determine the zero_division value to use - if zero_division == "warn": - zero_div_value = "warn" - elif zero_division == "nan": - zero_div_value = "nan" - else: - zero_div_value = float(zero_division) - if average == "micro": numerator = torch.sum(numerator, dim=-1) denominator = torch.sum(denominator, dim=-1) - return _safe_divide(numerator, denominator, zero_division=zero_div_value) + return _safe_divide(numerator, denominator, zero_division="nan") - dice = _safe_divide(numerator, denominator, zero_division=zero_div_value) + dice = _safe_divide(numerator, denominator, zero_division="nan") if average == "macro": return torch.nanmean(dice, dim=-1) if average == "weighted": if not isinstance(support, torch.Tensor): raise ValueError(f"Expected argument `support` to be a tensor, got: {type(support)}.") - weights = _safe_divide(support, torch.sum(support, dim=-1, keepdim=True), zero_division=zero_div_value) + weights = _safe_divide(support, torch.sum(support, dim=-1, keepdim=True), zero_division="nan") nan_mask = dice.isnan().all(dim=-1) dice = torch.nansum(dice * weights, dim=-1) dice[nan_mask] = torch.nan @@ -124,7 +110,6 @@ def dice_score( average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", input_format: Literal["one-hot", "index", "mixed"] = "one-hot", aggregation_level: Optional[Literal["samplewise", "global"]] = "samplewise", - zero_division: Union[float, Literal["warn", "nan"]] = "nan", ) -> Tensor: """Compute the Dice score for semantic segmentation. @@ -141,8 +126,6 @@ def dice_score( aggregation_level: The level at which to aggregate the dice score. Options are ``"samplewise"`` or ``"global"``. For ``"samplewise"`` the dice score is computed for each sample and then averaged. For ``"global"`` the dice score is computed globally over all samples. - zero_division: The value to return when there is a division by zero. Options are 1.0, 0.0, "warn" or "nan". - Setting it to "warn" behaves like 0.0 but will also create a warning. Returns: The Dice score. @@ -193,13 +176,6 @@ def dice_score( " If you've explicitly set this parameter, you can ignore this warning.", UserWarning, ) - _dice_score_validate_args(num_classes, include_background, average, input_format, aggregation_level, zero_division) + _dice_score_validate_args(num_classes, include_background, average, input_format, aggregation_level) numerator, denominator, support = _dice_score_update(preds, target, num_classes, include_background, input_format) - return _dice_score_compute( - numerator, - denominator, - average, - aggregation_level=aggregation_level, - support=support, - zero_division=zero_division, - ) + return _dice_score_compute(numerator, denominator, average, aggregation_level=aggregation_level, support=support) From 451bd606697d0e1de551e4b68b185fc48806986b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 5 Oct 2025 16:54:03 +0000 Subject: [PATCH 14/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- README.md | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index fd00bcaa6bd..c77320656fd 100644 --- a/README.md +++ b/README.md @@ -39,13 +39,15 @@ ______________________________________________________________________ # Looking for GPUs? -Over 340,000 developers use [Lightning Cloud](https://lightning.ai/?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme) - purpose-built for PyTorch and PyTorch Lightning. -- [GPUs](https://lightning.ai/pricing?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme) from $0.19. -- [Clusters](https://lightning.ai/clusters?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): frontier-grade training/inference clusters. + +Over 340,000 developers use [Lightning Cloud](https://lightning.ai/?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme) - purpose-built for PyTorch and PyTorch Lightning. + +- [GPUs](https://lightning.ai/pricing?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme) from $0.19. +- [Clusters](https://lightning.ai/clusters?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): frontier-grade training/inference clusters. - [AI Studio (vibe train)](https://lightning.ai/studios?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): workspaces where AI helps you debug, tune and vibe train. -- [AI Studio (vibe deploy)](https://lightning.ai/studios?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): workspaces where AI helps you optimize, and deploy models. +- [AI Studio (vibe deploy)](https://lightning.ai/studios?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): workspaces where AI helps you optimize, and deploy models. - [Notebooks](https://lightning.ai/notebooks?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): Persistent GPU workspaces where AI helps you code and analyze. -- [Inference](https://lightning.ai/deploy?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): Deploy models as inference APIs. +- [Inference](https://lightning.ai/deploy?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): Deploy models as inference APIs. # Installation