Skip to content

Commit 0e28c54

Browse files
authored
Remove backpropagation beta/gamma. (#203)
1 parent 6f5a3c1 commit 0e28c54

File tree

4 files changed

+13
-26
lines changed

4 files changed

+13
-26
lines changed

src/mcts/node.cc

+2-6
Original file line numberDiff line numberDiff line change
@@ -176,13 +176,9 @@ bool Node::TryStartScoreUpdate() {
176176

177177
void Node::CancelScoreUpdate() { --n_in_flight_; }
178178

179-
void Node::FinalizeScoreUpdate(float v, float gamma, float beta) {
179+
void Node::FinalizeScoreUpdate(float v) {
180180
// Recompute Q.
181-
if (gamma == 1.0f && beta == 1.0f) {
182-
q_ += (v - q_) / (n_ + 1);
183-
} else {
184-
q_ += (v - q_) / (std::pow(static_cast<float>(n_), gamma) * beta + 1);
185-
}
181+
q_ += (v - q_) / (n_ + 1);
186182
// Increment N.
187183
++n_;
188184
// Decrement virtual loss.

src/mcts/node.h

+7-3
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ class Node {
166166
// * Q (weighted average of all V in a subtree)
167167
// * N (+=1)
168168
// * N-in-flight (-=1)
169-
void FinalizeScoreUpdate(float v, float gamma, float beta);
169+
void FinalizeScoreUpdate(float v);
170170

171171
// Updates max depth, if new depth is larger.
172172
void UpdateMaxDepth(int depth);
@@ -243,8 +243,12 @@ class EdgeAndNode {
243243
EdgeAndNode() = default;
244244
EdgeAndNode(Edge* edge, Node* node) : edge_(edge), node_(node) {}
245245
explicit operator bool() const { return edge_ != nullptr; }
246-
bool operator==(const EdgeAndNode& other) const { return edge_ == other.edge_; }
247-
bool operator!=(const EdgeAndNode& other) const { return edge_ != other.edge_; }
246+
bool operator==(const EdgeAndNode& other) const {
247+
return edge_ == other.edge_;
248+
}
249+
bool operator!=(const EdgeAndNode& other) const {
250+
return edge_ != other.edge_;
251+
}
248252
bool HasNode() const { return node_ != nullptr; }
249253
Edge* edge() const { return edge_; }
250254
Node* node() const { return node_; }

src/mcts/search.cc

+4-13
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,6 @@ const char* Search::kCacheHistoryLengthStr =
4747
const char* Search::kPolicySoftmaxTempStr = "Policy softmax temperature";
4848
const char* Search::kAllowedNodeCollisionsStr =
4949
"Allowed node collisions, per batch";
50-
const char* Search::kBackPropagateBetaStr = "Backpropagation gamma";
51-
const char* Search::kBackPropagateGammaStr = "Backpropagation beta";
5250

5351
namespace {
5452
const int kSmartPruningToleranceNodes = 100;
@@ -79,10 +77,6 @@ void Search::PopulateUciParams(OptionsParser* options) {
7977
"policy-softmax-temp") = 1.0f;
8078
options->Add<IntOption>(kAllowedNodeCollisionsStr, 0, 1024,
8179
"allowed-node-collisions") = 0;
82-
options->Add<FloatOption>(kBackPropagateBetaStr, 0.0f, 100.0f,
83-
"backpropagate-beta") = 1.0f;
84-
options->Add<FloatOption>(kBackPropagateGammaStr, -100.0f, 100.0f,
85-
"backpropagate-gamma") = 1.0f;
8680
}
8781

8882
Search::Search(const NodeTree& tree, Network* network,
@@ -109,9 +103,7 @@ Search::Search(const NodeTree& tree, Network* network,
109103
kFpuReduction(options.Get<float>(kFpuReductionStr)),
110104
kCacheHistoryLength(options.Get<int>(kCacheHistoryLengthStr)),
111105
kPolicySoftmaxTemp(options.Get<float>(kPolicySoftmaxTempStr)),
112-
kAllowedNodeCollisions(options.Get<int>(kAllowedNodeCollisionsStr)),
113-
kBackPropagateBeta(options.Get<float>(kBackPropagateBetaStr)),
114-
kBackPropagateGamma(options.Get<float>(kBackPropagateGammaStr)) {}
106+
kAllowedNodeCollisions(options.Get<int>(kAllowedNodeCollisionsStr)) {}
115107

116108
namespace {
117109
void ApplyDirichletNoise(Node* node, float eps, double alpha) {
@@ -889,8 +881,7 @@ void SearchWorker::DoBackupUpdate() {
889881
for (Node* n = node; n != search_->root_node_->GetParent();
890882
n = n->GetParent()) {
891883
++depth;
892-
n->FinalizeScoreUpdate(v, search_->kBackPropagateGamma,
893-
search_->kBackPropagateBeta);
884+
n->FinalizeScoreUpdate(v);
894885
// Q will be flipped for opponent.
895886
v = -v;
896887

@@ -903,8 +894,8 @@ void SearchWorker::DoBackupUpdate() {
903894
// Best move.
904895
if (n->GetParent() == search_->root_node_ &&
905896
search_->best_move_edge_.GetN() <= n->GetN()) {
906-
search_->best_move_edge_ =
907-
search_->GetBestChildNoTemperature(search_->root_node_);
897+
search_->best_move_edge_ =
898+
search_->GetBestChildNoTemperature(search_->root_node_);
908899
}
909900
}
910901
++search_->total_playouts_;

src/mcts/search.h

-4
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,6 @@ class Search {
9393
static const char* kCacheHistoryLengthStr;
9494
static const char* kPolicySoftmaxTempStr;
9595
static const char* kAllowedNodeCollisionsStr;
96-
static const char* kBackPropagateBetaStr;
97-
static const char* kBackPropagateGammaStr;
9896

9997
private:
10098
// Returns the best move, maybe with temperature (according to the settings).
@@ -167,8 +165,6 @@ class Search {
167165
const bool kCacheHistoryLength;
168166
const float kPolicySoftmaxTemp;
169167
const int kAllowedNodeCollisions;
170-
const float kBackPropagateBeta;
171-
const float kBackPropagateGamma;
172168

173169
friend class SearchWorker;
174170
};

0 commit comments

Comments
 (0)