Skip to content

Commit

Permalink
switch to range func
Browse files Browse the repository at this point in the history
  • Loading branch information
phuslu committed Oct 23, 2024
1 parent 387014c commit 2e9745f
Show file tree
Hide file tree
Showing 9 changed files with 90 additions and 96 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ jobs:
benchmark:
runs-on: ubuntu-latest
steps:
- uses: actions/setup-go@v2
- uses: actions/setup-go@v4
with:
go-version: '1.22.1'
go-version: '1.23.2'
- name: Benchmark
run: |
set -ex
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/setup-go@v2
- uses: actions/setup-go@v4
with:
go-version: '1.22.1'
go-version: '1.23.2'
- name: Build
run: |
set -ex
Expand All @@ -21,5 +21,5 @@ jobs:
go test -v -cover
go build -v -race
(cd cmd/fastdig && go build -v)
curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s v1.56.2
curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s v1.61.0
./bin/golangci-lint run
69 changes: 30 additions & 39 deletions client_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,16 @@ func (c *Client) AppendLookupNetIP(dst []netip.Addr, ctx context.Context, networ
}

cname := make([]byte, 0, 64)

_ = resp.Walk(func(name []byte, typ Type, class Class, ttl uint32, data []byte) bool {
switch typ {
for r := range resp.Records {
switch r.Type {
case TypeCNAME:
cname = resp.DecodeName(cname[:0], data)
cname = resp.DecodeName(cname[:0], r.Data)
case TypeA:
dst = append(dst, netip.AddrFrom4(*(*[4]byte)(data)))
dst = append(dst, netip.AddrFrom4(*(*[4]byte)(r.Data)))
case TypeAAAA:
dst = append(dst, netip.AddrFrom16(*(*[16]byte)(data)))
dst = append(dst, netip.AddrFrom16(*(*[16]byte)(r.Data)))
}
return true
})
}

if len(cname) != 0 && len(dst) == 0 {
dst, err = c.AppendLookupNetIP(dst, ctx, network, b2s(cname))
Expand All @@ -78,16 +76,14 @@ func (c *Client) LookupCNAME(ctx context.Context, host string) (cname string, er
return
}

_ = resp.Walk(func(name []byte, typ Type, class Class, ttl uint32, data []byte) bool {
switch typ {
for r := range resp.Records {
switch r.Type {
case TypeCNAME:
cname = string(resp.DecodeName(nil, data))
return false
cname = string(resp.DecodeName(nil, r.Data))
default:
err = ErrInvalidAnswer
}
return true
})
}

return
}
Expand All @@ -107,17 +103,16 @@ func (c *Client) LookupNS(ctx context.Context, name string) (ns []*net.NS, err e

soa := make([]byte, 0, 64)

_ = resp.Walk(func(name []byte, typ Type, class Class, ttl uint32, data []byte) bool {
switch typ {
for r := range resp.Records {
switch r.Type {
case TypeSOA:
soa = resp.DecodeName(soa[:0], name)
soa = resp.DecodeName(soa[:0], r.Name)
case TypeNS:
ns = append(ns, &net.NS{Host: string(resp.DecodeName(nil, data))})
ns = append(ns, &net.NS{Host: string(resp.DecodeName(nil, r.Data))})
default:
err = ErrInvalidAnswer
}
return true
})
}

if len(soa) != 0 {
ns, err = c.LookupNS(ctx, b2s(soa))
Expand All @@ -139,19 +134,18 @@ func (c *Client) LookupTXT(ctx context.Context, host string) (txt []string, err
return
}

_ = resp.Walk(func(name []byte, typ Type, class Class, ttl uint32, data []byte) bool {
switch typ {
for r := range resp.Records {
switch r.Type {
case TypeTXT:
if len(data) > 1 && int(data[0])+1 == len(data) {
txt = append(txt, string(data[1:]))
if len(r.Data) > 1 && int(r.Data[0])+1 == len(r.Data) {
txt = append(txt, string(r.Data[1:]))
} else {
err = ErrInvalidAnswer
}
default:
err = ErrInvalidAnswer
}
return true
})
}

return
}
Expand All @@ -169,16 +163,14 @@ func (c *Client) LookupMX(ctx context.Context, host string) (mx []*net.MX, err e
return
}

_ = resp.Walk(func(name []byte, typ Type, class Class, ttl uint32, data []byte) bool {
switch typ {
case TypeMX:
for r := range resp.Records {
if r.Type == TypeMX {
mx = append(mx, &net.MX{
Host: string(resp.DecodeName(nil, data[2:])),
Pref: binary.BigEndian.Uint16(data),
Host: string(resp.DecodeName(nil, r.Data[2:])),
Pref: binary.BigEndian.Uint16(r.Data),
})
}
return true
})
}

return
}
Expand All @@ -196,12 +188,12 @@ func (c *Client) LookupHTTPS(ctx context.Context, host string) (https []NetHTTPS
return
}

_ = resp.Walk(func(name []byte, typ Type, class Class, ttl uint32, data []byte) bool {
switch typ {
case TypeHTTPS:
for r := range resp.Records {
if r.Type == TypeHTTPS {
var h NetHTTPS
data := r.Data
if len(data) < 7 {
return true
return nil, ErrInvalidAnswer
}
data = data[3:]
for len(data) >= 4 {
Expand Down Expand Up @@ -243,8 +235,7 @@ func (c *Client) LookupHTTPS(ctx context.Context, host string) (https []NetHTTPS
}
https = append(https, h)
}
return true
})
}

return
}
13 changes: 6 additions & 7 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,16 @@ func TestClientExchange(t *testing.T) {
t.Errorf("client=%+v exchange(%v) error: %+v\n", client, c.Domain, err)
}
t.Logf("%s: CLASS %s TYPE %s\n", resp.Domain, resp.Question.Class, resp.Question.Type)
_ = resp.Walk(func(name []byte, typ Type, class Class, ttl uint32, data []byte) bool {
switch typ {
for r := range resp.Records {
switch r.Type {
case TypeCNAME:
t.Logf("%s.\t%d\t%s\t%s\t%s.\n", resp.DecodeName(nil, name), ttl, class, typ, resp.DecodeName(nil, data))
t.Logf("%s.\t%d\t%s\t%s\t%s.\n", resp.DecodeName(nil, r.Name), r.TTL, r.Class, r.Type, resp.DecodeName(nil, r.Data))
case TypeA:
t.Logf("%s.\t%d\t%s\t%s\t%s\n", resp.DecodeName(nil, name), ttl, class, typ, netip.AddrFrom4(*(*[4]byte)(data)))
t.Logf("%s.\t%d\t%s\t%s\t%s\n", resp.DecodeName(nil, r.Name), r.TTL, r.Class, r.Type, netip.AddrFrom4(*(*[4]byte)(r.Data)))
case TypeAAAA:
t.Logf("%s.\t%d\t%s\t%s\t%s\n", resp.DecodeName(nil, name), ttl, class, typ, netip.AddrFrom16(*(*[16]byte)(data)))
t.Logf("%s.\t%d\t%s\t%s\t%s\n", resp.DecodeName(nil, r.Name), r.TTL, r.Class, r.Type, netip.AddrFrom16(*(*[16]byte)(r.Data)))
}
return true
})
}
}
}

Expand Down
57 changes: 28 additions & 29 deletions cmd/fastdig/fastdig.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,48 +103,47 @@ func opt(option string, options []string) bool {
}

func short(resp *fastdns.Message) {
_ = resp.Walk(func(name []byte, typ fastdns.Type, class fastdns.Class, ttl uint32, data []byte) bool {
for r := range resp.Records {
var v interface{}
switch typ {
switch r.Type {
case fastdns.TypeA, fastdns.TypeAAAA:
v, _ = netip.AddrFromSlice(data)
v, _ = netip.AddrFromSlice(r.Data)
case fastdns.TypeCNAME, fastdns.TypeNS:
v = fmt.Sprintf("%s.", resp.DecodeName(nil, data))
v = fmt.Sprintf("%s.", resp.DecodeName(nil, r.Data))
case fastdns.TypeMX:
v = fmt.Sprintf("%d %s.", binary.BigEndian.Uint16(data), resp.DecodeName(nil, data[2:]))
v = fmt.Sprintf("%d %s.", binary.BigEndian.Uint16(r.Data), resp.DecodeName(nil, r.Data[2:]))
case fastdns.TypeTXT:
v = fmt.Sprintf("\"%s\"", data[1:])
v = fmt.Sprintf("\"%s\"", r.Data[1:])
case fastdns.TypeSRV:
priority := binary.BigEndian.Uint16(data)
weight := binary.BigEndian.Uint16(data[2:])
port := binary.BigEndian.Uint16(data[4:])
target := resp.DecodeName(nil, data[6:])
priority := binary.BigEndian.Uint16(r.Data)
weight := binary.BigEndian.Uint16(r.Data[2:])
port := binary.BigEndian.Uint16(r.Data[4:])
target := resp.DecodeName(nil, r.Data[6:])
v = fmt.Sprintf("%d %d %d %s.", priority, weight, port, target)
case fastdns.TypeSOA:
var mname []byte
for i, b := range data {
for i, b := range r.Data {
if b == 0 {
mname = data[:i+1]
mname = r.Data[:i+1]
break
} else if b&0b11000000 == 0b11000000 {
mname = data[:i+2]
mname = r.Data[:i+2]
break
}
}
nname := resp.DecodeName(nil, data[len(mname):len(data)-20])
nname := resp.DecodeName(nil, r.Data[len(mname):len(r.Data)-20])
mname = resp.DecodeName(nil, mname)
serial := binary.BigEndian.Uint32(data[len(data)-20:])
refresh := binary.BigEndian.Uint32(data[len(data)-16:])
retry := binary.BigEndian.Uint32(data[len(data)-12:])
expire := binary.BigEndian.Uint32(data[len(data)-8:])
minimum := binary.BigEndian.Uint32(data[len(data)-4:])
serial := binary.BigEndian.Uint32(r.Data[len(r.Data)-20:])
refresh := binary.BigEndian.Uint32(r.Data[len(r.Data)-16:])
retry := binary.BigEndian.Uint32(r.Data[len(r.Data)-12:])
expire := binary.BigEndian.Uint32(r.Data[len(r.Data)-8:])
minimum := binary.BigEndian.Uint32(r.Data[len(r.Data)-4:])
v = fmt.Sprintf("%s. %s. %d %d %d %d %d", mname, nname, serial, refresh, retry, expire, minimum)
default:
v = fmt.Sprintf("%x", data)
v = fmt.Sprintf("%x", r.Data)
}
fmt.Printf("%s\n", v)
return true
})
}
}

func cmd(req, resp *fastdns.Message, server string, start, end time.Time) {
Expand Down Expand Up @@ -187,11 +186,12 @@ func cmd(req, resp *fastdns.Message, server string, start, end time.Time) {
} else {
fmt.Printf(";; AUTHORITY SECTION:\n")
}
var index int
_ = resp.Walk(func(name []byte, typ fastdns.Type, class fastdns.Class, ttl uint32, data []byte) bool {
index := 0
for r := range resp.Records {
index++
data := r.Data
var v interface{}
switch typ {
switch r.Type {
case fastdns.TypeA, fastdns.TypeAAAA:
v, _ = netip.AddrFromSlice(data)
case fastdns.TypeCNAME, fastdns.TypeNS:
Expand Down Expand Up @@ -228,7 +228,7 @@ func cmd(req, resp *fastdns.Message, server string, start, end time.Time) {
case fastdns.TypeHTTPS:
var h fastdns.NetHTTPS
if len(data) < 7 {
return true
return
}
data = data[3:]
for len(data) >= 4 {
Expand Down Expand Up @@ -304,9 +304,8 @@ func cmd(req, resp *fastdns.Message, server string, start, end time.Time) {
default:
v = fmt.Sprintf("%x", data)
}
fmt.Printf("%s. %d %s %s %s\n", resp.DecodeName(nil, name), ttl, class, typ, v)
return true
})
fmt.Printf("%s. %d %s %s %s\n", resp.DecodeName(nil, r.Name), r.TTL, r.Class, r.Type, v)
}

fmt.Printf("\n")
fmt.Printf(";; Query time: %d msec\n", end.Sub(start)/time.Millisecond)
Expand Down
2 changes: 1 addition & 1 deletion cmd/fastdoh/go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module main

go 1.22
go 1.23

require (
github.com/phuslu/fastdns v1.0.0
Expand Down
13 changes: 6 additions & 7 deletions cmd/fastdoh/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,16 @@ func (h *DNSHandler) ServeDNS(rw fastdns.ResponseWriter, req *fastdns.Message) {
}

if h.Debug {
_ = resp.Walk(func(name []byte, typ fastdns.Type, class fastdns.Class, ttl uint32, data []byte) bool {
switch typ {
for r := range resp.Records {
switch r.Type {
case fastdns.TypeCNAME:
slog.Info("dns request CNAME", "name", resp.DecodeName(nil, name), "ttl", ttl, "class", class, "type", typ, "CNAME", resp.DecodeName(nil, data))
slog.Info("dns request CNAME", "name", resp.DecodeName(nil, r.Name), "ttl", r.TTL, "class", r.Class, "type", r.Type, "CNAME", resp.DecodeName(nil, r.Data))
case fastdns.TypeA:
slog.Info("dns request A", "name", resp.DecodeName(nil, name), "ttl", ttl, "class", class, "type", typ, "A", netip.AddrFrom4(*(*[4]byte)(data)))
slog.Info("dns request A", "name", resp.DecodeName(nil, r.Name), "ttl", r.TTL, "class", r.Class, "type", r.Type, "A", netip.AddrFrom4(*(*[4]byte)(r.Data)))
case fastdns.TypeAAAA:
slog.Info("dns request AAAA", "name", resp.DecodeName(nil, name), "ttl", ttl, "class", class, "type", typ, "AAAA", netip.AddrFrom16(*(*[16]byte)(data)))
slog.Info("dns request AAAA", "name", resp.DecodeName(nil, r.Name), "ttl", r.TTL, "class", r.Class, "type", r.Type, "AAAA", netip.AddrFrom16(*(*[16]byte)(r.Data)))
}
return true
})
}
slog.Info("serve dns answers", "remote_addr", rw.RemoteAddr(), "domain", req.Domain, "remote_addr", h.DNSClient.Addr, "answer_count", resp.Header.ANCount)
}

Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
module github.com/phuslu/fastdns

go 1.22
go 1.23
20 changes: 13 additions & 7 deletions message.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,19 @@ func (msg *Message) DecodeName(dst []byte, name []byte) []byte {
return dst
}

type AnswerRecord struct {
Name []byte
Type Type
Class Class
TTL uint32
Data []byte
}

// Walk calls f for each item in the msg in the original order of the parsed RR.
func (msg *Message) Walk(f func(name []byte, typ Type, class Class, ttl uint32, data []byte) bool) error {
func (msg *Message) Records(f func(AnswerRecord) bool) {
n := msg.Header.ANCount + msg.Header.NSCount
if n == 0 {
return ErrInvalidAnswer
return
}

payload := msg.Raw[16+len(msg.Question.Name):]
Expand All @@ -238,7 +246,7 @@ func (msg *Message) Walk(f func(name []byte, typ Type, class Class, ttl uint32,
}
}
if name == nil {
return ErrInvalidAnswer
return
}
_ = payload[9] // hint compiler to remove bounds check
typ := Type(payload[0])<<8 | Type(payload[1])
Expand All @@ -247,17 +255,15 @@ func (msg *Message) Walk(f func(name []byte, typ Type, class Class, ttl uint32,
length := uint16(payload[8])<<8 | uint16(payload[9])
data := payload[10 : 10+length]
payload = payload[10+length:]
ok := f(name, typ, class, ttl, data)
ok := f(AnswerRecord{Name: name, Type: typ, Class: class, TTL: ttl, Data: data})
if !ok {
break
}
}

return nil
}

// WalkAdditionalRecords calls f for each item in the msg in the original order of the parsed AR.
func (msg *Message) WalkAdditionalRecords(f func(name []byte, typ Type, class Class, ttl uint32, data []byte) bool) error {
func (msg *Message) AdditionalRecords(f func(AnswerRecord) bool) {
panic("not implemented")
}

Expand Down

0 comments on commit 2e9745f

Please sign in to comment.