diff --git a/utils/criterion.py b/utils/criterion.py index 37bd7036..0081c9c9 100644 --- a/utils/criterion.py +++ b/utils/criterion.py @@ -70,7 +70,10 @@ def _ohem_forward(self, score, target, **kwargs): tmp_target[tmp_target == self.ignore_label] = 0 pred = pred.gather(1, tmp_target.unsqueeze(1)) pred, ind = pred.contiguous().view(-1,)[mask].contiguous().sort() - min_value = pred[min(self.min_kept, pred.numel() - 1)] + if pred.numel() > 0: + min_value = pred[min(self.min_kept, pred.numel() - 1)] + else: + return score.new_tensor(0.0) threshold = max(min_value, self.thresh) pixel_losses = pixel_losses[mask][ind]