-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathactivation.go
90 lines (75 loc) · 1.93 KB
/
activation.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
package neural
import (
"math"
)
// ForwardFn is used to think
type ForwardFn func(sum float64) float64
// BackwardFn is used to learn (derivative of forward)
type BackwardFn func(activation float64) float64
// LinearForward is the linear fn
func LinearForward(sum float64) float64 {
return sum
}
// LinearBackward is the linear derivative
func LinearBackward(activation float64) float64 {
return 1.0
}
// SigmoidForward is the sigmoid fn
func SigmoidForward(sum float64) float64 {
return 1.0 / (1.0 + math.Exp(-sum))
}
// SigmoidBackward is the sigmoid derivative
func SigmoidBackward(activation float64) float64 {
return activation * (1.0 - activation)
}
// TanhForward is the tanh fn
func TanhForward(sum float64) float64 {
return math.Tanh(sum)
}
// TanhBackward is the tanh derivative
func TanhBackward(activation float64) float64 {
return 1 - activation*activation
}
// ReluForward is the relu fn
func ReluForward(sum float64) float64 {
if sum < 0.0 {
return 0.0
}
return sum
}
// ReluBackward is the relu derivative
func ReluBackward(activation float64) float64 {
if activation <= 0.0 {
return 0.0
}
return 1.0
}
// ActivationSet is a forward and backward fn with its range
type ActivationSet struct {
Forward ForwardFn
Backward BackwardFn
// Range of the activation
Ranges []float64
}
func selectActivation(activation string) ActivationSet {
set := ActivationSet{}
if activation == "linear" {
set.Forward = LinearForward
set.Backward = LinearBackward
} else if activation == "" || activation == "sigmoid" {
set.Forward = SigmoidForward
set.Backward = SigmoidBackward
set.Ranges = []float64{0.0, 1.0}
} else if activation == "tanh" {
set.Forward = TanhForward
set.Backward = TanhBackward
set.Ranges = []float64{-1.0, 1.0}
} else if activation == "relu" {
set.Forward = ReluForward
set.Backward = ReluBackward
set.Ranges = []float64{0.0, 1.0}
} else {
panic("need a valid activation name")
}
return set
}