diff --git a/SpatialSubtractiveNormalization.lua b/SpatialSubtractiveNormalization.lua index e2da2c6a2..7fa440267 100644 --- a/SpatialSubtractiveNormalization.lua +++ b/SpatialSubtractiveNormalization.lua @@ -63,7 +63,8 @@ end function SpatialSubtractiveNormalization:updateOutput(input) -- compute side coefficients local dim = input:dim() - if input:dim()+1 ~= self.coef:dim() or (input:size(dim) ~= self.coef:size(dim)) or (input:size(dim-1) ~= self.coef:size(dim-1)) then + if not self._inpsz or not input:isSize(self._inpsz) then + self._inpsz = input:size() self.ones = self.ones or input.new() self._coef = self._coef or self.coef.new() if dim == 4 then