diff --git a/shdns.go b/shdns.go index f829ea6..ea92a81 100644 --- a/shdns.go +++ b/shdns.go @@ -51,16 +51,16 @@ var showver = flag.Bool("V", false, "Show version") var version = "unknown" var builddate = "unknown" -type servertype int +type serverType int const ( - domestic servertype = iota + domestic serverType = iota foreign ) type nameserver struct { - udpaddr *net.UDPAddr - stype servertype + udpAddr *net.UDPAddr + sType serverType } type byByte []net.IPNet @@ -79,14 +79,14 @@ func (n byByte) Less(i, j int) bool { func (n byByte) Swap(i, j int) { n[i], n[j] = n[j], n[i] } var ( - cnipnet4, cnipnet6 []net.IPNet - blackips4, blackips6 []net.IPNet + cnIPNet4, cnIPNet6 []net.IPNet + blackIPs4, blackIPs6 []net.IPNet servers []nameserver logger = log.New(os.Stdout, "", log.Ldate|log.Ltime|log.Lmicroseconds) errlog = log.New(os.Stderr, "", log.Ldate|log.Ltime|log.Lmicroseconds) ) -func parseudpaddr(str string) (*net.UDPAddr, error) { +func parseUDPAddr(str string) (*net.UDPAddr, error) { _, _, err := net.SplitHostPort(str) if err == nil { return net.ResolveUDPAddr("udp", str) @@ -100,10 +100,10 @@ func parseudpaddr(str string) (*net.UDPAddr, error) { return nil, err } -func parseservers(str string, stype servertype) { +func parseServers(str string, sType serverType) { serverstr := strings.Split(str, ",") for _, s := range serverstr { - if addr, err := parseudpaddr(s); err != nil { + if addr, err := parseUDPAddr(s); err != nil { errlog.Fatalf("Invalid nameserver: %s", s) } else { if addr.Zone != "" { @@ -117,8 +117,8 @@ func parseservers(str string, stype servertype) { errlog.Fatalf("IPv6 zone invalid: %s", s) } } - if _, exist := lookupserver(addr); !exist { - servers = append(servers, nameserver{udpaddr: addr, stype: stype}) + if _, exist := lookupServer(addr); !exist { + servers = append(servers, nameserver{udpAddr: addr, sType: sType}) logger.Printf("Using nameserver %s", addr) } else { errlog.Fatalf("Nameserver exists: %s", s) @@ -128,7 +128,7 @@ func parseservers(str string, stype servertype) { return } -func parseiplist(filename string, iplen int) (ipnets []net.IPNet) { +func parseIPList(filename string, iplen int) (ipnets []net.IPNet) { f, err := os.Open(filename) if err != nil { errlog.Fatalln(err) @@ -158,7 +158,7 @@ func parseiplist(filename string, iplen int) (ipnets []net.IPNet) { return } -func cmpipipnet(ip net.IP, ipnet net.IPNet) int { // based on net.Contains() +func cmpIPIPNet(ip net.IP, ipnet net.IPNet) int { // based on net.Contains() for i := 0; i < len(ip); i++ { if a, b := ip[i]&ipnet.Mask[i], ipnet.IP[i]&ipnet.Mask[i]; a < b { return -1 @@ -169,9 +169,9 @@ func cmpipipnet(ip net.IP, ipnet net.IPNet) int { // based on net.Contains() return 0 } -func findipinnet(ip net.IP, ipnets []net.IPNet) bool { // based on sort.Search() +func findIPInNet(ip net.IP, ipnets []net.IPNet) bool { // based on sort.Search() for i, j := 0, len(ipnets); i < j; { - switch k := int(uint(i+j) >> 1); cmpipipnet(ip, ipnets[k]) { // i <= k < j + switch k := int(uint(i+j) >> 1); cmpIPIPNet(ip, ipnets[k]) { // i <= k < j case -1: j = k case 0: @@ -183,13 +183,13 @@ func findipinnet(ip net.IP, ipnets []net.IPNet) bool { // based on sort.Search() return false } -func addtag(bufs []bytes.Buffer, tag string) { +func addTag(bufs []bytes.Buffer, tag string) { for i := range bufs { fmt.Fprint(&bufs[i], tag) } } -func handlequery(addr *net.UDPAddr, payload []byte, inconn *net.UDPConn) { // net.Addr is *net.UDPAddr, net.PacketConn is *net.UDPConn +func handleQuery(addr *net.UDPAddr, payload []byte, inConn *net.UDPConn) { // net.Addr is *net.UDPAddr, net.PacketConn is *net.UDPConn var p dnsmessage.Parser h, err := p.Start(payload) if err != nil { @@ -223,40 +223,40 @@ func handlequery(addr *net.UDPAddr, payload []byte, inconn *net.UDPConn) { // ne } if *verbose { if dnssec { - addtag(bufs, " DNSSEC") + addTag(bufs, " DNSSEC") } for _, buf := range bufs { logger.Println(&buf) } } - ch := make(chan []byte) - chsave := make(chan []byte) + chAnswer := make(chan []byte) + chSave := make(chan []byte) loc, _ := net.ResolveUDPAddr("udp", "") - outconn, err := net.ListenUDP("udp", loc) + outConn, err := net.ListenUDP("udp", loc) if err != nil { errlog.Println(err) return } - defer outconn.Close() - go sendandreceive(payload, outconn, ch, chsave, qs[0].Type, dnssec) + defer outConn.Close() + go forwardQuery(payload, outConn, chAnswer, chSave, qs[0].Type, dnssec) answered := false - var latestanswer []byte + var lastAnswer []byte for { select { - case a, ok := <-ch: + case a, ok := <-chAnswer: if ok { if !answered { - if _, err := inconn.WriteToUDP(a, addr); err != nil { + if _, err := inConn.WriteToUDP(a, addr); err != nil { errlog.Println(err) } answered = true if !*verbose { - outconn.Close() + outConn.Close() } } } else { - if !answered && latestanswer != nil { - if _, err := inconn.WriteToUDP(latestanswer, addr); err != nil { + if !answered && lastAnswer != nil { + if _, err := inConn.WriteToUDP(lastAnswer, addr); err != nil { errlog.Println(err) } } @@ -265,76 +265,76 @@ func handlequery(addr *net.UDPAddr, payload []byte, inconn *net.UDPConn) { // ne } return } - case a := <-chsave: + case a := <-chSave: if !answered { - latestanswer = a + lastAnswer = a } } } } -func sendandreceive(payload []byte, outconn *net.UDPConn, ch, chsave chan<- []byte, qtype dnsmessage.Type, dnssec bool) { - defer close(ch) - recvch := make([]chan []byte, len(servers)) - chdone := make([]chan bool, len(servers)) - senttime := time.Now() +func forwardQuery(payload []byte, outConn *net.UDPConn, chAnswer, chSave chan<- []byte, qType dnsmessage.Type, dnssec bool) { + defer close(chAnswer) + chRecv := make([]chan []byte, len(servers)) + chDone := make([]chan bool, len(servers)) + sentTime := time.Now() for i, ns := range servers { - if _, err := outconn.WriteToUDP(payload, ns.udpaddr); err != nil { + if _, err := outConn.WriteToUDP(payload, ns.udpAddr); err != nil { continue } - recvch[i] = make(chan []byte) - chdone[i] = make(chan bool) - go parseanswer(ns, senttime, recvch[i], ch, chsave, chdone[i], dnssec, qtype) - defer func(recvch chan []byte, chdone chan bool) { - close(recvch) - <-chdone - }(recvch[i], chdone[i]) + chRecv[i] = make(chan []byte) + chDone[i] = make(chan bool) + go parseAnswer(ns, sentTime, chRecv[i], chAnswer, chSave, chDone[i], dnssec, qType) + defer func(chRecv chan []byte, chDone chan bool) { + close(chRecv) + <-chDone + }(chRecv[i], chDone[i]) } - outconn.SetReadDeadline(senttime.Add(time.Duration(*initimeout) * time.Second)) + outConn.SetReadDeadline(sentTime.Add(time.Duration(*initimeout) * time.Second)) received := false for { payload := make([]byte, 1500) - n, addr, err := outconn.ReadFromUDP(payload) + n, addr, err := outConn.ReadFromUDP(payload) if err != nil { return } if !received { received = true - outconn.SetReadDeadline(time.Now().Add(time.Duration(*subtimeout) * time.Millisecond)) + outConn.SetReadDeadline(time.Now().Add(time.Duration(*subtimeout) * time.Millisecond)) } - if i, ok := lookupserver(addr); ok { - recvch[i] <- payload[:n] + if i, ok := lookupServer(addr); ok { + chRecv[i] <- payload[:n] } } } -func lookupserver(addr *net.UDPAddr) (int, bool) { +func lookupServer(addr *net.UDPAddr) (int, bool) { for i, s := range servers { - if s.udpaddr.IP.Equal(addr.IP) && s.udpaddr.Port == addr.Port && s.udpaddr.Zone == addr.Zone { + if s.udpAddr.IP.Equal(addr.IP) && s.udpAddr.Port == addr.Port && s.udpAddr.Zone == addr.Zone { return i, true } } return 0, false } -func parseanswer(ns nameserver, senttime time.Time, recvch <-chan []byte, ch, chsave chan<- []byte, chdone chan<- bool, dnssec bool, qtype dnsmessage.Type) { - pcount := 0 - var firstrtt time.Duration +func parseAnswer(ns nameserver, sentTime time.Time, chRecv <-chan []byte, chAnswer, chSave chan<- []byte, chDone chan<- bool, dnssec bool, qType dnsmessage.Type) { + pktCount := 0 + var firstRTT time.Duration answered := false for { - a, ok := <-recvch + a, ok := <-chRecv if !ok { - chdone <- true + chDone <- true return // receive channel closed } - rtt := time.Since(senttime) - toofast := false - if ns.stype == foreign && rtt < time.Duration(*minrtt)*time.Millisecond { - toofast = true + rtt := time.Since(sentTime) + tooFast := false + if ns.sType == foreign && rtt < time.Duration(*minrtt)*time.Millisecond { + tooFast = true } - pcount++ - if pcount == 1 { - firstrtt = rtt + pktCount++ + if pktCount == 1 { + firstRTT = rtt } var p dnsmessage.Parser h, err := p.Start(a) @@ -343,8 +343,8 @@ func parseanswer(ns nameserver, senttime time.Time, recvch <-chan []byte, ch, ch continue } p.SkipAllQuestions() - var geoerr, typeerr, hascname, hasa, hasaaaa, blacklisted, dnssecerr, addterr bool - dnssecerr = dnssec && ns.stype == foreign + var geoErr, typeErr, hasCNAME, hasA, hasAAAA, inBlacklist, dnssecErr, addtErr bool + dnssecErr = dnssec && ns.sType == foreign ansCount := 0 var bufs []bytes.Buffer for { @@ -356,29 +356,29 @@ func parseanswer(ns nameserver, senttime time.Time, recvch <-chan []byte, ch, ch ansCount++ var buf bytes.Buffer if *verbose { - fmt.Fprintf(&buf, "%d %s Answer[%s]", h.ID, ns.udpaddr, ah.Type.String()[4:]) + fmt.Fprintf(&buf, "%d %s Answer[%s]", h.ID, ns.udpAddr, ah.Type.String()[4:]) } switch ah.Type { case dnsmessage.TypeA: - hasa = true - if cnipnet4 != nil || blackips4 != nil || *verbose { + hasA = true + if cnIPNet4 != nil || blackIPs4 != nil || *verbose { r, _ := p.AResource() ip := net.IP(r.A[:]) //r.A is 4-byte if *verbose { fmt.Fprintf(&buf, " %s %s len %d %dms", ah.Name.String(), ip.String(), len(a), rtt.Nanoseconds()/1000000) } - if cnipnet4 != nil { - iscn := findipinnet(ip, cnipnet4) - if ns.stype == domestic && !iscn { - geoerr = true + if cnIPNet4 != nil { + isCN := findIPInNet(ip, cnIPNet4) + if ns.sType == domestic && !isCN { + geoErr = true if *verbose { fmt.Fprint(&buf, " GEOERR") } } } - if blackips4 != nil { - if findipinnet(ip, blackips4) { - blacklisted = true + if blackIPs4 != nil { + if findIPInNet(ip, blackIPs4) { + inBlacklist = true if *verbose { fmt.Fprint(&buf, " BLACKLIST") } @@ -387,32 +387,32 @@ func parseanswer(ns nameserver, senttime time.Time, recvch <-chan []byte, ch, ch } else { p.SkipAnswer() } - if qtype == dnsmessage.TypeAAAA { - typeerr = true + if qType == dnsmessage.TypeAAAA { + typeErr = true if *verbose { fmt.Fprint(&buf, " TYPEERR") } } case dnsmessage.TypeAAAA: - hasaaaa = true - if cnipnet6 != nil || blackips6 != nil || *verbose { + hasAAAA = true + if cnIPNet6 != nil || blackIPs6 != nil || *verbose { r, _ := p.AAAAResource() ip := net.IP(r.AAAA[:]) if *verbose { fmt.Fprintf(&buf, " %s %s len %d %dms", ah.Name.String(), ip.String(), len(a), rtt.Nanoseconds()/1000000) } - if cnipnet6 != nil { - iscn := findipinnet(ip, cnipnet6) - if ns.stype == domestic && !iscn { - geoerr = true + if cnIPNet6 != nil { + isCN := findIPInNet(ip, cnIPNet6) + if ns.sType == domestic && !isCN { + geoErr = true if *verbose { fmt.Fprint(&buf, " GEOERR") } } } - if blackips6 != nil { - if findipinnet(ip, blackips6) { - blacklisted = true + if blackIPs6 != nil { + if findIPInNet(ip, blackIPs6) { + inBlacklist = true if *verbose { fmt.Fprint(&buf, " BLACKLIST") } @@ -428,7 +428,7 @@ func parseanswer(ns nameserver, senttime time.Time, recvch <-chan []byte, ch, ch } else { p.SkipAnswer() } - hascname = true + hasCNAME = true case dnsmessage.TypePTR: if *verbose { r, _ := p.PTRResource() @@ -443,7 +443,7 @@ func parseanswer(ns nameserver, senttime time.Time, recvch <-chan []byte, ch, ch p.SkipAnswer() } if *verbose { - if toofast { + if tooFast { fmt.Fprint(&buf, " TOOFAST") } bufs = append(bufs, buf) @@ -451,7 +451,7 @@ func parseanswer(ns nameserver, senttime time.Time, recvch <-chan []byte, ch, ch } // answer section parsed if *verbose && ansCount == 0 { var buf bytes.Buffer - fmt.Fprintf(&buf, "%d %s Answer[Empty] len %d %dms", h.ID, ns.udpaddr, len(a), rtt.Nanoseconds()/1000000) + fmt.Fprintf(&buf, "%d %s Answer[Empty] len %d %dms", h.ID, ns.udpAddr, len(a), rtt.Nanoseconds()/1000000) bufs = append(bufs, buf) } authCount := 0 @@ -467,94 +467,94 @@ func parseanswer(ns nameserver, senttime time.Time, recvch <-chan []byte, ch, ch addtCount++ switch rh.Type { case dnsmessage.TypeA, dnsmessage.TypeAAAA: - addterr = true + addtErr = true case dnsmessage.TypeOPT: - if ns.stype == foreign { + if ns.sType == foreign { if dnssec == rh.DNSSECAllowed() { - dnssecerr = false + dnssecErr = false } else { - dnssecerr = true + dnssecErr = true } } } p.SkipAdditional() } if *verbose { - if addterr { - addtag(bufs, " ADDTERR") + if addtErr { + addTag(bufs, " ADDTERR") } - if dnssecerr { - addtag(bufs, " DNSSECERR") + if dnssecErr { + addTag(bufs, " DNSSECERR") } if h.RCode != dnsmessage.RCodeSuccess { - addtag(bufs, " "+h.RCode.String()) + addTag(bufs, " "+h.RCode.String()) } - addtag(bufs, " "+strconv.Itoa(ansCount)+"/"+strconv.Itoa(authCount)+"/"+strconv.Itoa(addtCount)) + addTag(bufs, " "+strconv.Itoa(ansCount)+"/"+strconv.Itoa(authCount)+"/"+strconv.Itoa(addtCount)) } - if !dnssecerr && !geoerr && !typeerr && !addterr && !toofast && !blacklisted && - (h.RCode == dnsmessage.RCodeSuccess && qtype == dnsmessage.TypeA && hasa || - h.RCode == dnsmessage.RCodeSuccess && qtype == dnsmessage.TypeAAAA && (hasaaaa || hascname || authCount > 0) || - h.RCode == dnsmessage.RCodeSuccess && qtype != dnsmessage.TypeA && qtype != dnsmessage.TypeAAAA || + if !dnssecErr && !geoErr && !typeErr && !addtErr && !tooFast && !inBlacklist && + (h.RCode == dnsmessage.RCodeSuccess && qType == dnsmessage.TypeA && hasA || + h.RCode == dnsmessage.RCodeSuccess && qType == dnsmessage.TypeAAAA && (hasAAAA || hasCNAME || authCount > 0) || + h.RCode == dnsmessage.RCodeSuccess && qType != dnsmessage.TypeA && qType != dnsmessage.TypeAAAA || h.RCode == dnsmessage.RCodeNameError) { - switch ns.stype { + switch ns.sType { case domestic: - if pcount == 1 { - ch <- a + if pktCount == 1 { + chAnswer <- a answered = true if *verbose { - addtag(bufs, " [ACCEPT]") + addTag(bufs, " [ACCEPT]") } } else { if *verbose { - addtag(bufs, " [IGNORE]") + addTag(bufs, " [IGNORE]") } } case foreign: - if pcount == 1 { - if qtype != dnsmessage.TypeA && qtype != dnsmessage.TypeAAAA || - rtt > time.Duration(*minsafe)*time.Millisecond || hascname || ansCount > 1 { + if pktCount == 1 { + if qType != dnsmessage.TypeA && qType != dnsmessage.TypeAAAA || + rtt > time.Duration(*minsafe)*time.Millisecond || hasCNAME || ansCount > 1 { if *trusted && rtt < time.Duration(*minwait)*time.Millisecond { time.Sleep(time.Duration(*minwait)*time.Millisecond - rtt) if *verbose { - addtag(bufs, " [DELAYED]") + addTag(bufs, " [DELAYED]") } } else { if *verbose { - addtag(bufs, " [ACCEPT]") + addTag(bufs, " [ACCEPT]") } } - ch <- a + chAnswer <- a answered = true } else { if *verbose { - addtag(bufs, " [SAVE]") + addTag(bufs, " [SAVE]") } - chsave <- a + chSave <- a } } else { - if rtt-firstrtt > time.Duration(*maxdur)*time.Millisecond || hascname || ansCount > 1 { + if rtt-firstRTT > time.Duration(*maxdur)*time.Millisecond || hasCNAME || ansCount > 1 { if !answered { - ch <- a + chAnswer <- a answered = true if *verbose { - addtag(bufs, " [ACCEPT]") + addTag(bufs, " [ACCEPT]") } } else { if *verbose { - addtag(bufs, " [IGNORE]") + addTag(bufs, " [IGNORE]") } } } else { if *verbose { - addtag(bufs, " [SAVE]") + addTag(bufs, " [SAVE]") } - chsave <- a + chSave <- a } } } } else { if *verbose { - addtag(bufs, " [DROP]") + addTag(bufs, " [DROP]") } } if *verbose { @@ -577,58 +577,58 @@ func main() { return } if *ipnet4file != "" { - cnipnet4 = parseiplist(*ipnet4file, net.IPv4len) - if cnipnet4 != nil { - logger.Printf("Loaded %d domestic IPv4 entries", len(cnipnet4)) - sort.Sort(byByte(cnipnet4)) + cnIPNet4 = parseIPList(*ipnet4file, net.IPv4len) + if cnIPNet4 != nil { + logger.Printf("Loaded %d domestic IPv4 entries", len(cnIPNet4)) + sort.Sort(byByte(cnIPNet4)) } } - if cnipnet4 == nil { + if cnIPNet4 == nil { errlog.Fatalln("Domestic IPv4 list must be provided") } if *ipnet6file != "" { - cnipnet6 = parseiplist(*ipnet6file, net.IPv6len) - if cnipnet6 != nil { - logger.Printf("Loaded %d domestic IPv6 entries", len(cnipnet6)) - sort.Sort(byByte(cnipnet6)) + cnIPNet6 = parseIPList(*ipnet6file, net.IPv6len) + if cnIPNet6 != nil { + logger.Printf("Loaded %d domestic IPv6 entries", len(cnIPNet6)) + sort.Sort(byByte(cnIPNet6)) } } if *blacklist4file != "" { - blackips4 = parseiplist(*blacklist4file, net.IPv4len) - if blackips4 != nil { - logger.Printf("Loaded %d blacklisted IPv4 entries", len(blackips4)) + blackIPs4 = parseIPList(*blacklist4file, net.IPv4len) + if blackIPs4 != nil { + logger.Printf("Loaded %d blacklisted IPv4 entries", len(blackIPs4)) } } if *blacklist6file != "" { - blackips6 = parseiplist(*blacklist6file, net.IPv6len) - if blackips6 != nil { - logger.Printf("Loaded %d blacklisted IPv6 entries", len(blackips6)) + blackIPs6 = parseIPList(*blacklist6file, net.IPv6len) + if blackIPs6 != nil { + logger.Printf("Loaded %d blacklisted IPv6 entries", len(blackIPs6)) } } - parseservers(*dservers, domestic) - parseservers(*fservers, foreign) + parseServers(*dservers, domestic) + parseServers(*fservers, foreign) if *trusted { logger.Print("Foreign servers in trustworthy mode") *minsafe = 0 *minrtt = 0 } - addr, err := parseudpaddr(*localnet) + addr, err := parseUDPAddr(*localnet) if err != nil { errlog.Fatalf("Invalid binding address: %s", *localnet) } - inconn, err := net.ListenUDP("udp", addr) + inConn, err := net.ListenUDP("udp", addr) if err != nil { errlog.Fatalln(err) } - defer inconn.Close() + defer inConn.Close() logger.Printf("Listening on UDP %s", addr) for { payload := make([]byte, 1500) - if n, addr, err := inconn.ReadFromUDP(payload); err != nil { + if n, addr, err := inConn.ReadFromUDP(payload); err != nil { errlog.Println(err) continue } else { - go handlequery(addr, payload[:n], inconn) + go handleQuery(addr, payload[:n], inConn) } } }