-
Notifications
You must be signed in to change notification settings - Fork 13
/
concat.lua
47 lines (41 loc) · 1.27 KB
/
concat.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
-- torch.concat([res], tensors, [dim])
function torch.concat(result, tensors, dim, index)
index = index or 1
if type(result) == 'table' then
index = dim or 1
dim = tensors
tensors = result
result = tensors[index].new()
end
assert(type(tensors) == 'table', "expecting table at arg 2")
dim = dim or 1
local size
for i,tensor in ipairs(tensors) do
assert(torch.isTensor(tensor), "Expecting table of torch.Tensors at arg 2 : "..torch.type(tensor))
if not size then
size = tensor:size():totable()
size[dim] = 0
end
for j,v in ipairs(tensor:size():totable()) do
if j == dim then
size[j] = (size[j] or 0) + v
else
if size[j] and size[j] ~= v then
error(
"Cannot concat dim "..j.." with different sizes: "..
(size[j] or 'nil').." ~= "..(v or 'nil')..
" for tensor at index "..i, 2
)
end
end
end
end
result:resize(unpack(size))
local start = 1
for i, tensor in ipairs(tensors) do
result:narrow(dim, start, tensor:size(dim)):copy(tensor)
start = start+tensor:size(dim)
end
return result
end
torchx.Tensor.concat = torch.concat