-
Notifications
You must be signed in to change notification settings - Fork 4
/
main.go
60 lines (52 loc) · 1.21 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
package main
import (
"aixigo/agent/aixi"
"aixigo/env/grid"
"aixigo/mcts"
"aixigo/x"
"fmt"
)
func main() {
spec := [][]int{
{0, 0, 1, 1, 1},
{1, 0, 0, 1, 2},
{0, 1, 0, 1, 0},
{0, 1, 0, 1, 0},
{0, 1, 0, 0, 0},
}
env := grid.New(spec)
meta := mcts.NewMeta(grid.Meta, grid.NewModel(spec), 10000)
agent := &aixi.AImu{Meta: meta}
cycles := 100
fmt.Printf("Running for %d cycles with %d samples, using horizon %d\n",
cycles, meta.Samples, meta.Horizon)
log := run(agent, env, 100)
fmt.Printf("Agent's avg reward per cycle: %f\n", averageReward(log))
fmt.Printf("Optimal avg reward per cycle: %f\n",
float64(meta.MaxReward)*(float64(cycles)-10.0)/float64(cycles))
}
type trace struct {
Action x.Action
Observation x.Observation
Reward x.Reward
}
func run(agent x.Agent, env x.Environment, cycles int) []trace {
log := make([]trace, cycles, cycles)
var a x.Action
var o x.Observation
var r x.Reward
for iter := 0; iter < cycles; iter++ {
a = agent.GetAction()
o, r = env.Perform(a)
agent.Update(a, o, r)
log[iter] = trace{a, o, r}
}
return log
}
func averageReward(log []trace) float64 {
s := 0.0
for _, t := range log {
s += float64(t.Reward)
}
return s / float64(len(log))
}