diff --git a/ParallelCriterion.lua b/ParallelCriterion.lua index bee1f9ca9..95bd6ccc9 100644 --- a/ParallelCriterion.lua +++ b/ParallelCriterion.lua @@ -17,31 +17,19 @@ end function ParallelCriterion:updateOutput(input, target) self.output = 0 - if not self.repeatTarget then - for i,criterion in ipairs(self.criterions) do - self.output = self.output + self.weights[i]*criterion:updateOutput(input[i],target[i]) - end - else - for i,criterion in ipairs(self.criterions) do - self.output = self.output + self.weights[i]*criterion:updateOutput(input[i],target) - end + for i,criterion in ipairs(self.criterions) do + local target = self.repeatTarget and target or target[i] + self.output = self.output + self.weights[i]*criterion:updateOutput(input[i],target) end return self.output end function ParallelCriterion:updateGradInput(input, target) - if not self.repeatTarget then - for i,criterion in ipairs(self.criterions) do - self.gradInput[i] = input[i].new() or self.gradInput[i] - self.gradInput[i]:resizeAs(input[i]):zero() - self.gradInput[i]:add(self.weights[i], criterion:updateGradInput(input[i],target[i])) - end - else - for i,criterion in ipairs(self.criterions) do - self.gradInput[i] = input[i].new() or self.gradInput[i] - self.gradInput[i]:resizeAs(input[i]):zero() - self.gradInput[i]:add(self.weights[i], criterion:updateGradInput(input[i],target)) - end + for i,criterion in ipairs(self.criterions) do + local target = self.repeatTarget and target or target[i] + self.gradInput[i] = self.gradInput[i] or input[i].new() + self.gradInput[i]:resizeAs(input[i]):zero() + self.gradInput[i]:add(self.weights[i], criterion:updateGradInput(input[i],target)) end return self.gradInput end