-
Notifications
You must be signed in to change notification settings - Fork 0
/
dnstun_test.go
81 lines (69 loc) · 1.9 KB
/
dnstun_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
package dnstun
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
plugintest "github.com/coredns/coredns/plugin/test"
"github.com/miekg/dns"
)
type TestPredictor struct {
resp PredictResponse
err error
}
func (p *TestPredictor) Handle(w http.ResponseWriter, r *http.Request) {
b, err := json.Marshal(p.resp)
if err != nil || p.err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Write(b)
}
type TestResponseWriter struct {
plugintest.ResponseWriter
m *dns.Msg
}
func (rw *TestResponseWriter) WriteMsg(m *dns.Msg) error {
rw.m = m
return rw.ResponseWriter.WriteMsg(m)
}
func TestDnstunServeDNS(t *testing.T) {
tests := []struct {
predictor TestPredictor
rcode int
err bool
}{
{TestPredictor{resp: PredictResponse{Predictions: [][]float64{{0.2, 0.8}}}}, dns.RcodeRefused, false},
{TestPredictor{resp: PredictResponse{Predictions: [][]float64{{0.7, 0.3}}}}, dns.RcodeSuccess, false},
{TestPredictor{err: errors.New("err")}, dns.RcodeServerFailure, true},
}
for _, tt := range tests {
t.Run("", func(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(tt.predictor.Handle))
defer s.Close()
defer s.CloseClientConnections()
d := NewDnstun(Options{
Mapping: MappingReverse,
Runtime: strings.TrimLeft(s.URL, "http://"),
})
req := plugintest.Case{Qname: "tunnel.example.org", Qtype: dns.TypeCNAME}
rw := new(TestResponseWriter)
rcode, err := d.ServeDNS(context.TODO(), rw, req.Msg())
if rcode != tt.rcode {
t.Errorf("rcode is wrong: %v != %v", rcode, tt.rcode)
}
if err != nil && !tt.err {
t.Errorf("error returned: %v", err)
}
if tt.rcode == dns.RcodeRefused && rw.m == nil {
t.Fatalf("message is not written")
}
if tt.rcode == dns.RcodeRefused && rw.m.Rcode != tt.rcode {
t.Errorf("wrong rcode in response %v != %v", rw.m.Rcode, tt.rcode)
}
})
}
}