-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathweightedrandtree.h
122 lines (97 loc) · 2.34 KB
/
weightedrandtree.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#pragma once
/*
Given weights for indexes, returns a random index according to the weights.
This is useful for softmax and similar, used in the rollout policy.
Most operations are O(log n).
*/
#include "xorshift.h"
//rounds to power of 2 sizes, completely arbitrary weights
// O(log n) updates, O(log n) choose
class WeightedRandTree {
mutable XORShift_float unitrand;
unsigned int size, allocsize;
float * weights;
public:
WeightedRandTree() : size(0), allocsize(0), weights(NULL) { }
WeightedRandTree(unsigned int s) : allocsize(0), weights(NULL) { resize(s); }
~WeightedRandTree(){
if(weights)
delete[] weights;
weights = NULL;
}
//round a number up to the nearest power of 2
static unsigned int roundup(unsigned int v) {
v--;
v |= v >> 1;
v |= v >> 2;
v |= v >> 4;
v |= v >> 8;
v |= v >> 16;
v++;
return v;
}
//resize and clear the tree
void resize(unsigned int s){
if(s < 2) s = 2;
size = roundup(s);
if(size > allocsize){
if(weights)
delete[] weights;
allocsize = size;
weights = new float[size*2];
}
clear();
}
//reset all weights to 0, O(s)
void clear(){
for(unsigned int i = 0; i < size*2; i++)
weights[i] = 0;
}
//get an individual weight, O(1)
float get_weight(unsigned int i) const {
return weights[i + size];
}
//get the sum of the weights in the tree, O(1)
float sum_weight() const {
return weights[1];
}
//rebuilds the tree based on the weights, O(s)
void rebuild_tree(){
for(unsigned int i = size-1; i >= 1; i--)
weights[i] = weights[2*i] + weights[2*i + 1];
}
//sets the weight but doesn't update the tree, needs to be fixed with rebuild_tree(), O(1)
void set_weight_fast(unsigned int i, float w){
weights[i+size] = w;
}
//sets the weight and updates the tree, O(log s)
void set_weight(unsigned int i, float w){
i += size;
if(weights[i] == w)
return;
weights[i] = w;
while(i /= 2)
weights[i] = weights[i*2] + weights[i*2 + 1];
}
//return a weighted random index, O(log s), infinite loop if sum = 0
unsigned int choose() const {
if(weights[1] <= 0.00001f)
return -1;
float r;
unsigned int i;
do{
r = unitrand() * weights[1];
i = 2;
while(i < size){
if(r > weights[i]){
r -= weights[i];
i++;
}
i *= 2;
}
if(r > weights[i])
i++;
}while(weights[i] <= 0.0001f);
return i - size;
}
};