-
Notifications
You must be signed in to change notification settings - Fork 13
/
AliasMultinomial.lua
114 lines (93 loc) · 2.83 KB
/
AliasMultinomial.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
-- ref.: https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/
local AM = torch.class("torch.AliasMultinomial")
function AM:__init(probs)
self.J, self.q = self:setup(probs)
end
function AM:setup(probs)
assert(probs:dim() == 1)
local K = probs:nElement()
local q = probs.new(K):zero()
local J = torch.LongTensor(K):zero()
-- Sort the data into the outcomes with probabilities
-- that are larger and smaller than 1/K.
local smaller, larger = {}, {}
local maxk, maxp = 0, -1
for kk = 1,K do
local prob = probs[kk]
q[kk] = K*prob
if q[kk] < 1 then
table.insert(smaller, kk)
else
table.insert(larger, kk)
end
if maxk > maxp then
end
end
-- Loop through and create little binary mixtures that
-- appropriately allocate the larger outcomes over the
-- overall uniform mixture.
while #smaller > 0 and #larger > 0 do
local small = table.remove(smaller)
local large = table.remove(larger)
J[small] = large
q[large] = q[large] - (1.0 - q[small])
if q[large] < 1.0 then
table.insert(smaller,large)
else
table.insert(larger,large)
end
end
assert(q:min() >= 0)
if q:max() > 1 then
q:div(q:max())
end
assert(q:max() <= 1)
if J:min() <= 0 then
-- sometimes an large index isn't added to J.
-- fix it by making the probability 1 so that J isn't indexed.
local i = 0
J:apply(function(x)
i = i + 1
if x <= 0 then
q[i] = 1
end
end)
end
return J, q
end
function AM:draw()
J = self.J
q = self.q
local K = J:nElement()
-- Draw from the overall uniform mixture.
local kk = math.random(1,K)
-- Draw from the binary mixture, either keeping the
-- small one, or choosing the associated larger one.
if math.random() < q[kk] then
return kk
else
return J[kk]
end
end
function AM:batchdraw(output)
assert(torch.type(output) == 'torch.LongTensor')
assert(output:nElement() > 0)
local J = self.J
local K = J:nElement()
self._kk = self._kk or output.new()
self._kk:resizeAs(output):random(1,K)
self._q = self._q or self.q.new()
self._q:index(self.q, 1, self._kk:view(-1))
self._mask = self._b or torch.LongTensor()
self._mask:resize(self._q:size()):bernoulli(self._q)
self.__kk = self.__kk or output.new()
self.__kk:resize(self._kk:size()):copy(self._kk)
self.__kk:cmul(self._mask)
-- if mask == 0 then output[i] = J[kk[i]] else output[i] = 0
self._mask:add(-1):mul(-1) -- (1,0) - > (0,1)
output:view(-1):index(J, 1, self._kk:view(-1))
output:cmul(self._mask)
-- elseif mask == 1 then output[i] = kk[i]
output:add(self.__kk)
return output
end