forked from Wei-1/Scala-Machine-Learning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
MCTS.scala
44 lines (41 loc) · 1.18 KB
/
MCTS.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
// Wei Chen - Monte Carlo Tree Search
// 2017-08-08
package com.scalaml.algorithm
class MCNode(val sim: Array[Double] => Double, val act: Array[Double] => Array[Array[Double]], val init: Array[Double], val parent: MCNode = null) {
var score: Double = sim(init)
var arr: Array[MCNode] = null // act(init).map(a => new MCNode(sim, act, a))
var check: Boolean = false
def best: MCNode = {
if (check) {
arr.maxBy(_.score).best
} else {
this
}
}
def expand: Unit = {
check = true
arr = act(init).map(a => new MCNode(sim, act, a, this))
}
def backpropagate: Unit = {
if (parent != null && parent.score > score) {
parent.score = score
parent.backpropagate
}
}
}
class MCTS {
def search(
sim: Array[Double] => Double,
act: Array[Double] => Array[Array[Double]],
init: Array[Double],
iter: Int
): MCNode = {
val tree: MCNode = new MCNode(sim, act, init)
for (i <- 0 until iter) {
val node = tree.best
node.backpropagate
node.expand
}
tree.arr.maxBy(_.score)
}
}