-
Notifications
You must be signed in to change notification settings - Fork 12
/
message.go
401 lines (339 loc) · 10.5 KB
/
message.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
package fastdns
import (
"errors"
"sync"
)
// Message represents an DNS request received by a server or to be sent by a client.
type Message struct {
// Raw refers to the raw query packet.
Raw []byte
// Domain represents to the parsed query domain in the query.
Domain []byte
// Header encapsulates the construct of the header part of the DNS query message.
// It follows the conventions stated at RFC1035 section 4.1.1.
Header struct {
// ID is an arbitrary 16bit request identifier that is
// forwarded back in the response so that we can match them up.
//
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// | ID |
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
ID uint16
// Flags is an arbitrary 16bit represents QR, Opcode, AA, TC, RD, RA, Z and RCODE.
//
// 0 1 2 3 4 5 6 7 8 9 A B C D E F
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// |QR| Opcode |AA|TC|RD|RA| Z | RCODE |
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
Flags Flags
// QDCOUNT specifies the number of entries in the question section
//
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// | QDCOUNT |
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
QDCount uint16
// ANCount specifies the number of resource records (RR) in the answer section
//
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// | ANCOUNT |
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
ANCount uint16
// NSCount specifies the number of name server resource records in the authority section
//
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// | NSCOUNT |
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
NSCount uint16
// ARCount specifies the number of resource records in the additional records section
//
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// | ARCOUNT |
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
ARCount uint16
}
// Question encapsulates the construct of the question part of the DNS query message.
// It follows the conventions stated at RFC1035 section 4.1.2.
Question struct {
// Name refers to the raw query name to be resolved in the query.
//
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// | |
// / QNAME /
// / /
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
Name []byte
// Type specifies the type of the query to perform.
//
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// | QTYPE |
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
Type Type
// Class specifies the class of the query to perform.
//
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// | QCLASS |
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
Class Class
}
}
var (
// ErrInvalidHeader is returned when dns message does not have the expected header size.
ErrInvalidHeader = errors.New("dns message does not have the expected header size")
// ErrInvalidQuestion is returned when dns message does not have the expected question size.
ErrInvalidQuestion = errors.New("dns message does not have the expected question size")
// ErrInvalidAnswer is returned when dns message does not have the expected answer size.
ErrInvalidAnswer = errors.New("dns message does not have the expected answer size")
)
// ParseMessage parses dns request from payload into dst and returns the error.
func ParseMessage(dst *Message, payload []byte, copying bool) error {
if copying {
dst.Raw = append(dst.Raw[:0], payload...)
payload = dst.Raw
}
if len(payload) < 12 {
return ErrInvalidHeader
}
// hint golang compiler remove ip bounds check
_ = payload[11]
// ID
dst.Header.ID = uint16(payload[0])<<8 | uint16(payload[1])
// RD, TC, AA, Opcode, QR, RA, Z, RCODE
dst.Header.Flags = Flags(payload[2])<<8 | Flags(payload[3])
// QDCOUNT, ANCOUNT, NSCOUNT, ARCOUNT
dst.Header.QDCount = uint16(payload[4])<<8 | uint16(payload[5])
dst.Header.ANCount = uint16(payload[6])<<8 | uint16(payload[7])
dst.Header.NSCount = uint16(payload[8])<<8 | uint16(payload[9])
dst.Header.ARCount = uint16(payload[10])<<8 | uint16(payload[11])
if dst.Header.QDCount != 1 {
return ErrInvalidHeader
}
// QNAME
payload = payload[12:]
var i int
var b byte
for i, b = range payload {
if b == 0 {
break
}
}
if i == 0 || i+5 > len(payload) {
return ErrInvalidQuestion
}
dst.Question.Name = payload[:i+1]
// QTYPE, QCLASS
payload = payload[i:]
dst.Question.Class = Class(uint16(payload[4]) | uint16(payload[3])<<8)
dst.Question.Type = Type(uint16(payload[2]) | uint16(payload[1])<<8)
// Domain
i = int(dst.Question.Name[0])
payload = append(dst.Domain[:0], dst.Question.Name[1:]...)
for i < len(payload) && payload[i] != 0 {
j := int(payload[i])
payload[i] = '.'
i += j + 1
}
dst.Domain = payload[:len(payload)-1]
return nil
}
// DecodeName decodes dns labels to dst.
func (msg *Message) DecodeName(dst []byte, name []byte) []byte {
if len(name) < 2 {
return dst
}
// fast path for domain pointer
if name[1] == 12 && name[0] == 0b11000000 {
return append(dst, msg.Domain...)
}
pos := len(dst)
var offset int
if name[len(name)-1] == 0 {
dst = append(dst, name...)
} else {
dst = append(dst, name[:len(name)-2]...)
offset = int(name[len(name)-2]&0b00111111)<<8 + int(name[len(name)-1])
}
for offset != 0 {
for i := offset; i < len(msg.Raw); {
b := int(msg.Raw[i])
if b == 0 {
offset = 0
dst = append(dst, 0)
break
} else if b&0b11000000 == 0b11000000 {
offset = int(b&0b00111111)<<8 + int(msg.Raw[i+1])
break
} else {
dst = append(dst, msg.Raw[i:i+b+1]...)
i += b + 1
}
}
}
n := pos
for dst[pos] != 0 {
i := int(dst[pos])
dst[pos] = '.'
pos += i + 1
}
if n == 0 {
dst = dst[1 : len(dst)-1]
} else {
dst = append(dst[:n], dst[n+1:len(dst)-1]...)
}
return dst
}
type MessageRecord struct {
Name []byte
Type Type
Class Class
TTL uint32
Data []byte
}
// Walk calls f for each item in the msg in the original order of the parsed RR.
func (msg *Message) Records(f func(MessageRecord) bool) {
n := msg.Header.ANCount + msg.Header.NSCount
if n == 0 {
return
}
payload := msg.Raw[16+len(msg.Question.Name):]
for i := uint16(0); i < n; i++ {
var name []byte
for j, b := range payload {
if b&0b11000000 == 0b11000000 {
name = payload[:j+2]
payload = payload[j+2:]
break
} else if b == 0 {
name = payload[:j+1]
payload = payload[j+1:]
break
}
}
if name == nil {
return
}
_ = payload[9] // hint compiler to remove bounds check
typ := Type(payload[0])<<8 | Type(payload[1])
class := Class(payload[2])<<8 | Class(payload[3])
ttl := uint32(payload[4])<<24 | uint32(payload[5])<<16 | uint32(payload[6])<<8 | uint32(payload[7])
length := uint16(payload[8])<<8 | uint16(payload[9])
data := payload[10 : 10+length]
payload = payload[10+length:]
ok := f(MessageRecord{Name: name, Type: typ, Class: class, TTL: ttl, Data: data})
if !ok {
break
}
}
}
// WalkAdditionalRecords calls f for each item in the msg in the original order of the parsed AR.
func (msg *Message) AdditionalRecords(f func(MessageRecord) bool) {
panic("not implemented")
}
// SetRequestQuestion set question for DNS request.
func (msg *Message) SetRequestQuestion(domain string, typ Type, class Class) {
// random head id
msg.Header.ID = uint16(cheaprandn(65536))
// QR = 0, RCODE = 0, RD = 1
//
// 0 1 2 3 4 5 6 7 8 9 A B C D E F
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// |QR| Opcode |AA|TC|RD|RA| Z | RCODE |
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
msg.Header.Flags &= 0b0111111111110000
msg.Header.Flags |= 0b0000000100000000
msg.Header.QDCount = 1
msg.Header.ANCount = 0
msg.Header.NSCount = 0
msg.Header.ARCount = 0
header := [...]byte{
// ID
byte(msg.Header.ID >> 8), byte(msg.Header.ID),
// Flags
byte(msg.Header.Flags >> 8), byte(msg.Header.Flags),
// QDCOUNT, ANCOUNT, NSCOUNT, ARCOUNT
0, 1, 0, 0, 0, 0, 0, 0,
}
msg.Raw = append(msg.Raw[:0], header[:]...)
// QNAME
msg.Raw = EncodeDomain(msg.Raw, domain)
msg.Question.Name = msg.Raw[len(header) : len(header)+len(domain)+2]
// QTYPE
msg.Raw = append(msg.Raw, byte(typ>>8), byte(typ))
msg.Question.Type = typ
// QCLASS
msg.Raw = append(msg.Raw, byte(class>>8), byte(class))
msg.Question.Class = class
// Domain
msg.Domain = append(msg.Domain[:0], domain...)
}
// SetResponseHeader sets QR=1, RCODE=rcode, ANCount=ancount then updates Raw.
func (msg *Message) SetResponseHeader(rcode Rcode, ancount uint16) {
// QR = 1, RCODE = rcode
//
// 0 1 2 3 4 5 6 7 8 9 A B C D E F
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// |QR| Opcode |AA|TC|RD|RA| Z | RCODE |
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
msg.Header.Flags &= 0b1111111111110000
msg.Header.Flags |= 0b1000000000000000 | Flags(rcode)
// Error
if rcode != RcodeNoError {
msg.Header.QDCount = 0
msg.Header.ANCount = 0
msg.Header.NSCount = 0
msg.Header.ARCount = 0
msg.Raw = msg.Raw[:12]
// Flags
msg.Raw[2] = byte(msg.Header.Flags >> 8)
msg.Raw[3] = byte(msg.Header.Flags)
// QDCount
msg.Raw[4] = 0
msg.Raw[5] = 0
// ANCOUNT
msg.Raw[6] = 0
msg.Raw[7] = 0
// NSCOUNT
msg.Raw[8] = 0
msg.Raw[9] = 0
// ARCOUNT
msg.Raw[10] = 0
msg.Raw[11] = 0
return
}
msg.Header.QDCount = 1
msg.Header.ANCount = ancount
msg.Header.NSCount = 0
msg.Header.ARCount = 0
msg.Raw = msg.Raw[:12+len(msg.Question.Name)+4]
header := msg.Raw[:12]
// Flags
header[2] = byte(msg.Header.Flags >> 8)
header[3] = byte(msg.Header.Flags)
// QDCount
header[4] = 0
header[5] = 1
// ANCOUNT
header[6] = byte(ancount >> 8)
header[7] = byte(ancount)
// NSCOUNT
header[8] = 0
header[9] = 0
// ARCOUNT
header[10] = 0
header[11] = 0
}
var msgPool = sync.Pool{
New: func() interface{} {
msg := new(Message)
msg.Raw = make([]byte, 0, 1024)
msg.Domain = make([]byte, 0, 256)
return msg
},
}
// AcquireMessage returns new dns request.
func AcquireMessage() *Message {
return msgPool.Get().(*Message)
}
// ReleaseMessage returnes the dns request to the pool.
func ReleaseMessage(msg *Message) {
msgPool.Put(msg)
}