-
Notifications
You must be signed in to change notification settings - Fork 2
/
trainer.go
42 lines (38 loc) · 951 Bytes
/
trainer.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
package automata
type Trainer struct {
Network Networker
LearnRate float64
Iterations int
MaxErrorRate float64
CostFunction Coster
}
type TrainSet struct {
Input []float64
Output []float64
}
func (t *Trainer) Train(trainingSet []TrainSet) error {
// TODO: Cross-validation support
for i := 0; i < t.Iterations; i++ {
errorSum, err := t.trainSet(trainingSet, t.LearnRate, t.CostFunction)
if err != nil {
return err
}
errRate := errorSum / float64(len(trainingSet))
if errRate < t.MaxErrorRate {
return nil
}
}
return nil
}
func (t *Trainer) trainSet(set []TrainSet, rate float64, coster Coster) (float64, error) {
var errorSum float64
for _, s := range set {
actualOutput, err := t.Network.Activate(s.Input)
if err != nil {
return 0, err
}
t.Network.Propagate(rate, s.Output)
errorSum += coster.Cost(s.Output, actualOutput)
}
return errorSum, nil
}