diff --git a/Criterion.lua b/Criterion.lua index 0f6e41b71..f6e0d82ab 100644 --- a/Criterion.lua +++ b/Criterion.lua @@ -29,9 +29,10 @@ function Criterion:clone() end function Criterion:type(type) + assert(type, 'Criterion: must provide a type to convert to') -- find all tensors and convert them for key,param in pairs(self) do - self[key] = nn._utils.recursiveType(param, type) + self[key] = nn.utils.recursiveType(param, type) end return self end diff --git a/CrossEntropyCriterion.lua b/CrossEntropyCriterion.lua index 2b3c78c28..d4d19e5cb 100644 --- a/CrossEntropyCriterion.lua +++ b/CrossEntropyCriterion.lua @@ -25,11 +25,4 @@ function CrossEntropyCriterion:updateGradInput(input, target) return self.gradInput end -function CrossEntropyCriterion:type(name) - Criterion.type(self, name) - self.lsm:type(name) - self.nll:type(name) - return self -end - return nn.CrossEntropyCriterion diff --git a/Module.lua b/Module.lua index d6b16fbca..be7896c8a 100644 --- a/Module.lua +++ b/Module.lua @@ -117,18 +117,7 @@ function Module:type(type) assert(type, 'Module: must provide a type to convert to') -- find all tensors and convert them for key,param in pairs(self) do - -- Many modules (like CDivTable) have output or gradInput fields which - -- are table's of tensors. To be general we need to recursively - -- cast fields that may be nested tables. - if key ~= 'modules' then - self[key] = nn._utils.recursiveType(param, type) - end - end - -- find submodules in classic containers 'modules' - if self.modules then - for _,module in ipairs(self.modules) do - module:type(type) - end + self[key] = nn.utils.recursiveType(param, type) end return self end diff --git a/utils.lua b/utils.lua index 887b82d5f..f5f4b73b9 100644 --- a/utils.lua +++ b/utils.lua @@ -1,10 +1,13 @@ -nn._utils = {} +nn.utils = {} -function nn._utils.recursiveType(param, type_str) +function nn.utils.recursiveType(param, type_str) if torch.type(param) == 'table' then for k, v in pairs(param) do - param[k] = nn._utils.recursiveType(v, type_str) + param[k] = nn.utils.recursiveType(v, type_str) end + elseif torch.isTypeOf(param, 'nn.Module') or + torch.isTypeOf(param, 'nn.Criterion') then + param:type(type_str) elseif torch.isTensor(param) then param = param:type(type_str) end