Skip to content

Commit

Permalink
Check for nn.Module and nn.Criterion in recursiveType.
Browse files Browse the repository at this point in the history
  • Loading branch information
dominikgrewe committed May 5, 2015
1 parent e555649 commit c5c63d0
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 23 deletions.
3 changes: 2 additions & 1 deletion Criterion.lua
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ function Criterion:clone()
end

function Criterion:type(type)
assert(type, 'Criterion: must provide a type to convert to')
-- find all tensors and convert them
for key,param in pairs(self) do
self[key] = nn._utils.recursiveType(param, type)
self[key] = nn.utils.recursiveType(param, type)
end
return self
end
Expand Down
7 changes: 0 additions & 7 deletions CrossEntropyCriterion.lua
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,4 @@ function CrossEntropyCriterion:updateGradInput(input, target)
return self.gradInput
end

function CrossEntropyCriterion:type(name)
Criterion.type(self, name)
self.lsm:type(name)
self.nll:type(name)
return self
end

return nn.CrossEntropyCriterion
13 changes: 1 addition & 12 deletions Module.lua
Original file line number Diff line number Diff line change
Expand Up @@ -117,18 +117,7 @@ function Module:type(type)
assert(type, 'Module: must provide a type to convert to')
-- find all tensors and convert them
for key,param in pairs(self) do
-- Many modules (like CDivTable) have output or gradInput fields which
-- are table's of tensors. To be general we need to recursively
-- cast fields that may be nested tables.
if key ~= 'modules' then
self[key] = nn._utils.recursiveType(param, type)
end
end
-- find submodules in classic containers 'modules'
if self.modules then
for _,module in ipairs(self.modules) do
module:type(type)
end
self[key] = nn.utils.recursiveType(param, type)
end
return self
end
Expand Down
9 changes: 6 additions & 3 deletions utils.lua
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
nn._utils = {}
nn.utils = {}

function nn._utils.recursiveType(param, type_str)
function nn.utils.recursiveType(param, type_str)
if torch.type(param) == 'table' then
for k, v in pairs(param) do
param[k] = nn._utils.recursiveType(v, type_str)
param[k] = nn.utils.recursiveType(v, type_str)
end
elseif torch.isTypeOf(param, 'nn.Module') or
torch.isTypeOf(param, 'nn.Criterion') then
param:type(type_str)
elseif torch.isTensor(param) then
param = param:type(type_str)
end
Expand Down

0 comments on commit c5c63d0

Please sign in to comment.