diff --git a/lib/dns/caches.go b/lib/dns/caches.go index 5a49e762..34a0b56e 100644 --- a/lib/dns/caches.go +++ b/lib/dns/caches.go @@ -232,39 +232,61 @@ func (c *Caches) ExternalSearch(re *regexp.Regexp) (listMsg []*Message) { return listMsg } -// get an answer based on domain-name, query type, and query class. -// -// If query name exist but the query type or class does not exist, -// it will return list of answer and nil answer. -// -// If answer exist on cache and its from external, their accessed time will be -// updated to current time and moved to back of LRU to prevent being pruned -// later. -func (c *Caches) get(qname string, rtype RecordType, rclass RecordClass) (ans *answers, an *Answer) { +func (c *Caches) query(msg *Message) (an *Answer) { + var ans *answers + c.Lock() - defer c.Unlock() - ans = c.internal[qname] + ans = c.internal[msg.Question.Name] if ans == nil { - ans = c.external[qname] + ans = c.external[msg.Question.Name] if ans == nil { - return nil, nil + goto out } } - an, _ = ans.get(rtype, rclass) + an, _ = ans.get(msg.Question.Type, msg.Question.Class) if an == nil { - return ans, nil + goto out } - // Move the answer to the back of LRU if its external - // answer and update its accessed time. + // Move the answer to the back of LRU if its external answer and + // update its accessed time. if an.ReceivedAt > 0 { c.lru.MoveToBack(an.el) - an.AccessedAt = time.Now().Unix() + an.AccessedAt = timeNow().Unix() + } + +out: + c.Unlock() + + if an == nil { + // No answers found in internal and external caches. + // If the requested domain is subset of our internal + // zone, return answer with error and Authority. + var zone = c.internalZone(msg.Question.Name) + if zone == nil { + return nil + } + an = &Answer{ + msg: msg, + } + _ = an.msg.AddAuthority(zone.soaRecord()) + an.msg.SetResponseCode(RCodeErrName) } + return an +} - return ans, an +// internalZone will return the zone if the query name is suffix of one of +// the Zone Origin. +func (c *Caches) internalZone(qname string) (zone *Zone) { + qname = toDomainAbsolute(qname) + for _, zone = range c.zone { + if strings.HasSuffix(qname, `.`+zone.Origin) { + return zone + } + } + return nil } // InternalPopulate add list of message to internal caches. diff --git a/lib/dns/caches_test.go b/lib/dns/caches_test.go index 2684d788..1560cd62 100644 --- a/lib/dns/caches_test.go +++ b/lib/dns/caches_test.go @@ -12,14 +12,12 @@ import ( "github.com/shuLhan/share/lib/test" ) -func TestCachesGet(t *testing.T) { +func TestCachesQuery(t *testing.T) { type testCase struct { exp *Answer + msg Message desc string - QName string expList []*Answer - RType RecordType - RClass RecordClass } var ( @@ -75,11 +73,15 @@ func TestCachesGet(t *testing.T) { an1, an2, an3, }, }, { - desc: "With query found", - QName: "test", - RType: 1, - RClass: 1, - exp: an1, + desc: "With query found", + msg: Message{ + Question: MessageQuestion{ + Name: "test", + Type: 1, + Class: 1, + }, + }, + exp: an1, expList: []*Answer{ an2, an3, an1, }, @@ -88,10 +90,10 @@ func TestCachesGet(t *testing.T) { for _, c = range cases { t.Log(c.desc) - _, got = ca.get(c.QName, c.RType, c.RClass) + got = ca.query(&c.msg) gotList = ca.ExternalLRU() - test.Assert(t, "caches.get", c.exp, got) + test.Assert(t, "caches.query", c.exp, got) test.Assert(t, "caches.list", c.expList, gotList) } } @@ -327,3 +329,36 @@ func TestCachesUpsert(t *testing.T) { } } } + +func TestCaches_internalZone(t *testing.T) { + type testCase struct { + qname string + exp bool + } + + var caches = &Caches{ + zone: map[string]*Zone{ + `my.internal.`: NewZone(``, `my.internal.`), + }, + } + + var listCase = []testCase{{ + qname: `notmy.internal`, + exp: false, + }, { + qname: `sub.my.internal`, + exp: true, + }, { + qname: `sub.my.internal.`, + exp: true, + }} + + var ( + c testCase + got *Zone + ) + for _, c = range listCase { + got = caches.internalZone(c.qname) + test.Assert(t, c.qname, c.exp, got != nil) + } +} diff --git a/lib/dns/server.go b/lib/dns/server.go index f40b0f81..bc865d56 100644 --- a/lib/dns/server.go +++ b/lib/dns/server.go @@ -565,7 +565,6 @@ func (srv *Server) processRequest() { an *Answer res *Message req *request - ans *answers err error ) @@ -582,11 +581,8 @@ func (srv *Server) processRequest() { req.message.Question.String()) } - ans, an = srv.Caches.get(req.message.Question.Name, - req.message.Question.Type, - req.message.Question.Class) - - if ans == nil || an == nil { + an = srv.Caches.query(req.message) + if an == nil { switch { case srv.hasForwarders(): if req.kind == connTypeTCP { diff --git a/lib/dns/zone.go b/lib/dns/zone.go index 052bc132..0d52e30f 100644 --- a/lib/dns/zone.go +++ b/lib/dns/zone.go @@ -23,7 +23,8 @@ type Zone struct { // records. Records map[string][]*ResourceRecord `json:"-"` - SOA *RDataSOA + SOA *RDataSOA + rrSOA *ResourceRecord Path string `json:"-"` @@ -541,3 +542,16 @@ func (zone *Zone) recordRemove(rr *ResourceRecord) bool { } return false } + +func (zone *Zone) soaRecord() (rrsoa *ResourceRecord) { + if zone.rrSOA == nil { + zone.rrSOA = &ResourceRecord{ + Value: zone.SOA, + Name: zone.Origin, + Type: RecordTypeSOA, + Class: RecordClassIN, + TTL: zone.SOA.Minimum, + } + } + return zone.rrSOA +}