-
Notifications
You must be signed in to change notification settings - Fork 13
/
group.lua
53 lines (43 loc) · 1.47 KB
/
group.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
function torch.group(sorted, index, tensor, samegrp, desc)
if not torch.isTensor(tensor) then
desc = tensor
samegrp = index
tensor = sorted
index = nil
sorted = nil
end
assert(torch.isTensor(tensor), 'expecting torch.Tensor for arg 3')
local sorted = sorted or tensor.new()
assert(torch.type(tensor) == torch.type(sorted), 'expecting torch.Tensor for arg 1 as same type as arg 2')
samegrp = samegrp or function(start_val, val)
return start_val == val
end
assert(torch.type(samegrp) == 'function', 'expecting function for arg 4')
index = index or torch.LongTensor()
assert(torch.type(index) == 'torch.LongTensor', 'expecting torch.LongTensor for arg 2')
if desc == nil then
desc = false
end
sorted:sort(index, tensor, desc)
local start_idx, start_val = 1, sorted[1]
local idx = 1
local groups = {}
sorted:apply(function(val)
if not samegrp(start_val, val) then
groups[start_val] = {
idx=index:narrow(1, start_idx, idx-start_idx),
val=sorted:narrow(1, start_idx, idx-start_idx)
}
start_val = val
start_idx = idx
end
idx = idx + 1
if idx-1 == sorted:size(1) then
groups[start_val] = {
idx=index:narrow(1, start_idx, idx-start_idx),
val=sorted:narrow(1, start_idx, idx-start_idx)
}
end
end)
return groups, sorted, index
end