diff --git a/example/Corefile b/example/Corefile index 34cdcea..728f95e 100644 --- a/example/Corefile +++ b/example/Corefile @@ -3,7 +3,11 @@ debug prometheus - blocklist https://raw.githubusercontent.com/StevenBlack/hosts/master/hosts + blocklist https://raw.githubusercontent.com/StevenBlack/hosts/master/hosts { + # if CoreDNS listens at 53, you need another DNS to bootstrap the download + bootstrap_dns 1.1.1.1:53 + } + blocklist blocklist.txt { allowlist allowlist.txt domain_metrics diff --git a/list_loader.go b/list_loader.go index 04e85f4..e7eb724 100644 --- a/list_loader.go +++ b/list_loader.go @@ -2,26 +2,33 @@ package blocklist import ( "bufio" + "context" "io" + "net" "net/http" "os" "path/filepath" "regexp" "strings" + "time" "github.com/coredns/caddy" ) -func loadList(c *caddy.Controller, location string) ([]string, error) { +func loadList(c *caddy.Controller, location string, bootStrapDNS string) ([]string, error) { log.Infof("Loading from %s", location) if strings.HasPrefix(location, "http://") || strings.HasPrefix(location, "https://") { - return loadListFromUrl(c, location) + return loadListFromUrl(c, location, bootStrapDNS) } return loadListFromFile(c, location) } -func loadListFromUrl(c *caddy.Controller, name string) ([]string, error) { - response, err := http.Get(name) +func loadListFromUrl(c *caddy.Controller, name string, bootStrapDNS string) ([]string, error) { + client := &http.Client{} + if bootStrapDNS != "" { + client = customDNS(bootStrapDNS) + } + response, err := client.Get(name) if err != nil { return nil, err } @@ -29,6 +36,38 @@ func loadListFromUrl(c *caddy.Controller, name string) ([]string, error) { return collectDomains(response.Body, name) } +func customDNS(bootStrapDNS string) *http.Client { + var ( + dnsResolverIP = bootStrapDNS // Google DNS resolver. + dnsResolverProto = "udp" // Protocol to use for the DNS resolver + dnsResolverTimeoutMs = 5000 // Timeout (ms) for the DNS resolver (optional) + ) + + dialer := &net.Dialer{ + Resolver: &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + d := net.Dialer{ + Timeout: time.Duration(dnsResolverTimeoutMs) * time.Millisecond, + } + return d.DialContext(ctx, dnsResolverProto, dnsResolverIP) + }, + }, + } + dialContext := func(ctx context.Context, network, addr string) (net.Conn, error) { + return dialer.DialContext(ctx, network, addr) + } + tr := &http.Transport{ + MaxIdleConns: 10, + IdleConnTimeout: 30 * time.Second, + DisableCompression: true, + DialContext: dialContext, + } + client := &http.Client{Transport: tr} + + return client +} + func loadListFromFile(c *caddy.Controller, name string) ([]string, error) { if !filepath.IsAbs(name) { name = filepath.Join( diff --git a/readme.md b/readme.md index c8a9647..f4c9f5e 100644 --- a/readme.md +++ b/readme.md @@ -15,7 +15,11 @@ domain on each line. There is an example file in the example folder. prometheus # load from url - blocklist https://mirror1.malwaredomains.com/files/justdomains + blocklist https://mirror1.malwaredomains.com/files/justdomains { + # if CoreDNS listens at 53, you need another DNS to bootstrap the download + bootstrap_dns 1.1.1.1:53 + } + # load from file, if the path is not absolute it will be relative to the Corefile blocklist blocklist.txt diff --git a/setup.go b/setup.go index af081fc..be9cb00 100644 --- a/setup.go +++ b/setup.go @@ -20,6 +20,7 @@ func setup(c *caddy.Controller) error { var allowlistLocation string var allowlist []string var blockResponse string + var bootStrapDNS string c.Args(&blocklistLocation) if blocklistLocation == "" { @@ -39,6 +40,8 @@ func setup(c *caddy.Controller) error { log.Debugf("Setting allowlist location to %s", allowlistLocation) case "domain_metrics": domainMetrics = true + case "bootstrap_dns": + bootStrapDNS = c.RemainingArgs()[0] case "block_response": remaining := c.RemainingArgs() if len(remaining) != 1 { @@ -56,13 +59,13 @@ func setup(c *caddy.Controller) error { return plugin.Error("blocklist", errors.New("To many arguments for blocklist.")) } - blocklist, err := loadList(c, blocklistLocation) + blocklist, err := loadList(c, blocklistLocation, bootStrapDNS) if err != nil { return plugin.Error("blocklist", err) } if allowlistLocation != "" { - allowlist, err = loadList(c, allowlistLocation) + allowlist, err = loadList(c, allowlistLocation, bootStrapDNS) if err != nil { return plugin.Error("blocklist", err) }