Skip to content

Commit

Permalink
Add support for negative indices in nn.SplitTable
Browse files Browse the repository at this point in the history
This module can now be used when the total number of dimensions is unknown.
  • Loading branch information
sergomezcol committed Jun 3, 2015
1 parent fdd6659 commit 350de82
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 8 deletions.
21 changes: 13 additions & 8 deletions SplitTable.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,21 @@ function SplitTable:__init(dimension, nInputDims)
self.nInputDims = nInputDims
end

function SplitTable:updateOutput(input)
function SplitTable:_getPositiveDimension(input)
local dimension = self.dimension
if self.nInputDims and input:dim()==(self.nInputDims+1) then
dimension = dimension + 1
if dimension < 0 then
dimension = input:dim() + dimension + 1
elseif self.nInputDims and input:dim()==(self.nInputDims+1) then
dimension = dimension + 1
end
local currentOutput= {}
return dimension
end

function SplitTable:updateOutput(input)
local dimension = self:_getPositiveDimension(input)
local slices = input:size(dimension)

local currentOutput= {}
for i=1,slices do
currentOutput[#currentOutput+1] = input:select(dimension,i)
end
Expand All @@ -21,10 +29,7 @@ function SplitTable:updateOutput(input)
end

function SplitTable:updateGradInput(input, gradOutput)
local dimension = self.dimension
if self.nInputDims and input:dim()==(self.nInputDims+1) then
dimension = dimension + 1
end
local dimension = self:_getPositiveDimension(input)
local slices = input:size(dimension)
self.gradInput:resizeAs(input)

Expand Down
47 changes: 47 additions & 0 deletions doc/table.md
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,53 @@ gives the output:
[torch.DoubleTensor of dimension 3]
```

The module also supports indexing from the end using negative dimensions. This allows to use this module when the number of dimensions of the input is unknown.

### Example

```lua
m = nn.SplitTable(-2)
out = m:forward(torch.randn(3, 2))
for i, k in ipairs(out) do print(i, k) end
out = m:forward(torch.randn(1, 3, 2))
for i, k in ipairs(out) do print(i, k) end
```

gives the output:

```
1
0.1420
-0.5698
[torch.DoubleTensor of size 2]
2
0.1663
0.1197
[torch.DoubleTensor of size 2]
3
0.4198
-1.1394
[torch.DoubleTensor of size 2]
1
-2.4941
-1.4541
[torch.DoubleTensor of size 1x2]
2
0.4594
1.1946
[torch.DoubleTensor of size 1x2]
3
-2.3322
-0.7383
[torch.DoubleTensor of size 1x2]
```

### A more complicated example

```lua
Expand Down
7 changes: 7 additions & 0 deletions test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2928,6 +2928,13 @@ function nntest.SplitTable()
module = nn.SplitTable(d, 2)
mytester:asserteq(#module:forward(input), input:size(d+1), "dimension " .. d)
end

-- Negative indices
local module = nn.SplitTable(-3)
local input = torch.randn(3,4,5)
mytester:asserteq(#module:forward(input), 3, "negative index")
local input = torch.randn(2,3,4,5)
mytester:asserteq(#module:forward(input), 3, "negative index (minibatch)")
end

function nntest.SelectTable()
Expand Down

0 comments on commit 350de82

Please sign in to comment.