From fe4080eed56006ce4d183f0a47a03a061cd7f9cc Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Sun, 22 May 2022 21:13:59 +0000 Subject: [PATCH] Fixed issue with total_weight in nll_loss_forward_decomposition --- functorch/csrc/BatchRulesLoss.cpp | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/functorch/csrc/BatchRulesLoss.cpp b/functorch/csrc/BatchRulesLoss.cpp index 043b5bb91..3a0d18756 100644 --- a/functorch/csrc/BatchRulesLoss.cpp +++ b/functorch/csrc/BatchRulesLoss.cpp @@ -203,16 +203,22 @@ std::tuple nll_loss_forward_decomposition( // target can be [N, 1, ...] or [1] auto result = -at::gather(self_, channel_dim, target_).squeeze(channel_dim); - auto total_weight = at::full( - {}, result.numel(), self_.scalar_type(), - self_.layout(), self_.device(), nullopt); bool has_ignore_index = ignore_index >= 0; - Tensor ignore_index_mask; + Tensor ignore_index_mask, total_weight; if (has_ignore_index) { ignore_index_mask = target != ignore_index; result = result * ignore_index_mask; - total_weight = ignore_index_mask.sum().to(self_); + if (!(reduction == Reduction::None && self.dim() >= 2)) { + total_weight = ignore_index_mask.sum().to(self_); + } + } + + if (!total_weight.defined()) { + auto init_value = (reduction == Reduction::None && self.dim() >= 2) ? 0.0 : 1.0 * result.numel(); + total_weight = at::full( + {}, init_value, self_.scalar_type(), + self_.layout(), self_.device(), nullopt); } // Apply the reduction