-
Notifications
You must be signed in to change notification settings - Fork 8
/
distance_test.go
84 lines (73 loc) · 1.81 KB
/
distance_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
package gannoy
import (
"fmt"
"testing"
)
func TestAngularMargin(t *testing.T) {
angular := Angular{}
node := Node{v: []float64{1, 2, 3}}
y := []float64{1, 2, 3}
dot := angular.margin(node, y)
expect := 14.0
if dot != expect {
t.Errorf("Angular margin should return %d, but %d", expect, dot)
}
}
func TestAngularSide(t *testing.T) {
angular := Angular{}
// dot is plus (14.0)
node := Node{v: []float64{1, 2, 3}}
y := []float64{1, 2, 3}
if side := angular.side(node, y, RandRandom{}); side != 1 {
t.Errorf("Angular side should return 1, but %d", side)
}
// dot is minus (-14.0)
node = Node{v: []float64{1, 2, 3}}
y = []float64{-1, -2, -3}
if side := angular.side(node, y, RandRandom{}); side != 0 {
t.Errorf("Angular side should return 0, but %d", side)
}
}
func TestAngularDistance(t *testing.T) {
angular := Angular{}
x := []float64{1, 2, 3}
y := []float64{-1, -2, -3}
expect := 4.0
if distance := angular.distance(x, y); distance != expect {
t.Errorf("Angular distance should return %f, but %f.", expect, distance)
}
}
func TestAngularCreateSplit(t *testing.T) {
angular := Angular{}
nodes := []Node{
{v: []float64{0.1, 0.1}},
{v: []float64{1.1, 1.1}},
{v: []float64{0.1, 1.1}},
{v: []float64{1.1, 0.1}},
}
n := angular.createSplit(nodes, &TestLoopRandom{max: len(nodes)}, Node{})
expect := []string{"0.822251", "-0.569124"}
for i, v := range n.v {
if strv := fmt.Sprintf("%f", v); strv != expect[i] {
t.Errorf("Create split should return node.v %s, but %s", strv, expect[i])
}
}
}
type TestLoopRandom struct {
max int
current int
flipCurrent int
}
func (r *TestLoopRandom) index(n int) int {
index := r.current % r.max
r.current++
return index
}
func (r *TestLoopRandom) flip() int {
r.flipCurrent++
if r.flipCurrent%2 == 0 {
return 0
} else {
return 1
}
}