Skip to content

Commit

Permalink
Add nn.Sequential.remove([index])
Browse files Browse the repository at this point in the history
  • Loading branch information
eliemichel committed Jun 9, 2015
1 parent 8929c5c commit bf8fa17
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 2 deletions.
12 changes: 11 additions & 1 deletion Sequential.lua
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,24 @@ end

function Sequential:insert(module, index)
index = index or (#self.modules + 1)
if index > (#self.modules + 1) then
if index > (#self.modules + 1) or index < 1 then
error"index should be contiguous to existing modules"
end
table.insert(self.modules, index, module)
self.output = self.modules[#self.modules].output
self.gradInput = self.modules[1].gradInput
end

function Sequential:remove(index)
index = index or #self.modules
if index > #self.modules or index < 1 then
error"index out of range"
end
table.remove(self.modules, index)
self.output = self.modules[#self.modules].output
self.gradInput = self.modules[1].gradInput
end

function Sequential:updateOutput(input)
local currentOutput = input
for i=1,#self.modules do
Expand Down
24 changes: 23 additions & 1 deletion doc/containers.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,27 @@ which gives the output:
[torch.Tensor of dimension 1]
```

<a name="nn.Sequential.insert"/>
<a name="nn.Sequential.remove"/>
### remove([index]) ###

Remove the module at the given `index`. If `index` is not specified, remove the last layer.

```lua
model = nn.Sequential()
model:add(nn.Linear(10, 20))
model:add(nn.Linear(20, 20))
model:add(nn.Linear(20, 30))
model:remove(2)
> model
nn.Sequential {
[input -> (1) -> (2) -> output]
(1): nn.Linear(10 -> 20)
(2): nn.Linear(20 -> 30)
}
```


<a name="nn.Sequential.remove"/>
### insert(module, [index]) ###

Inserts the given `module` at the given `index`. If `index` is not specified, the incremented length of the sequence is used and so this is equivalent to use `add(module)`.
Expand All @@ -70,6 +90,8 @@ nn.Sequential {
}
```



<a name="nn.Parallel"/>
## Parallel ##

Expand Down

0 comments on commit bf8fa17

Please sign in to comment.