-
Notifications
You must be signed in to change notification settings - Fork 3
/
discrete-distribution.cc
239 lines (199 loc) · 6.45 KB
/
discrete-distribution.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
// C++ implementation of a fast algorithm for generating samples from a
// discrete distribution.
//
// David Pal, December 2015
//
// To compile the program run:
//
// g++ -Wall -Wextra -Werror -std=c++11 discrete-distribution.cc
#include <cmath>
#include <cassert>
#include <initializer_list>
#include <iostream>
#include <iterator>
#include <numeric>
#include <random>
#include <tuple>
#include <vector>
using std::cout;
using std::endl;
namespace {
// Stack that does not own the underlying storage.
template<typename T, typename BidirectionalIterator>
class stack_view {
public:
stack_view(const BidirectionalIterator base)
: base_(base), top_(base) { };
void push(const T& element) {
*top_ = element;
++top_;
}
T pop() {
--top_;
return *top_;
}
bool empty() {
return top_ == base_;
}
private:
const BidirectionalIterator base_;
BidirectionalIterator top_;
};
}
template<typename IntType = int>
class fast_discrete_distribution {
public:
typedef IntType result_type;
fast_discrete_distribution(const std::vector<double>& weights)
: uniform_distribution_(0.0, 1.0) {
normalize_weights(weights);
create_buckets();
}
result_type operator()(std::default_random_engine& generator) {
const double number = uniform_distribution_(generator);
size_t index = floor(buckets_.size() * number);
// Fix index. TODO: This probably not necessary?
if (index >= buckets_.size()) index = buckets_.size() - 1;
const Bucket& bucket = buckets_[index];
if (number < std::get<2>(bucket))
return std::get<0>(bucket);
else
return std::get<1>(bucket);
}
result_type min() const {
return static_cast<result_type>(0);
}
result_type max() const {
return probabilities_.empty()
? static_cast<result_type>(0)
: static_cast<result_type>(probabilities_.size() - 1);
}
std::vector<double> probabilities() const {
return probabilities_;
}
void reset() {
// Empty
}
void PrintBuckets() {
cout << "buckets.size() = " << buckets_.size() << endl;
for (auto bucket : buckets_) {
cout << std::get<0>(bucket) << " "
<< std::get<1>(bucket) << " "
<< std::get<2>(bucket) << " "
<< endl;
}
}
private:
// TODO: Figure out how to replace size_t in Segment with result_type.
// GCC 4.8.4 refuses to compile it.
typedef std::pair<double, size_t> Segment;
typedef std::tuple<result_type, result_type, double> Bucket;
void normalize_weights(const std::vector<double>& weights) {
const double sum = std::accumulate(weights.begin(), weights.end(), 0.0);
probabilities_.reserve(weights.size());
for (auto weight : weights) {
probabilities_.push_back(weight / sum);
}
}
void create_buckets() {
const size_t N = probabilities_.size();
if (N <= 0) {
buckets_.emplace_back(0, 0, 0.0);
return;
}
// Two stacks in one vector. First stack grows from the begining of the
// vector. The second stack grows from the end of the vector.
std::vector<Segment> segments(N);
stack_view<Segment, std::vector<Segment>::iterator>
small(segments.begin());
stack_view<Segment, std::vector<Segment>::reverse_iterator>
large(segments.rbegin());
// Split probabilities into small and large
result_type i = 0;
for (auto probability : probabilities_) {
if (probability < (1.0 / N)) {
small.push(Segment(probability, i));
} else {
large.push(Segment(probability, i));
}
++i;
}
buckets_.reserve(N);
i = 0;
while (!small.empty() && !large.empty()) {
const Segment s = small.pop();
const Segment l = large.pop();
// Create a mixed bucket
buckets_.emplace_back(s.second, l.second,
s.first + static_cast<double>(i) / N);
// Calculate the length of the left-over segment
const double left_over = s.first + l.first - static_cast<double>(1) / N;
// Re-insert the left-over segment
if (left_over < (1.0 / N))
small.push(Segment(left_over, l.second));
else
large.push(Segment(left_over, l.second));
++i;
}
// Create pure buckets
while (!large.empty()) {
const Segment l = large.pop();
// The last argument is irrelevant as long it's not a NaN.
buckets_.emplace_back(l.second, l.second, 0.0);
}
// This loop can be executed only due to numerical inaccuracies.
// TODO: Find an example when it actually happens.
while (!small.empty()) {
const Segment s = small.pop();
cout << "Here" << endl;
// The last argument is irrelevant as long it's not a NaN.
buckets_.emplace_back(s.second, s.second, 0.0);
}
}
// Uniform distribution over interval [0,1].
std::uniform_real_distribution<double> uniform_distribution_;
// List of probabilities
std::vector<double> probabilities_;
std::vector<Bucket> buckets_;
};
void Test(const std::vector<double>& weights, const size_t num_samples) {
std::default_random_engine generator;
fast_discrete_distribution<int> distribution(weights);
distribution.PrintBuckets();
std::vector<size_t> counts(weights.size(), 0);
for (size_t i = 0; i < num_samples; ++i) {
const int number = distribution(generator);
assert(number >= 0);
assert(number < static_cast<int>(weights.size()));
++counts[number];
}
std::cout << "counts:" << std::endl;
for (size_t i = 0; i < weights.size(); ++i)
cout << i << " (" << weights[i] << ") : "
<< std::string(counts[i], '*') << endl;
cout << endl;
}
void TestEmpty(const size_t num_samples) {
std::default_random_engine generator;
fast_discrete_distribution<int> distribution({});
distribution.PrintBuckets();
for (size_t i = 0; i < num_samples; ++i) {
const int number = distribution(generator);
assert(number == 0);
}
}
int main() {
TestEmpty(100);
Test({0}, 100);
Test({1}, 100);
Test({1, 1}, 200);
Test({1, 1, 1}, 300);
Test({1, 1, 2}, 300);
Test({1, 0, 2}, 300);
Test({20, 10, 30}, 300);
Test({0, 1e-20, 0}, 100);
Test({1 - 1e-10, 1 - 1e-10, 1 - 1e-10}, 100);
std::discrete_distribution<int> distribution({10.0, 20.0, 30.0});
cout << distribution << endl;
return 0;
}