diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index cf2b00c..856e574 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -17,7 +17,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v2 with: - go-version: 1.19 + go-version: 1.21 - name: Build run: go build -v ./... diff --git a/README.md b/README.md index 6362720..9147eb7 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ Forked from [looterz/grimd](https://github.com/looterz/grimd) - [x] DNS over TCP - [x] DNS over HTTP(S) (DoH as per [RFC-8484](https://datatracker.ietf.org/doc/html/rfc8484)) - [x] Prometheus metrics API -- [x] Custom DNS records supports +- [x] Custom DNS records support - [x] Blocklist fetching - [x] Hardcoded blocklist config - [x] Hardcoded whitelist config @@ -56,7 +56,7 @@ Usage of grimd: ``` # Building -Requires golang 1.7 or higher, you build grimd like any other golang application, for example to build for linux x64 +Requires golang 1.20 or higher, you build grimd like any other golang application, for example to build for linux x64 ```shell env GOOS=linux GOARCH=amd64 go build -v github.com/cottand/grimd ``` @@ -76,19 +76,19 @@ curl -H "Accept: application/json" http://127.0.0.1:55006/application/active ``` # Speed -Incoming requests spawn a goroutine and are served concurrently, and the block cache resides in-memory to allow for rapid lookups, while answered queries are cached allowing grimd to serve thousands of queries at once while maintaining a memory footprint of under 15mb for 100,000 blocked domains! +Incoming requests spawn a goroutine and are served concurrently, and the block cache resides in-memory to allow for rapid lookups, while answered queries are cached allowing grimd to serve thousands of queries at once while maintaining a memory footprint of under 30mb for 100,000 blocked domains! # Daemonize You can find examples of different daemon scripts for grimd on the [wiki](https://github.com/looterz/grimd/wiki/Daemon-Scripts). # Objectives -These are some of the things I would like to contribute in this fork: - [x] ~~ARM64 Docker builds~~ - [ ] Better custom DNS support - [x] ~~Dynamic config reload for custom DNS issue#16~~ - [x] ~~Fix multi-record responses issue#5~~ - - [ ] DNS record flattening issue#1 + - [x] ~~DNS record CNAME following issue#1~~ + - [ ] DNS record CNAME flattening a la cloudflare issue#27 - [ ] Service discovery integrations? issue#4 - [x] Prometheus metrics exporter issue#3 - [x] DNS over HTTPS #2 diff --git a/config.go b/config.go index 4917243..9fb4533 100644 --- a/config.go +++ b/config.go @@ -46,6 +46,7 @@ type Config struct { DoH string Metrics Metrics `toml:"metrics"` DnsOverHttpServer DnsOverHttpServer + FollowCnameDepth uint32 } type Metrics struct { @@ -162,6 +163,12 @@ reactivationdelay = 300 # Dns over HTTPS upstream provider to use DoH = "https://cloudflare-dns.com/dns-query" +# How deep to follow chains of CNAME records +# set to 0 to disable CNAME-following entirely +# (anything more than 10 should be more than plenty) +# see https://github.com/Cottand/grimd/wiki/CNAME%E2%80%90following-DNS +followCnameDepth = 12 + # Prometheus metrics - disabled by default [Metrics] enabled = false diff --git a/dashboard/reaper b/dashboard/reaper deleted file mode 160000 index 79c9068..0000000 --- a/dashboard/reaper +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 79c906812668315e238319f9d5229f09c1a54ebc diff --git a/doh.go b/doh.go index 9d66d5e..0980294 100644 --- a/doh.go +++ b/doh.go @@ -91,7 +91,7 @@ func (s *ServerHTTPS) Stop() error { return nil } -// ServeHTTP is the handler that gets the HTTP request and converts to the dns format, calls the resolver, +// ServeHTTP is the eventLoop that gets the HTTP request and converts to the dns format, calls the resolver, // converts it back and write it to the client. func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) { if !(r.URL.Path == pathDOH) { @@ -195,7 +195,7 @@ type DohResponseWriter struct { // See section 4.2.1 of RFC 8484. // We are using code 500 to indicate an unexpected situation when the chain -// handler has not provided any response message. +// eventLoop has not provided any response message. func (w *DohResponseWriter) handleErr(err error) { logger.Warningf("error when replying to DoH: %v", err) http.Error(w.delegate, "No response", http.StatusInternalServerError) diff --git a/doh_test.go b/doh_test.go index 468135b..ff44a88 100644 --- a/doh_test.go +++ b/doh_test.go @@ -17,7 +17,7 @@ func dnsAQuestion(question string) (msg *dns.Msg) { func TestDohHappyPath(t *testing.T) { handler := dns.NewServeMux() custom := NewCustomDNSRecordsFromText([]string{"example.com. IN A 10.0.0.0 "}) - handler.HandleFunc("example.com", custom[0].serve(nil)) + handler.HandleFunc("example.com", custom[0].asHandler()) dohTest(t, handler, func(r Resolver, bind string) { response, err := r.DoHLookup("http://"+bind+"/dns-query", 1, dnsAQuestion("example.com.")) @@ -36,7 +36,7 @@ func TestDohHappyPath(t *testing.T) { func TestDoh404(t *testing.T) { handler := dns.NewServeMux() custom := NewCustomDNSRecordsFromText([]string{"example.com A 10.0.0.0"}) - handler.HandleFunc("example.com", custom[0].serve(nil)) + handler.HandleFunc("example.com", custom[0].asHandler()) dohTest(t, handler, func(r Resolver, bind string) { resp, err := http.Get("http://" + bind + "/unknown-path") diff --git a/grimd_test.go b/grimd_test.go index bfa1042..0c2c317 100644 --- a/grimd_test.go +++ b/grimd_test.go @@ -5,6 +5,7 @@ import ( "github.com/pelletier/go-toml/v2" "io" "net/http" + "slices" "strings" "testing" "time" @@ -40,6 +41,7 @@ func integrationTest(changeConfig func(c *Config), test func(client *dns.Client, go startActivation(actChannel, quitActivation, config.ReactivationDelay) grimdActivation = <-actChannel + grimdActive = true close(actChannel) server := &Server{ @@ -51,6 +53,9 @@ func integrationTest(changeConfig func(c *Config), test func(client *dns.Client, // BlockCache contains all blocked domains blockCache := &MemoryBlockCache{Backend: make(map[string]bool)} + for _, blocked := range config.Blocklist { + _ = blockCache.Set(blocked, true) + } // QuestionCache contains all queries to the dns server questionCache := makeQuestionCache(config.QuestionCacheCap) @@ -150,6 +155,72 @@ func Test2in3DifferentARecords(t *testing.T) { ) } +func contains(str string) func(rr dns.RR) bool { + return func(rr dns.RR) bool { + return strings.Contains(rr.String(), str) + } +} + +func TestCnameFollowHappyPath(t *testing.T) { + integrationTest( + func(c *Config) { + c.CustomDNSRecords = []string{ + "first.com IN CNAME second.com ", + "second.com IN CNAME third.com ", + "third.com IN A 10.10.0.42 ", + } + c.Timeout = 10000 + }, + func(client *dns.Client, target string) { + c := new(dns.Client) + + m := new(dns.Msg) + + m.SetQuestion(dns.Fqdn("first.com"), dns.TypeA) + reply, _, err := c.Exchange(m, target) + if err != nil { + t.Fatalf("failed to exchange %v", err) + } + if l := len(reply.Answer); l != 3 { + t.Fatalf("Expected 3 returned records but had %v: %v", l, reply.Answer) + } + + if !slices.ContainsFunc(reply.Answer, contains("10.10.0.42")) || + !slices.ContainsFunc(reply.Answer, contains("A")) { + t.Fatalf("Expected the right A address to be returned, but got %v", reply.Answer[0]) + } + }, + ) +} + +func TestCnameFollowWithBlocked(t *testing.T) { + integrationTest( + func(c *Config) { + c.CustomDNSRecords = []string{ + "first.com IN CNAME second.com ", + "second.com IN CNAME example.com ", + } + c.Blocklist = []string{"example.com"} + + }, + func(client *dns.Client, target string) { + c := new(dns.Client) + + m := new(dns.Msg) + + m.SetQuestion(dns.Fqdn("first.com"), dns.TypeA) + reply, _, err := c.Exchange(m, target) + if err != nil { + t.Error(err) + t.FailNow() + } + if !slices.ContainsFunc(reply.Answer, contains("0.0.0.0")) { + t.Fatalf("Expected right A address to be blocked, but got \n%v", reply.String()) + } + }, + ) +} + func TestDohIntegration(t *testing.T) { dohBind := "localhost:8181" integrationTest(func(c *Config) { diff --git a/handler.go b/handler.go index 424eb4d..cbf1415 100644 --- a/handler.go +++ b/handler.go @@ -3,6 +3,7 @@ package main import ( "github.com/cottand/grimd/internal/metric" "net" + "slices" "strings" "sync" "time" @@ -36,14 +37,19 @@ func (q *Question) String() string { return q.Qname + " " + q.Qclass + " " + q.Qtype } -// DNSHandler type -type DNSHandler struct { +// EventLoop type +type EventLoop struct { requestChannel chan DNSOperationData resolver *Resolver cache Cache - negCache Cache - active bool - muActive sync.RWMutex + // negCache caches failures + negCache Cache + active bool + muActive sync.RWMutex + config *Config + blockCache *MemoryBlockCache + questionCache *MemoryQuestionCache + customDns *CustomRecordsResolver } // DNSOperationData type @@ -53,8 +59,8 @@ type DNSOperationData struct { req *dns.Msg } -// NewHandler returns a new DNSHandler -func NewHandler(config *Config, blockCache *MemoryBlockCache, questionCache *MemoryQuestionCache) *DNSHandler { +// NewEventLoop returns a new eventLoop +func NewEventLoop(config *Config, blockCache *MemoryBlockCache, questionCache *MemoryQuestionCache) *EventLoop { var ( clientConfig *dns.ClientConfig resolver *Resolver @@ -73,209 +79,282 @@ func NewHandler(config *Config, blockCache *MemoryBlockCache, questionCache *Mem Maxcount: config.Maxcount, } - handler := &DNSHandler{ + handler := &EventLoop{ requestChannel: make(chan DNSOperationData), resolver: resolver, cache: cache, negCache: negCache, + blockCache: blockCache, + questionCache: questionCache, active: true, + config: config, + customDns: NewCustomRecordsResolver(NewCustomDNSRecordsFromText(config.CustomDNSRecords)), } - go handler.do(config, blockCache, questionCache) + go handler.do() return handler } -func (h *DNSHandler) do(config *Config, blockCache *MemoryBlockCache, questionCache *MemoryQuestionCache) { +func (h *EventLoop) do() { for { data, ok := <-h.requestChannel if !ok { break } - func(Net string, w dns.ResponseWriter, req *dns.Msg) { - defer func(w dns.ResponseWriter) { - err := w.Close() - if err != nil { - } - }(w) - q := req.Question[0] - Q := Question{UnFqdn(q.Name), dns.TypeToString[q.Qtype], dns.ClassToString[q.Qclass]} - - var remote net.IP - if Net == "tcp" { - remote = w.RemoteAddr().(*net.TCPAddr).IP - } else if Net == "http" { - remote = w.RemoteAddr().(*net.TCPAddr).IP + h.doRequest(data.Net, data.w, data.req) + } +} + +// responseFor has side-effects, like writing to h's caches, so avoid calling it concurrently +func (h *EventLoop) responseFor(Net string, req *dns.Msg, _local net.Addr, _remote net.Addr) (_ *dns.Msg, success bool) { + + var remote net.IP + if Net == "tcp" || Net == "http" { + remote = _remote.(*net.TCPAddr).IP + } else { + remote = _remote.(*net.UDPAddr).IP + } + + // first of all, check custom DNS. No need to cache it because it is already in-mem and precedes the blocking + if custom := h.customDns.Resolve(req, _local, _remote); custom != nil { + return custom, true + } + + q := req.Question[0] + Q := Question{UnFqdn(q.Name), dns.TypeToString[q.Qtype], dns.ClassToString[q.Qclass]} + logger.Infof("%s lookup %s\n", remote, Q.String()) + var grimdActive = grimdActivation.query() + if len(h.config.ToggleName) > 0 && strings.Contains(Q.Qname, h.config.ToggleName) { + logger.Noticef("Found ToggleName! (%s)\n", Q.Qname) + grimdActive = grimdActivation.toggle(h.config.ReactivationDelay) + + if grimdActive { + logger.Notice("Grimd Activated") + } else { + logger.Notice("Grimd Deactivated") + } + } + + IPQuery := h.isIPQuery(q) + + // Only query cache when qtype == 'A'|'AAAA' , qclass == 'IN' + key := KeyGen(Q) + if IPQuery > 0 { + mesg, blocked, err := h.cache.Get(key) + if err != nil { + if mesg, blocked, err = h.negCache.Get(key); err != nil { + logger.Debugf("%s didn't hit cache\n", Q.String()) } else { - remote = w.RemoteAddr().(*net.UDPAddr).IP + logger.Debugf("%s hit negative cache\n", Q.String()) + return nil, false } + } else { + if blocked && !grimdActive { + logger.Debugf("%s hit cache and was blocked: forwarding request\n", Q.String()) + } else { + logger.Debugf("%s hit cache\n", Q.String()) - logger.Infof("%s lookup %s\n", remote, Q.String()) - - var grimdActive = grimdActivation.query() - if len(config.ToggleName) > 0 && strings.Contains(Q.Qname, config.ToggleName) { - logger.Noticef("Found ToggleName! (%s)\n", Q.Qname) - grimdActive = grimdActivation.toggle(config.ReactivationDelay) + // we need this copy against concurrent modification of ID + msg := *mesg + msg.Id = req.Id - if grimdActive { - logger.Notice("Grimd Activated") - } else { - logger.Notice("Grimd Deactivated") - } + defer metric.ReportDNSRespond(remote, &msg, blocked) + return &msg, true } + } + } + // Check blocklist + var blacklisted = false + + if IPQuery > 0 { + blacklisted = h.blockCache.Exists(Q.Qname) + + if grimdActive && blacklisted { + m := new(dns.Msg) + m.SetReply(req) - IPQuery := h.isIPQuery(q) - - // Only query cache when qtype == 'A'|'AAAA' , qclass == 'IN' - key := KeyGen(Q) - if IPQuery > 0 { - mesg, blocked, err := h.cache.Get(key) - if err != nil { - if mesg, blocked, err = h.negCache.Get(key); err != nil { - logger.Debugf("%s didn't hit cache\n", Q.String()) - } else { - logger.Debugf("%s hit negative cache\n", Q.String()) - h.HandleFailed(w, req) - return + if h.config.NXDomain { + m.SetRcode(req, dns.RcodeNameError) + } else { + nullroute := net.ParseIP(h.config.Nullroute) + nullroutev6 := net.ParseIP(h.config.Nullroutev6) + + switch IPQuery { + case _IP4Query: + rrHeader := dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: h.config.TTL, } - } else { - if blocked && !grimdActive { - logger.Debugf("%s hit cache and was blocked: forwarding request\n", Q.String()) - } else { - logger.Debugf("%s hit cache\n", Q.String()) - - // we need this copy against concurrent modification of ID - msg := *mesg - msg.Id = req.Id - h.WriteReplyMsg(w, &msg) - metric.ReportDNSResponse(w, &msg, blocked) - return + a := &dns.A{Hdr: rrHeader, A: nullroute} + m.Answer = append(m.Answer, a) + case _IP6Query: + rrHeader := dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: h.config.TTL, } + a := &dns.AAAA{Hdr: rrHeader, AAAA: nullroutev6} + m.Answer = append(m.Answer, a) } } - // Check blocklist - var blacklisted = false - - if IPQuery > 0 { - blacklisted = blockCache.Exists(Q.Qname) - - if grimdActive && blacklisted { - m := new(dns.Msg) - m.SetReply(req) - - if config.NXDomain { - m.SetRcode(req, dns.RcodeNameError) - } else { - nullroute := net.ParseIP(config.Nullroute) - nullroutev6 := net.ParseIP(config.Nullroutev6) - - switch IPQuery { - case _IP4Query: - rrHeader := dns.RR_Header{ - Name: q.Name, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: config.TTL, - } - a := &dns.A{Hdr: rrHeader, A: nullroute} - m.Answer = append(m.Answer, a) - case _IP6Query: - rrHeader := dns.RR_Header{ - Name: q.Name, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - Ttl: config.TTL, - } - a := &dns.AAAA{Hdr: rrHeader, AAAA: nullroutev6} - m.Answer = append(m.Answer, a) - } - } - - h.WriteReplyMsg(w, m) - metric.ReportDNSResponse(w, m, true) - logger.Noticef("%s found in blocklist\n", Q.Qname) + defer metric.ReportDNSRespond(remote, m, true) - // log query - NewEntry := QuestionCacheEntry{Date: time.Now().Unix(), Remote: remote.String(), Query: Q, Blocked: true} - go questionCache.Add(NewEntry) + logger.Noticef("%s found in blocklist\n", Q.Qname) - // cache the block; we don't know the true TTL for blocked entries: we just enforce our config - err := h.cache.Set(key, m, true) - if err != nil { - logger.Errorf("Set %s block cache failed: %s\n", Q.String(), err.Error()) - } + // log query + NewEntry := QuestionCacheEntry{Date: time.Now().Unix(), Remote: remote.String(), Query: Q, Blocked: true} + go h.questionCache.Add(NewEntry) - return - } - logger.Debugf("%s not found in blocklist\n", Q.Qname) + // cache the block; we don't know the true TTL for blocked entries: we just enforce our config + err := h.cache.Set(key, m, true) + if err != nil { + logger.Errorf("Set %s block cache failed: %s\n", Q.String(), err.Error()) } - // log query - NewEntry := QuestionCacheEntry{Date: time.Now().Unix(), Remote: remote.String(), Query: Q, Blocked: false} - go questionCache.Add(NewEntry) + return m, true + } + logger.Debugf("%s not found in blocklist\n", Q.Qname) + } - mesg, err := h.resolver.Lookup(Net, req, config.Timeout, config.Interval, config.Nameservers, config.DoH) + // log query + NewEntry := QuestionCacheEntry{Date: time.Now().Unix(), Remote: remote.String(), Query: Q, Blocked: false} + go h.questionCache.Add(NewEntry) - if err != nil { - logger.Errorf("resolve query error %s\n", err) - h.HandleFailed(w, req) + mesg, err := h.resolver.Lookup(Net, req, h.config.Timeout, h.config.Interval, h.config.Nameservers, h.config.DoH) - // cache the failure, too! - if err = h.negCache.Set(key, nil, false); err != nil { - logger.Errorf("set %s negative cache failed: %v\n", Q.String(), err) - } - return - } + if err != nil { + logger.Errorf("resolve query error %s\n", err) - if mesg.Truncated && Net == "udp" { - mesg, err = h.resolver.Lookup("tcp", req, config.Timeout, config.Interval, config.Nameservers, config.DoH) - if err != nil { - logger.Errorf("resolve tcp query error %s\n", err) - h.HandleFailed(w, req) + // cache the failure, too! + if err = h.negCache.Set(key, nil, false); err != nil { + logger.Errorf("set %s negative cache failed: %v\n", Q.String(), err) + } + return nil, false + } - // cache the failure, too! - if err = h.negCache.Set(key, nil, false); err != nil { - logger.Errorf("set %s negative cache failed: %v\n", Q.String(), err) - } - return - } + if mesg.Truncated && Net == "udp" { + mesg, err = h.resolver.Lookup("tcp", req, h.config.Timeout, h.config.Interval, h.config.Nameservers, h.config.DoH) + if err != nil { + logger.Errorf("resolve tcp query error %s\n", err) + + // cache the failure, too! + if err = h.negCache.Set(key, nil, false); err != nil { + logger.Errorf("set %s negative cache failed: %v\n", Q.String(), err) } + return nil, false + } + } - //find the smallest ttl - ttl := config.Expire - var candidateTTL uint32 + //find the smallest ttl + ttl := h.config.Expire + var candidateTTL uint32 - for index, answer := range mesg.Answer { - logger.Debugf("Answer %d - %s\n", index, answer.String()) + for index, answer := range mesg.Answer { + logger.Debugf("Answer %d - %s\n", index, answer.String()) - candidateTTL = answer.Header().Ttl + candidateTTL = answer.Header().Ttl - if candidateTTL > 0 && candidateTTL < ttl { - ttl = candidateTTL - } + if candidateTTL > 0 && candidateTTL < ttl { + ttl = candidateTTL + } + } + + defer metric.ReportDNSRespond(remote, mesg, false) + + if IPQuery > 0 && len(mesg.Answer) > 0 { + if !grimdActive && blacklisted { + logger.Debugf("%s is blacklisted and grimd not active: not caching\n", Q.String()) + } else { + err = h.cache.Set(key, mesg, false) + if err != nil { + logger.Errorf("set %s cache failed: %s\n", Q.String(), err.Error()) } + logger.Debugf("insert %s into cache with ttl %d\n", Q.String(), ttl) + } + } + return mesg, true +} - h.WriteReplyMsg(w, mesg) - metric.ReportDNSResponse(w, mesg, false) +func (h *EventLoop) doRequest(Net string, w dns.ResponseWriter, req *dns.Msg) { + defer func(w dns.ResponseWriter) { + err := w.Close() + if err != nil { + } + }(w) - if IPQuery > 0 && len(mesg.Answer) > 0 { - if !grimdActive && blacklisted { - logger.Debugf("%s is blacklisted and grimd not active: not caching\n", Q.String()) - } else { - err = h.cache.Set(key, mesg, false) - if err != nil { - logger.Errorf("set %s cache failed: %s\n", Q.String(), err.Error()) - } - logger.Debugf("insert %s into cache with ttl %d\n", Q.String(), ttl) + resp, ok := h.responseFor(Net, req, w.LocalAddr(), w.RemoteAddr()) + + if !ok { + m := new(dns.Msg) + m.SetRcode(req, dns.RcodeServerFailure) + WriteReplyMsg(w, m) + metric.ReportDNSResponse(w, m, false) + return + } + + depthSoFar := uint32(0) + for h.config.FollowCnameDepth > depthSoFar { + cnames, ok := canFollow(req, resp) + depthSoFar++ + if !ok { + break + } + for _, cname := range cnames { + r := dns.Msg{} + r.SetQuestion(cname.Target, req.Question[0].Qtype) + followed, ok := h.responseFor(Net, &r, w.LocalAddr(), w.RemoteAddr()) + for _, fAnswer := range followed.Answer { + containsNewAnswer := func(rr dns.RR) bool { + return rr.String() == fAnswer.String() + } + if ok && !slices.ContainsFunc(resp.Answer, containsNewAnswer) { + resp.Answer = append(resp.Answer, fAnswer) } } - }(data.Net, data.w, data.req) + } } + + WriteReplyMsg(w, resp) + } +// determines if resp contains no A records but some CNAME record +func canFollow(req *dns.Msg, resp *dns.Msg) (cnames []*dns.CNAME, ok bool) { + // RFC-1034: only follow non-CNAME queries + if req.Question[0].Qtype == dns.TypeCNAME { + return []*dns.CNAME{}, false + } + + isA := func(rr dns.RR) bool { + return rr.Header().Rrtype == dns.TypeA || rr.Header().Rrtype == dns.TypeAAAA + } + + isCname := func(rr dns.RR) bool { + return rr.Header().Rrtype == dns.TypeCNAME + } + + ok = !slices.ContainsFunc(resp.Answer, isA) && slices.ContainsFunc(resp.Answer, isCname) + for _, rr := range resp.Answer { + if asCname, ok := rr.(*dns.CNAME); isCname(rr) && ok { + cnames = append(cnames, asCname) + } + } + + return cnames, ok && len(cnames) != 0 +} + +// msg: +// Q: A fst.com +// A: CN snd.com, thrd.com +// + // DoTCP begins a tcp query -func (h *DNSHandler) DoTCP(w dns.ResponseWriter, req *dns.Msg) { +func (h *EventLoop) DoTCP(w dns.ResponseWriter, req *dns.Msg) { h.muActive.RLock() defer h.muActive.RUnlock() if h.active { @@ -284,7 +363,7 @@ func (h *DNSHandler) DoTCP(w dns.ResponseWriter, req *dns.Msg) { } // DoUDP begins a udp query -func (h *DNSHandler) DoUDP(w dns.ResponseWriter, req *dns.Msg) { +func (h *EventLoop) DoUDP(w dns.ResponseWriter, req *dns.Msg) { h.muActive.RLock() defer h.muActive.RUnlock() if h.active { @@ -292,7 +371,7 @@ func (h *DNSHandler) DoUDP(w dns.ResponseWriter, req *dns.Msg) { } } -func (h *DNSHandler) DoHTTP(w dns.ResponseWriter, req *dns.Msg) { +func (h *EventLoop) DoHTTP(w dns.ResponseWriter, req *dns.Msg) { h.muActive.RLock() defer h.muActive.RUnlock() if h.active { @@ -300,16 +379,8 @@ func (h *DNSHandler) DoHTTP(w dns.ResponseWriter, req *dns.Msg) { } } -// HandleFailed handles dns failures -func (h *DNSHandler) HandleFailed(w dns.ResponseWriter, message *dns.Msg) { - m := new(dns.Msg) - m.SetRcode(message, dns.RcodeServerFailure) - h.WriteReplyMsg(w, m) - metric.ReportDNSResponse(w, m, false) -} - // WriteReplyMsg writes the dns reply -func (h *DNSHandler) WriteReplyMsg(w dns.ResponseWriter, message *dns.Msg) { +func WriteReplyMsg(w dns.ResponseWriter, message *dns.Msg) { defer func() { if r := recover(); r != nil { logger.Noticef("Recovered in WriteReplyMsg: %s\n", r) @@ -320,10 +391,9 @@ func (h *DNSHandler) WriteReplyMsg(w dns.ResponseWriter, message *dns.Msg) { if err != nil { logger.Error(err) } - } -func (h *DNSHandler) isIPQuery(q dns.Question) int { +func (h *EventLoop) isIPQuery(q dns.Question) int { if q.Qclass != dns.ClassINET { return notIPQuery } diff --git a/internal/metric/metric.go b/internal/metric/metric.go index 8b533e6..a85a25c 100644 --- a/internal/metric/metric.go +++ b/internal/metric/metric.go @@ -81,3 +81,13 @@ func ReportDNSResponse(w dns.ResponseWriter, message *dns.Msg, blocked bool) { "blocked": strconv.FormatBool(blocked), }).Inc() } +func ReportDNSRespond(remote net.IP, message *dns.Msg, blocked bool) { + question := message.Question[0] + responseCounter.With(prometheus.Labels{ + "remote_ip": remote.String(), + "q_type": dns.Type(question.Qtype).String(), + "q_name": question.Name, + "rcode": dns.RcodeToString[message.Rcode], + "blocked": strconv.FormatBool(blocked), + }).Inc() +} diff --git a/records.go b/records.go index 2fb0580..f52918d 100644 --- a/records.go +++ b/records.go @@ -3,6 +3,7 @@ package main import ( "github.com/cottand/grimd/internal/metric" "github.com/miekg/dns" + "net" ) type CustomDNSRecords struct { @@ -46,14 +47,74 @@ func NewCustomDNSRecords(from map[string][]dns.RR) []CustomDNSRecords { return records } -func (records CustomDNSRecords) serve(serverHandler *DNSHandler) func(dns.ResponseWriter, *dns.Msg) { +func (records CustomDNSRecords) asHandler() func(dns.ResponseWriter, *dns.Msg) { return func(writer dns.ResponseWriter, req *dns.Msg) { m := new(dns.Msg) m.SetReply(req) m.Answer = append(m.Answer, records.answer...) - serverHandler.WriteReplyMsg(writer, m) + WriteReplyMsg(writer, m) metric.RequestCustomCounter.Inc() metric.ReportDNSResponse(writer, m, false) } } + +// CustomRecordsResolver allows faking an in-mem DNS server just for custom records +type CustomRecordsResolver struct { + mux *dns.ServeMux +} + +func NewCustomRecordsResolver(records []CustomDNSRecords) *CustomRecordsResolver { + mux := dns.NewServeMux() + for _, r := range records { + mux.HandleFunc(r.name, r.asHandler()) + } + return &CustomRecordsResolver{mux} +} + +// Resolve returns nil when there was no result found +func (r *CustomRecordsResolver) Resolve(req *dns.Msg, local net.Addr, remote net.Addr) *dns.Msg { + writer := roResponseWriter{local: local, remote: remote} + r.mux.ServeDNS(&writer, req) + if writer.result.Rcode == dns.RcodeRefused { + return nil + } else { + return writer.result + } +} + +// roResponseWriter implements dns.ResponseWriter, +// but does not allow calling any method with +// side effects. +// It allows wrapping a dns.ResponseWriter in order +// to recover the final written dns.Msg +type roResponseWriter struct { + local net.Addr + remote net.Addr + result *dns.Msg +} + +func (w *roResponseWriter) LocalAddr() net.Addr { + return w.local +} + +func (w *roResponseWriter) RemoteAddr() net.Addr { + return w.remote +} + +func (w *roResponseWriter) WriteMsg(msg *dns.Msg) error { + w.result = msg + return nil +} +func (w *roResponseWriter) Write([]byte) (int, error) { + return 0, nil +} +func (w *roResponseWriter) Close() error { + return nil +} +func (w *roResponseWriter) TsigStatus() error { + return nil +} +func (w *roResponseWriter) TsigTimersOnly(_ bool) {} +func (w *roResponseWriter) Hijack() { +} diff --git a/resolver.go b/resolver.go index 8ca25f4..1022883 100644 --- a/resolver.go +++ b/resolver.go @@ -69,8 +69,12 @@ func (r *Resolver) Lookup(net string, req *dns.Msg, timeout int, interval int, n } + clientNet := net + if net == "http" { + clientNet = "tcp" + } c := &dns.Client{ - Net: net, + Net: clientNet, ReadTimeout: time.Duration(timeout) * time.Second, WriteTimeout: time.Duration(timeout) * time.Second, } diff --git a/server.go b/server.go index 8283dd7..3485d43 100644 --- a/server.go +++ b/server.go @@ -9,17 +9,16 @@ import ( // Server type type Server struct { - host string - rTimeout time.Duration - wTimeout time.Duration - handler *DNSHandler - udpServer *dns.Server - tcpServer *dns.Server - httpServer *ServerHTTPS - udpHandler *dns.ServeMux - tcpHandler *dns.ServeMux - httpHandler *dns.ServeMux - activeHandlerPatterns []string + host string + rTimeout time.Duration + wTimeout time.Duration + eventLoop *EventLoop + udpServer *dns.Server + tcpServer *dns.Server + httpServer *ServerHTTPS + udpHandler *dns.ServeMux + tcpHandler *dns.ServeMux + httpHandler *dns.ServeMux } // Run starts the server @@ -29,27 +28,16 @@ func (s *Server) Run( questionCache *MemoryQuestionCache, ) { - s.handler = NewHandler(config, blockCache, questionCache) + s.eventLoop = NewEventLoop(config, blockCache, questionCache) tcpHandler := dns.NewServeMux() - tcpHandler.HandleFunc(".", s.handler.DoTCP) + tcpHandler.HandleFunc(".", s.eventLoop.DoTCP) udpHandler := dns.NewServeMux() - udpHandler.HandleFunc(".", s.handler.DoUDP) + udpHandler.HandleFunc(".", s.eventLoop.DoUDP) httpHandler := dns.NewServeMux() - httpHandler.HandleFunc(".", s.handler.DoHTTP) - - handlerPatterns := make([]string, len(config.CustomDNSRecords)) - - for _, record := range NewCustomDNSRecordsFromText(config.CustomDNSRecords) { - dnsHandler := record.serve(s.handler) - tcpHandler.HandleFunc(record.name, dnsHandler) - udpHandler.HandleFunc(record.name, dnsHandler) - httpHandler.HandleFunc(record.name, dnsHandler) - handlerPatterns = append(handlerPatterns, record.name) - } - s.activeHandlerPatterns = handlerPatterns + httpHandler.HandleFunc(".", s.eventLoop.DoHTTP) s.tcpHandler = tcpHandler s.udpHandler = udpHandler @@ -104,11 +92,11 @@ func (s *Server) startHttp(addr string) { // Stop stops the server func (s *Server) Stop() { - if s.handler != nil { - s.handler.muActive.Lock() - s.handler.active = false - close(s.handler.requestChannel) - s.handler.muActive.Unlock() + if s.eventLoop != nil { + s.eventLoop.muActive.Lock() + s.eventLoop.active = false + close(s.eventLoop.requestChannel) + s.eventLoop.muActive.Unlock() } if s.udpServer != nil { err := s.udpServer.Shutdown() @@ -133,31 +121,7 @@ func (s *Server) Stop() { // ReloadConfig only supports reloading the customDnsRecords section of the config for now func (s *Server) ReloadConfig(config *Config) { - oldRecords := s.activeHandlerPatterns newRecords := NewCustomDNSRecordsFromText(config.CustomDNSRecords) - newRecordsPatterns := make([]string, len(newRecords)) - for _, r := range newRecords { - newRecordsPatterns = append(newRecordsPatterns, r.name) - } - if testEq(oldRecords, newRecordsPatterns) { - // no changes - nothing to reload - return - } + s.eventLoop.customDns = NewCustomRecordsResolver(newRecords) defer metric.CustomDNSConfigReload.Inc() - - deletedRecords := difference(oldRecords, newRecordsPatterns) - - for _, deleted := range deletedRecords { - s.tcpHandler.HandleRemove(deleted) - s.udpHandler.HandleRemove(deleted) - s.httpHandler.HandleRemove(deleted) - } - - for _, record := range newRecords { - dnsHandler := record.serve(s.handler) - s.tcpHandler.HandleFunc(record.name, dnsHandler) - s.udpHandler.HandleFunc(record.name, dnsHandler) - s.httpHandler.HandleFunc(record.name, dnsHandler) - } - s.activeHandlerPatterns = newRecordsPatterns }