From ec12e408b524643852c04b536bb0b94f6a989570 Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Wed, 20 Nov 2024 11:02:25 +0100 Subject: [PATCH] removed redundant line in dice loss --- src/eva/vision/losses/dice.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/eva/vision/losses/dice.py b/src/eva/vision/losses/dice.py index 8e6133b3..d5d31d17 100644 --- a/src/eva/vision/losses/dice.py +++ b/src/eva/vision/losses/dice.py @@ -45,9 +45,6 @@ def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: inputs, targets = _apply_ignore_index(inputs, targets, self.ignore_index) targets = _to_one_hot(targets, num_classes=inputs.shape[1]) - if targets.ndim == 3: - targets = one_hot(targets[:, None, ...], num_classes=inputs.shape[1]) - return super().forward(inputs, targets)