diff --git a/Sequential.lua b/Sequential.lua index b08f7df37..359a764d1 100644 --- a/Sequential.lua +++ b/Sequential.lua @@ -15,7 +15,7 @@ end function Sequential:insert(module, index) index = index or (#self.modules + 1) - if index > (#self.modules + 1) then + if index > (#self.modules + 1) or index < 1 then error"index should be contiguous to existing modules" end table.insert(self.modules, index, module) @@ -23,6 +23,16 @@ function Sequential:insert(module, index) self.gradInput = self.modules[1].gradInput end +function Sequential:remove(index) + index = index or #self.modules + if index > #self.modules or index < 1 then + error"index out of range" + end + table.remove(self.modules, index) + self.output = self.modules[#self.modules].output + self.gradInput = self.modules[1].gradInput +end + function Sequential:updateOutput(input) local currentOutput = input for i=1,#self.modules do diff --git a/doc/containers.md b/doc/containers.md index 0b621db97..ae4f8bec3 100644 --- a/doc/containers.md +++ b/doc/containers.md @@ -51,7 +51,27 @@ which gives the output: [torch.Tensor of dimension 1] ``` - + +### remove([index]) ### + +Remove the module at the given `index`. If `index` is not specified, remove the last layer. + +```lua +model = nn.Sequential() +model:add(nn.Linear(10, 20)) +model:add(nn.Linear(20, 20)) +model:add(nn.Linear(20, 30)) +model:remove(2) +> model +nn.Sequential { + [input -> (1) -> (2) -> output] + (1): nn.Linear(10 -> 20) + (2): nn.Linear(20 -> 30) +} +``` + + + ### insert(module, [index]) ### Inserts the given `module` at the given `index`. If `index` is not specified, the incremented length of the sequence is used and so this is equivalent to use `add(module)`. @@ -70,6 +90,8 @@ nn.Sequential { } ``` + + ## Parallel ##