diff --git a/.golangci.yml b/.golangci.yml index 5f7573adb..8fc20d1d6 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -55,9 +55,10 @@ linters: - whitespace - wsl - ginkgolinter - disable: - noctx + - containedctx - contextcheck + disable: - scopelint - structcheck - deadcode diff --git a/e2e/basic_test.go b/e2e/basic_test.go index 9b6c27c5e..1ec7876c3 100644 --- a/e2e/basic_test.go +++ b/e2e/basic_test.go @@ -18,15 +18,15 @@ var _ = Describe("Basic functional tests", func() { var err error Describe("Container start", func() { - BeforeEach(func() { - moka, err = createDNSMokkaContainer("moka1", `A google/NOERROR("A 1.2.3.4 123")`) + BeforeEach(func(ctx context.Context) { + moka, err = createDNSMokkaContainer(ctx, "moka1", `A google/NOERROR("A 1.2.3.4 123")`) Expect(err).Should(Succeed()) DeferCleanup(moka.Terminate) }) When("wrong port configuration is provided", func() { - BeforeEach(func() { - blocky, err = createBlockyContainer(tmpDir, + BeforeEach(func(ctx context.Context) { + blocky, err = createBlockyContainer(ctx, tmpDir, "upstreams:", " groups:", " default:", @@ -38,22 +38,22 @@ var _ = Describe("Basic functional tests", func() { Expect(err).Should(HaveOccurred()) // check container exit status - state, err := blocky.State(context.Background()) + state, err := blocky.State(ctx) Expect(err).Should(Succeed()) Expect(state.ExitCode).Should(Equal(1)) DeferCleanup(blocky.Terminate) }) - It("should fail to start", func() { + It("should fail to start", func(ctx context.Context) { Eventually(blocky.IsRunning, "5s", "2ms").Should(BeFalse()) - Expect(getContainerLogs(blocky)). + Expect(getContainerLogs(ctx, blocky)). Should(ContainElement(ContainSubstring("address already in use"))) }) }) When("Minimal configuration is provided", func() { - BeforeEach(func() { - blocky, err = createBlockyContainer(tmpDir, + BeforeEach(func(ctx context.Context) { + blocky, err = createBlockyContainer(ctx, tmpDir, "upstreams:", " groups:", " default:", @@ -63,19 +63,19 @@ var _ = Describe("Basic functional tests", func() { Expect(err).Should(Succeed()) DeferCleanup(blocky.Terminate) }) - It("Should start and answer DNS queries", func() { + It("Should start and answer DNS queries", func(ctx context.Context) { msg := util.NewMsgWithQuestion("google.de.", A) - Expect(doDNSRequest(blocky, msg)). + Expect(doDNSRequest(ctx, blocky, msg)). Should( SatisfyAll( BeDNSRecord("google.de.", A, "1.2.3.4"), HaveTTL(BeNumerically("==", 123)), )) }) - It("should return 'healthy' container status (healthcheck)", func() { + It("should return 'healthy' container status (healthcheck)", func(ctx context.Context) { Eventually(func(g Gomega) string { - state, err := blocky.State(context.Background()) + state, err := blocky.State(ctx) g.Expect(err).NotTo(HaveOccurred()) return state.Health.Status @@ -84,8 +84,8 @@ var _ = Describe("Basic functional tests", func() { }) Context("http port configuration", func() { When("'httpPort' is not defined", func() { - BeforeEach(func() { - blocky, err = createBlockyContainer(tmpDir, + BeforeEach(func(ctx context.Context) { + blocky, err = createBlockyContainer(ctx, tmpDir, "upstreams:", " groups:", " default:", @@ -96,8 +96,8 @@ var _ = Describe("Basic functional tests", func() { DeferCleanup(blocky.Terminate) }) - It("should not open http port", func() { - host, port, err := getContainerHostPort(blocky, "4000/tcp") + It("should not open http port", func(ctx context.Context) { + host, port, err := getContainerHostPort(ctx, blocky, "4000/tcp") Expect(err).Should(Succeed()) _, err = http.Get(fmt.Sprintf("http://%s", net.JoinHostPort(host, port))) @@ -105,8 +105,8 @@ var _ = Describe("Basic functional tests", func() { }) }) When("'httpPort' is defined", func() { - BeforeEach(func() { - blocky, err = createBlockyContainer(tmpDir, + BeforeEach(func(ctx context.Context) { + blocky, err = createBlockyContainer(ctx, tmpDir, "upstreams:", " groups:", " default:", @@ -118,8 +118,8 @@ var _ = Describe("Basic functional tests", func() { Expect(err).Should(Succeed()) DeferCleanup(blocky.Terminate) }) - It("should serve http content", func() { - host, port, err := getContainerHostPort(blocky, "4000/tcp") + It("should serve http content", func(ctx context.Context) { + host, port, err := getContainerHostPort(ctx, blocky, "4000/tcp") Expect(err).Should(Succeed()) url := fmt.Sprintf("http://%s", net.JoinHostPort(host, port)) @@ -142,15 +142,15 @@ var _ = Describe("Basic functional tests", func() { }) Describe("Logging", func() { - BeforeEach(func() { - moka, err = createDNSMokkaContainer("moka1", `A google/NOERROR("A 1.2.3.4 123")`) + BeforeEach(func(ctx context.Context) { + moka, err = createDNSMokkaContainer(ctx, "moka1", `A google/NOERROR("A 1.2.3.4 123")`) Expect(err).Should(Succeed()) DeferCleanup(moka.Terminate) }) When("log privacy is enabled", func() { - BeforeEach(func() { - blocky, err = createBlockyContainer(tmpDir, + BeforeEach(func(ctx context.Context) { + blocky, err = createBlockyContainer(ctx, tmpDir, "upstreams:", " groups:", " default:", @@ -162,27 +162,27 @@ var _ = Describe("Basic functional tests", func() { Expect(err).Should(Succeed()) DeferCleanup(blocky.Terminate) }) - It("should not log answers and questions", func() { + It("should not log answers and questions", func(ctx context.Context) { msg := util.NewMsgWithQuestion("google.com.", A) // do 2 requests - Expect(doDNSRequest(blocky, msg)). + Expect(doDNSRequest(ctx, blocky, msg)). Should( SatisfyAll( BeDNSRecord("google.com.", A, "1.2.3.4"), HaveTTL(BeNumerically("==", 123)), )) - Expect(doDNSRequest(blocky, msg)). + Expect(doDNSRequest(ctx, blocky, msg)). Should( SatisfyAll( BeDNSRecord("google.com.", A, "1.2.3.4"), HaveTTL(BeNumerically("<=", 123)), )) - Expect(getContainerLogs(blocky)).Should(Not(ContainElement(ContainSubstring("google.com")))) - Expect(getContainerLogs(blocky)).Should(Not(ContainElement(ContainSubstring("1.2.3.4")))) + Expect(getContainerLogs(ctx, blocky)).Should(Not(ContainElement(ContainSubstring("google.com")))) + Expect(getContainerLogs(ctx, blocky)).Should(Not(ContainElement(ContainSubstring("1.2.3.4")))) }) }) }) diff --git a/e2e/blocking_test.go b/e2e/blocking_test.go index a30afce42..90745c3fe 100644 --- a/e2e/blocking_test.go +++ b/e2e/blocking_test.go @@ -13,8 +13,8 @@ import ( var _ = Describe("External lists and query blocking", func() { var blocky, httpServer, moka testcontainers.Container var err error - BeforeEach(func() { - moka, err = createDNSMokkaContainer("moka", `A google/NOERROR("A 1.2.3.4 123")`) + BeforeEach(func(ctx context.Context) { + moka, err = createDNSMokkaContainer(ctx, "moka", `A google/NOERROR("A 1.2.3.4 123")`) Expect(err).Should(Succeed()) DeferCleanup(moka.Terminate) @@ -22,8 +22,8 @@ var _ = Describe("External lists and query blocking", func() { Describe("List download on startup", func() { When("external blacklist ist not available", func() { Context("loading.strategy = blocking", func() { - BeforeEach(func() { - blocky, err = createBlockyContainer(tmpDir, + BeforeEach(func(ctx context.Context) { + blocky, err = createBlockyContainer(ctx, tmpDir, "log:", " level: warn", "upstreams:", @@ -45,22 +45,22 @@ var _ = Describe("External lists and query blocking", func() { DeferCleanup(blocky.Terminate) }) - It("should start with warning in log work without errors", func() { + It("should start with warning in log work without errors", func(ctx context.Context) { msg := util.NewMsgWithQuestion("google.com.", A) - Expect(doDNSRequest(blocky, msg)). + Expect(doDNSRequest(ctx, blocky, msg)). Should( SatisfyAll( BeDNSRecord("google.com.", A, "1.2.3.4"), HaveTTL(BeNumerically("==", 123)), )) - Expect(getContainerLogs(blocky)).Should(ContainElement(ContainSubstring("cannot open source: "))) + Expect(getContainerLogs(ctx, blocky)).Should(ContainElement(ContainSubstring("cannot open source: "))) }) }) Context("loading.strategy = failOnError", func() { - BeforeEach(func() { - blocky, err = createBlockyContainer(tmpDir, + BeforeEach(func(ctx context.Context) { + blocky, err = createBlockyContainer(ctx, tmpDir, "log:", " level: warn", "upstreams:", @@ -81,17 +81,17 @@ var _ = Describe("External lists and query blocking", func() { Expect(err).Should(HaveOccurred()) // check container exit status - state, err := blocky.State(context.Background()) + state, err := blocky.State(ctx) Expect(err).Should(Succeed()) Expect(state.ExitCode).Should(Equal(1)) DeferCleanup(blocky.Terminate) }) - It("should fail to start", func() { + It("should fail to start", func(ctx context.Context) { Eventually(blocky.IsRunning, "5s", "2ms").Should(BeFalse()) - Expect(getContainerLogs(blocky)). + Expect(getContainerLogs(ctx, blocky)). Should(ContainElement(ContainSubstring("Error: can't start server: 1 error occurred"))) }) }) @@ -99,13 +99,13 @@ var _ = Describe("External lists and query blocking", func() { }) Describe("Query blocking against external blacklists", func() { When("external blacklists are defined and available", func() { - BeforeEach(func() { - httpServer, err = createHTTPServerContainer("httpserver", tmpDir, "list.txt", "blockeddomain.com") + BeforeEach(func(ctx context.Context) { + httpServer, err = createHTTPServerContainer(ctx, "httpserver", tmpDir, "list.txt", "blockeddomain.com") Expect(err).Should(Succeed()) DeferCleanup(httpServer.Terminate) - blocky, err = createBlockyContainer(tmpDir, + blocky, err = createBlockyContainer(ctx, tmpDir, "log:", " level: warn", "upstreams:", @@ -124,17 +124,17 @@ var _ = Describe("External lists and query blocking", func() { Expect(err).Should(Succeed()) DeferCleanup(blocky.Terminate) }) - It("should download external list on startup and block queries", func() { + It("should download external list on startup and block queries", func(ctx context.Context) { msg := util.NewMsgWithQuestion("blockeddomain.com.", A) - Expect(doDNSRequest(blocky, msg)). + Expect(doDNSRequest(ctx, blocky, msg)). Should( SatisfyAll( BeDNSRecord("blockeddomain.com.", A, "0.0.0.0"), HaveTTL(BeNumerically("==", 6*60*60)), )) - Expect(getContainerLogs(blocky)).Should(BeEmpty()) + Expect(getContainerLogs(ctx, blocky)).Should(BeEmpty()) }) }) }) diff --git a/e2e/containers.go b/e2e/containers.go index 19f741c11..73cfaea7b 100644 --- a/e2e/containers.go +++ b/e2e/containers.go @@ -39,9 +39,7 @@ const ( blockyImage = "blocky-e2e" ) -func createDNSMokkaContainer(alias string, rules ...string) (testcontainers.Container, error) { - ctx := context.Background() - +func createDNSMokkaContainer(ctx context.Context, alias string, rules ...string) (testcontainers.Container, error) { mokaRules := make(map[string]string) for i, rule := range rules { @@ -63,7 +61,7 @@ func createDNSMokkaContainer(alias string, rules ...string) (testcontainers.Cont }) } -func createHTTPServerContainer(alias string, tmpDir *helpertest.TmpFolder, +func createHTTPServerContainer(ctx context.Context, alias string, tmpDir *helpertest.TmpFolder, filename string, lines ...string, ) (testcontainers.Container, error) { f1 := tmpDir.CreateStringFile(filename, @@ -75,7 +73,6 @@ func createHTTPServerContainer(alias string, tmpDir *helpertest.TmpFolder, const modeOwner = 700 - ctx := context.Background() req := testcontainers.ContainerRequest{ Image: staticServerImage, Networks: []string{NetworkName}, @@ -105,9 +102,7 @@ func WithNetwork(network string) testcontainers.CustomizeRequestOption { } } -func createRedisContainer() (*redis.RedisContainer, error) { - ctx := context.Background() - +func createRedisContainer(ctx context.Context) (*redis.RedisContainer, error) { return redis.RunContainer(ctx, testcontainers.WithImage(redisImage), redis.WithLogLevel(redis.LogLevelVerbose), @@ -115,9 +110,7 @@ func createRedisContainer() (*redis.RedisContainer, error) { ) } -func createPostgresContainer() (*postgres.PostgresContainer, error) { - ctx := context.Background() - +func createPostgresContainer(ctx context.Context) (*postgres.PostgresContainer, error) { const waitLogOccurrence = 2 return postgres.RunContainer(ctx, @@ -134,9 +127,7 @@ func createPostgresContainer() (*postgres.PostgresContainer, error) { ) } -func createMariaDBContainer() (*mariadb.MariaDBContainer, error) { - ctx := context.Background() - +func createMariaDBContainer(ctx context.Context) (*mariadb.MariaDBContainer, error) { return mariadb.RunContainer(ctx, testcontainers.WithImage(mariaDBImage), mariadb.WithDatabase("user"), @@ -151,7 +142,9 @@ const ( startupTimeout = 30 * time.Second ) -func createBlockyContainer(tmpDir *helpertest.TmpFolder, lines ...string) (testcontainers.Container, error) { +func createBlockyContainer(ctx context.Context, tmpDir *helpertest.TmpFolder, + lines ...string, +) (testcontainers.Container, error) { f1 := tmpDir.CreateStringFile("config1.yaml", lines..., ) @@ -164,7 +157,6 @@ func createBlockyContainer(tmpDir *helpertest.TmpFolder, lines ...string) (testc return nil, fmt.Errorf("can't create config struct %w", err) } - ctx := context.Background() req := testcontainers.ContainerRequest{ Image: blockyImage, Networks: []string{NetworkName}, @@ -192,7 +184,7 @@ func createBlockyContainer(tmpDir *helpertest.TmpFolder, lines ...string) (testc }) if err != nil { // attach container log if error occurs - if r, err := container.Logs(context.Background()); err == nil { + if r, err := container.Logs(ctx); err == nil { if b, err := io.ReadAll(r); err == nil { ginkgo.AddReportEntry("blocky container log", string(b)) } @@ -203,7 +195,7 @@ func createBlockyContainer(tmpDir *helpertest.TmpFolder, lines ...string) (testc // check if DNS/HTTP interface is working. // Sometimes the internal health check returns OK, but the container port is not mapped yet - err = checkBlockyReadiness(cfg, container) + err = checkBlockyReadiness(ctx, cfg, container) if err != nil { return container, fmt.Errorf("container not ready: %w", err) @@ -212,14 +204,14 @@ func createBlockyContainer(tmpDir *helpertest.TmpFolder, lines ...string) (testc return container, nil } -func checkBlockyReadiness(cfg *config.Config, container testcontainers.Container) error { +func checkBlockyReadiness(ctx context.Context, cfg *config.Config, container testcontainers.Container) error { var err error const retryAttempts = 3 err = retry.Do( func() error { - _, err = doDNSRequest(container, util.NewMsgWithQuestion("healthcheck.blocky.", dns.Type(dns.TypeA))) + _, err = doDNSRequest(ctx, container, util.NewMsgWithQuestion("healthcheck.blocky.", dns.Type(dns.TypeA))) return err }, @@ -239,7 +231,7 @@ func checkBlockyReadiness(cfg *config.Config, container testcontainers.Container port := parts[len(parts)-1] err = retry.Do( func() error { - return doHTTPRequest(container, port) + return doHTTPRequest(ctx, container, port) }, retry.OnRetry(func(n uint, err error) { log.Infof("Performing retry HTTP request #%d: %s\n", n, err) @@ -256,13 +248,19 @@ func checkBlockyReadiness(cfg *config.Config, container testcontainers.Container return nil } -func doHTTPRequest(container testcontainers.Container, containerPort string) error { - host, port, err := getContainerHostPort(container, nat.Port(fmt.Sprintf("%s/tcp", containerPort))) +func doHTTPRequest(ctx context.Context, container testcontainers.Container, containerPort string) error { + host, port, err := getContainerHostPort(ctx, container, nat.Port(fmt.Sprintf("%s/tcp", containerPort))) + if err != nil { + return err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, + fmt.Sprintf("http://%s", net.JoinHostPort(host, port)), nil) if err != nil { return err } - resp, err := http.Get(fmt.Sprintf("http://%s", net.JoinHostPort(host, port))) + resp, err := http.DefaultClient.Do(req) if err != nil { return err } @@ -276,7 +274,7 @@ func doHTTPRequest(container testcontainers.Container, containerPort string) err return err } -func doDNSRequest(container testcontainers.Container, message *dns.Msg) (*dns.Msg, error) { +func doDNSRequest(ctx context.Context, container testcontainers.Container, message *dns.Msg) (*dns.Msg, error) { const timeout = 5 * time.Second c := &dns.Client{ @@ -284,7 +282,7 @@ func doDNSRequest(container testcontainers.Container, message *dns.Msg) (*dns.Ms Timeout: timeout, } - host, port, err := getContainerHostPort(container, "53/tcp") + host, port, err := getContainerHostPort(ctx, container, "53/tcp") if err != nil { return nil, err } @@ -294,13 +292,13 @@ func doDNSRequest(container testcontainers.Container, message *dns.Msg) (*dns.Ms return msg, err } -func getContainerHostPort(c testcontainers.Container, p nat.Port) (host, port string, err error) { - res, err := c.MappedPort(context.Background(), p) +func getContainerHostPort(ctx context.Context, c testcontainers.Container, p nat.Port) (host, port string, err error) { + res, err := c.MappedPort(ctx, p) if err != nil { return "", "", err } - host, err = c.Host(context.Background()) + host, err = c.Host(ctx) if err != nil { return "", "", err @@ -309,8 +307,8 @@ func getContainerHostPort(c testcontainers.Container, p nat.Port) (host, port st return host, res.Port(), err } -func getContainerLogs(c testcontainers.Container) (lines []string, err error) { - if r, err := c.Logs(context.Background()); err == nil { +func getContainerLogs(ctx context.Context, c testcontainers.Container) (lines []string, err error) { + if r, err := c.Logs(ctx); err == nil { scanner := bufio.NewScanner(r) for scanner.Scan() { line := scanner.Text() diff --git a/e2e/e2e_suite_test.go b/e2e/e2e_suite_test.go index a4a6feb74..dca3d22f5 100644 --- a/e2e/e2e_suite_test.go +++ b/e2e/e2e_suite_test.go @@ -28,10 +28,10 @@ var ( tmpDir *helpertest.TmpFolder ) -var _ = BeforeSuite(func() { +var _ = BeforeSuite(func(ctx context.Context) { var err error - network, err = testcontainers.GenericNetwork(context.Background(), testcontainers.GenericNetworkRequest{ + network, err = testcontainers.GenericNetwork(ctx, testcontainers.GenericNetworkRequest{ NetworkRequest: testcontainers.NetworkRequest{ Name: NetworkName, CheckDuplicate: false, @@ -41,10 +41,10 @@ var _ = BeforeSuite(func() { Expect(err).Should(Succeed()) - DeferCleanup(func() { + DeferCleanup(func(ctx context.Context) { err := retry.Do( func() error { - return network.Remove(context.Background()) + return network.Remove(ctx) }, retry.Attempts(3), retry.DelayType(retry.BackOffDelay), diff --git a/e2e/metrics_test.go b/e2e/metrics_test.go index 7194697be..7e9f3cd71 100644 --- a/e2e/metrics_test.go +++ b/e2e/metrics_test.go @@ -2,6 +2,7 @@ package e2e import ( "bufio" + "context" "fmt" "net" "net/http" @@ -20,23 +21,24 @@ var _ = Describe("Metrics functional tests", func() { var metricsURL string Describe("Metrics", func() { - BeforeEach(func() { - moka, err = createDNSMokkaContainer("moka1", `A google/NOERROR("A 1.2.3.4 123")`) + BeforeEach(func(ctx context.Context) { + moka, err = createDNSMokkaContainer(ctx, "moka1", `A google/NOERROR("A 1.2.3.4 123")`) Expect(err).Should(Succeed()) DeferCleanup(moka.Terminate) - httpServer1, err = createHTTPServerContainer("httpserver1", tmpDir, "list1.txt", "domain1.com") + httpServer1, err = createHTTPServerContainer(ctx, "httpserver1", tmpDir, "list1.txt", "domain1.com") Expect(err).Should(Succeed()) DeferCleanup(httpServer1.Terminate) - httpServer2, err = createHTTPServerContainer("httpserver2", tmpDir, "list2.txt", "domain1.com", "domain2", "domain3") + httpServer2, err = createHTTPServerContainer(ctx, "httpserver2", tmpDir, "list2.txt", + "domain1.com", "domain2", "domain3") Expect(err).Should(Succeed()) DeferCleanup(httpServer2.Terminate) - blocky, err = createBlockyContainer(tmpDir, + blocky, err = createBlockyContainer(ctx, tmpDir, "upstreams:", " groups:", " default:", @@ -56,26 +58,26 @@ var _ = Describe("Metrics functional tests", func() { Expect(err).Should(Succeed()) DeferCleanup(blocky.Terminate) - host, port, err := getContainerHostPort(blocky, "4000/tcp") + host, port, err := getContainerHostPort(ctx, blocky, "4000/tcp") Expect(err).Should(Succeed()) metricsURL = fmt.Sprintf("http://%s/metrics", net.JoinHostPort(host, port)) }) When("Blocky is started", func() { - It("Should provide 'blocky_build_info' prometheus metrics", func() { - Eventually(fetchBlockyMetrics).WithArguments(metricsURL). + It("Should provide 'blocky_build_info' prometheus metrics", func(ctx context.Context) { + Eventually(fetchBlockyMetrics).WithArguments(ctx, metricsURL). Should(ContainElement(ContainSubstring("blocky_build_info"))) }) - It("Should provide 'blocky_blocking_enabled' prometheus metrics", func() { - Eventually(fetchBlockyMetrics, "30s", "2ms").WithArguments(metricsURL). + It("Should provide 'blocky_blocking_enabled' prometheus metrics", func(ctx context.Context) { + Eventually(fetchBlockyMetrics, "30s", "2ms").WithArguments(ctx, metricsURL). Should(ContainElement("blocky_blocking_enabled 1")) }) }) When("Some query results are cached", func() { - BeforeEach(func() { - Eventually(fetchBlockyMetrics).WithArguments(metricsURL). + BeforeEach(func(ctx context.Context) { + Eventually(fetchBlockyMetrics).WithArguments(ctx, metricsURL). Should( SatisfyAll( ContainElement("blocky_cache_entry_count 0"), @@ -84,18 +86,18 @@ var _ = Describe("Metrics functional tests", func() { )) }) - It("Should increment cache counts", func() { + It("Should increment cache counts", func(ctx context.Context) { msg := util.NewMsgWithQuestion("google.de.", A) By("first query, should increment the cache miss count and the total count", func() { - Expect(doDNSRequest(blocky, msg)). + Expect(doDNSRequest(ctx, blocky, msg)). Should( SatisfyAll( BeDNSRecord("google.de.", A, "1.2.3.4"), HaveTTL(BeNumerically("==", 123)), )) - Eventually(fetchBlockyMetrics).WithArguments(metricsURL). + Eventually(fetchBlockyMetrics).WithArguments(ctx, metricsURL). Should( SatisfyAll( ContainElement("blocky_cache_entry_count 1"), @@ -105,14 +107,14 @@ var _ = Describe("Metrics functional tests", func() { }) By("Same query again, should increment the cache hit count", func() { - Expect(doDNSRequest(blocky, msg)). + Expect(doDNSRequest(ctx, blocky, msg)). Should( SatisfyAll( BeDNSRecord("google.de.", A, "1.2.3.4"), HaveTTL(BeNumerically("<=", 123)), )) - Eventually(fetchBlockyMetrics).WithArguments(metricsURL). + Eventually(fetchBlockyMetrics).WithArguments(ctx, metricsURL). Should( SatisfyAll( ContainElement("blocky_cache_entry_count 1"), @@ -124,8 +126,8 @@ var _ = Describe("Metrics functional tests", func() { }) When("Lists are loaded", func() { - It("Should expose list cache sizes per group as metrics", func() { - Eventually(fetchBlockyMetrics).WithArguments(metricsURL). + It("Should expose list cache sizes per group as metrics", func(ctx context.Context) { + Eventually(fetchBlockyMetrics).WithArguments(ctx, metricsURL). Should( SatisfyAll( ContainElement("blocky_blacklist_cache{group=\"group1\"} 1"), @@ -136,10 +138,15 @@ var _ = Describe("Metrics functional tests", func() { }) }) -func fetchBlockyMetrics(url string) ([]string, error) { +func fetchBlockyMetrics(ctx context.Context, url string) ([]string, error) { var metrics []string - r, err := http.Get(url) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, err + } + + r, err := http.DefaultClient.Do(req) if err != nil { return nil, err } diff --git a/e2e/querylog_test.go b/e2e/querylog_test.go index 041d6aaf9..af4c5756e 100644 --- a/e2e/querylog_test.go +++ b/e2e/querylog_test.go @@ -22,20 +22,20 @@ var _ = Describe("Query logs functional tests", func() { var db *gorm.DB var err error - BeforeEach(func() { - moka, err = createDNSMokkaContainer("moka1", `A google/NOERROR("A 1.2.3.4 123")`, `A unknown/NXDOMAIN()`) + BeforeEach(func(ctx context.Context) { + moka, err = createDNSMokkaContainer(ctx, "moka1", `A google/NOERROR("A 1.2.3.4 123")`, `A unknown/NXDOMAIN()`) Expect(err).Should(Succeed()) DeferCleanup(moka.Terminate) }) Describe("Query logging into the mariaDB database", func() { - BeforeEach(func() { - mariaDB, err = createMariaDBContainer() + BeforeEach(func(ctx context.Context) { + mariaDB, err = createMariaDBContainer(ctx) Expect(err).Should(Succeed()) DeferCleanup(mariaDB.Terminate) - blocky, err = createBlockyContainer(tmpDir, + blocky, err = createBlockyContainer(ctx, tmpDir, "log:", " level: warn", "upstreams:", @@ -53,7 +53,7 @@ var _ = Describe("Query logs functional tests", func() { Expect(err).Should(Succeed()) - connectionString, err := mariaDB.ConnectionString(context.Background(), + connectionString, err := mariaDB.ConnectionString(ctx, "tls=false", "charset=utf8mb4", "parseTime=True", "loc=Local") Expect(err).Should(Succeed()) @@ -67,10 +67,12 @@ var _ = Describe("Query logs functional tests", func() { Eventually(countEntries).WithArguments(db).Should(BeNumerically("==", 0)) }) When("Some queries were performed", func() { - It("Should store query log in the mariaDB database", func() { + It("Should store query log in the mariaDB database", func(ctx context.Context) { By("Performing 2 queries", func() { - Expect(doDNSRequest(blocky, util.NewMsgWithQuestion("google.de.", dns.Type(dns.TypeA)))).Should(Not(BeNil())) - Expect(doDNSRequest(blocky, util.NewMsgWithQuestion("unknown.domain.", dns.Type(dns.TypeA)))).Should(Not(BeNil())) + Expect(doDNSRequest(ctx, blocky, + util.NewMsgWithQuestion("google.de.", dns.Type(dns.TypeA)))).Should(Not(BeNil())) + Expect(doDNSRequest(ctx, blocky, + util.NewMsgWithQuestion("unknown.domain.", dns.Type(dns.TypeA)))).Should(Not(BeNil())) }) By("check entries count asynchronously, since blocky flushes log entries in bulk", func() { @@ -108,12 +110,12 @@ var _ = Describe("Query logs functional tests", func() { }) Describe("Query logging into the postgres database", func() { - BeforeEach(func() { - postgresDB, err = createPostgresContainer() + BeforeEach(func(ctx context.Context) { + postgresDB, err = createPostgresContainer(ctx) Expect(err).Should(Succeed()) DeferCleanup(postgresDB.Terminate) - blocky, err = createBlockyContainer(tmpDir, + blocky, err = createBlockyContainer(ctx, tmpDir, "log:", " level: warn", "upstreams:", @@ -129,7 +131,7 @@ var _ = Describe("Query logs functional tests", func() { Expect(err).Should(Succeed()) DeferCleanup(blocky.Terminate) - connectionString, err := postgresDB.ConnectionString(context.Background(), "sslmode=disable") + connectionString, err := postgresDB.ConnectionString(ctx, "sslmode=disable") Expect(err).Should(Succeed()) // database might be slow on first start, retry here if necessary @@ -143,10 +145,10 @@ var _ = Describe("Query logs functional tests", func() { }) When("Some queries were performed", func() { msg := util.NewMsgWithQuestion("google.de.", dns.Type(dns.TypeA)) - It("Should store query log in the postgres database", func() { + It("Should store query log in the postgres database", func(ctx context.Context) { By("Performing 2 queries", func() { - Expect(doDNSRequest(blocky, msg)).Should(Not(BeNil())) - Expect(doDNSRequest(blocky, msg)).Should(Not(BeNil())) + Expect(doDNSRequest(ctx, blocky, msg)).Should(Not(BeNil())) + Expect(doDNSRequest(ctx, blocky, msg)).Should(Not(BeNil())) }) By("check entries count asynchronously, since blocky flushes log entries in bulk", func() { diff --git a/e2e/redis_test.go b/e2e/redis_test.go index 75ddf1e64..218587e4c 100644 --- a/e2e/redis_test.go +++ b/e2e/redis_test.go @@ -19,13 +19,13 @@ var _ = Describe("Redis configuration tests", func() { var redisClient *redis.Client var err error - BeforeEach(func() { - redisDB, err = createRedisContainer() + BeforeEach(func(ctx context.Context) { + redisDB, err = createRedisContainer(ctx) Expect(err).Should(Succeed()) DeferCleanup(redisDB.Terminate) - redisConnectionString, err := redisDB.ConnectionString(context.Background()) + redisConnectionString, err := redisDB.ConnectionString(ctx) Expect(err).Should(Succeed()) redisConnectionString = strings.ReplaceAll(redisConnectionString, "redis://", "") @@ -34,20 +34,20 @@ var _ = Describe("Redis configuration tests", func() { Addr: redisConnectionString, }) - Expect(dbSize(redisClient)).Should(BeNumerically("==", 0)) + Expect(dbSize(ctx, redisClient)).Should(BeNumerically("==", 0)) - moka, err = createDNSMokkaContainer("moka1", `A google/NOERROR("A 1.2.3.4 123")`) + moka, err = createDNSMokkaContainer(ctx, "moka1", `A google/NOERROR("A 1.2.3.4 123")`) Expect(err).Should(Succeed()) - DeferCleanup(func() { - _ = moka.Terminate(context.Background()) + DeferCleanup(func(ctx context.Context) { + _ = moka.Terminate(ctx) }) }) Describe("Cache sharing between blocky instances", func() { When("Redis and 2 blocky instances are configured", func() { - BeforeEach(func() { - blocky1, err = createBlockyContainer(tmpDir, + BeforeEach(func(ctx context.Context) { + blocky1, err = createBlockyContainer(ctx, tmpDir, "log:", " level: warn", "upstreams:", @@ -61,7 +61,7 @@ var _ = Describe("Redis configuration tests", func() { Expect(err).Should(Succeed()) DeferCleanup(blocky1.Terminate) - blocky2, err = createBlockyContainer(tmpDir, + blocky2, err = createBlockyContainer(ctx, tmpDir, "log:", " level: warn", "upstreams:", @@ -75,10 +75,10 @@ var _ = Describe("Redis configuration tests", func() { Expect(err).Should(Succeed()) DeferCleanup(blocky2.Terminate) }) - It("2nd instance of blocky should use cache from redis", func() { + It("2nd instance of blocky should use cache from redis", func(ctx context.Context) { msg := util.NewMsgWithQuestion("google.de.", A) By("Query first blocky instance, should store cache in redis", func() { - Eventually(doDNSRequest, "5s", "2ms").WithArguments(blocky1, msg). + Eventually(doDNSRequest, "5s", "2ms").WithArguments(ctx, blocky1, msg). Should( SatisfyAll( BeDNSRecord("google.de.", A, "1.2.3.4"), @@ -87,15 +87,15 @@ var _ = Describe("Redis configuration tests", func() { }) By("Check redis, must contain one cache entry", func() { - Eventually(dbSize, "5s", "2ms").WithArguments(redisClient).Should(BeNumerically("==", 1)) + Eventually(dbSize, "5s", "2ms").WithArguments(ctx, redisClient).Should(BeNumerically("==", 1)) }) By("Shutdown the upstream DNS server", func() { - Expect(moka.Terminate(context.Background())).Should(Succeed()) + Expect(moka.Terminate(ctx)).Should(Succeed()) }) By("Query second blocky instance, should use cache from redis", func() { - Eventually(doDNSRequest, "5s", "2ms").WithArguments(blocky2, msg). + Eventually(doDNSRequest, "5s", "2ms").WithArguments(ctx, blocky2, msg). Should( SatisfyAll( BeDNSRecord("google.de.", A, "1.2.3.4"), @@ -104,8 +104,8 @@ var _ = Describe("Redis configuration tests", func() { }) By("No warnings/errors in log", func() { - Expect(getContainerLogs(blocky1)).Should(BeEmpty()) - Expect(getContainerLogs(blocky2)).Should(BeEmpty()) + Expect(getContainerLogs(ctx, blocky1)).Should(BeEmpty()) + Expect(getContainerLogs(ctx, blocky2)).Should(BeEmpty()) }) }) }) @@ -113,8 +113,8 @@ var _ = Describe("Redis configuration tests", func() { Describe("Cache loading on startup", func() { When("Redis and 1 blocky instance are configured", func() { - BeforeEach(func() { - blocky1, err = createBlockyContainer(tmpDir, + BeforeEach(func(ctx context.Context) { + blocky1, err = createBlockyContainer(ctx, tmpDir, "log:", " level: warn", "upstreams:", @@ -128,10 +128,10 @@ var _ = Describe("Redis configuration tests", func() { Expect(err).Should(Succeed()) DeferCleanup(blocky1.Terminate) }) - It("should load cache from redis after start", func() { + It("should load cache from redis after start", func(ctx context.Context) { msg := util.NewMsgWithQuestion("google.de.", A) By("Query first blocky instance, should store cache in redis\"", func() { - Eventually(doDNSRequest, "5s", "2ms").WithArguments(blocky1, msg). + Eventually(doDNSRequest, "5s", "2ms").WithArguments(ctx, blocky1, msg). Should( SatisfyAll( BeDNSRecord("google.de.", A, "1.2.3.4"), @@ -140,11 +140,11 @@ var _ = Describe("Redis configuration tests", func() { }) By("Check redis, must contain one cache entry", func() { - Eventually(dbSize).WithArguments(redisClient).Should(BeNumerically("==", 1)) + Eventually(dbSize).WithArguments(ctx, redisClient).Should(BeNumerically("==", 1)) }) By("start other instance of blocky now -> it should load the cache from redis", func() { - blocky2, err = createBlockyContainer(tmpDir, + blocky2, err = createBlockyContainer(ctx, tmpDir, "log:", " level: warn", "upstreams:", @@ -160,11 +160,11 @@ var _ = Describe("Redis configuration tests", func() { }) By("Shutdown the upstream DNS server", func() { - Expect(moka.Terminate(context.Background())).Should(Succeed()) + Expect(moka.Terminate(ctx)).Should(Succeed()) }) By("Query second blocky instance", func() { - Eventually(doDNSRequest, "5s", "2ms").WithArguments(blocky2, msg). + Eventually(doDNSRequest, "5s", "2ms").WithArguments(ctx, blocky2, msg). Should( SatisfyAll( BeDNSRecord("google.de.", A, "1.2.3.4"), @@ -173,14 +173,14 @@ var _ = Describe("Redis configuration tests", func() { }) By("No warnings/errors in log", func() { - Expect(getContainerLogs(blocky1)).Should(BeEmpty()) - Expect(getContainerLogs(blocky2)).Should(BeEmpty()) + Expect(getContainerLogs(ctx, blocky1)).Should(BeEmpty()) + Expect(getContainerLogs(ctx, blocky2)).Should(BeEmpty()) }) }) }) }) }) -func dbSize(redisClient *redis.Client) (int64, error) { - return redisClient.DBSize(context.Background()).Result() +func dbSize(ctx context.Context, redisClient *redis.Client) (int64, error) { + return redisClient.DBSize(ctx).Result() } diff --git a/e2e/upstream_test.go b/e2e/upstream_test.go index c670056d8..43ea3a03a 100644 --- a/e2e/upstream_test.go +++ b/e2e/upstream_test.go @@ -1,6 +1,8 @@ package e2e import ( + "context" + . "github.com/0xERR0R/blocky/helpertest" "github.com/0xERR0R/blocky/util" "github.com/miekg/dns" @@ -15,8 +17,8 @@ var _ = Describe("Upstream resolver configuration tests", func() { Describe("'upstreams.startVerify' parameter handling", func() { When("'upstreams.startVerify' is false and upstream server as IP is not reachable", func() { - BeforeEach(func() { - blocky, err = createBlockyContainer(tmpDir, + BeforeEach(func(ctx context.Context) { + blocky, err = createBlockyContainer(ctx, tmpDir, "log:", " level: warn", "upstreams:", @@ -29,14 +31,14 @@ var _ = Describe("Upstream resolver configuration tests", func() { Expect(err).Should(Succeed()) DeferCleanup(blocky.Terminate) }) - It("should start even if upstream server is not reachable", func() { + It("should start even if upstream server is not reachable", func(ctx context.Context) { Expect(blocky.IsRunning()).Should(BeTrue()) - Expect(getContainerLogs(blocky)).Should(BeEmpty()) + Expect(getContainerLogs(ctx, blocky)).Should(BeEmpty()) }) }) When("'upstreams.startVerify' is false and upstream server as host name is not reachable", func() { - BeforeEach(func() { - blocky, err = createBlockyContainer(tmpDir, + BeforeEach(func(ctx context.Context) { + blocky, err = createBlockyContainer(ctx, tmpDir, "log:", " level: warn", "upstreams:", @@ -49,14 +51,14 @@ var _ = Describe("Upstream resolver configuration tests", func() { Expect(err).Should(Succeed()) DeferCleanup(blocky.Terminate) }) - It("should start even if upstream server is not reachable", func() { + It("should start even if upstream server is not reachable", func(ctx context.Context) { Expect(blocky.IsRunning()).Should(BeTrue()) - Expect(getContainerLogs(blocky)).Should(BeEmpty()) + Expect(getContainerLogs(ctx, blocky)).Should(BeEmpty()) }) }) When("'upstreams.startVerify' is true and upstream as IP address server is not reachable", func() { - BeforeEach(func() { - blocky, err = createBlockyContainer(tmpDir, + BeforeEach(func(ctx context.Context) { + blocky, err = createBlockyContainer(ctx, tmpDir, "upstreams:", " groups:", " default:", @@ -67,15 +69,15 @@ var _ = Describe("Upstream resolver configuration tests", func() { Expect(err).Should(HaveOccurred()) DeferCleanup(blocky.Terminate) }) - It("should not start", func() { + It("should not start", func(ctx context.Context) { Expect(blocky.IsRunning()).Should(BeFalse()) - Expect(getContainerLogs(blocky)). + Expect(getContainerLogs(ctx, blocky)). Should(ContainElement(ContainSubstring("no valid upstream for group default"))) }) }) When("'upstreams.startVerify' is true and upstream server as host name is not reachable", func() { - BeforeEach(func() { - blocky, err = createBlockyContainer(tmpDir, + BeforeEach(func(ctx context.Context) { + blocky, err = createBlockyContainer(ctx, tmpDir, "upstreams:", " groups:", " default:", @@ -86,24 +88,24 @@ var _ = Describe("Upstream resolver configuration tests", func() { Expect(err).Should(HaveOccurred()) DeferCleanup(blocky.Terminate) }) - It("should not start", func() { + It("should not start", func(ctx context.Context) { Expect(blocky.IsRunning()).Should(BeFalse()) - Expect(getContainerLogs(blocky)). + Expect(getContainerLogs(ctx, blocky)). Should(ContainElement(ContainSubstring("no valid upstream for group default"))) }) }) }) Describe("'upstreams.timeout' parameter handling", func() { var moka testcontainers.Container - BeforeEach(func() { - moka, err = createDNSMokkaContainer("moka1", + BeforeEach(func(ctx context.Context) { + moka, err = createDNSMokkaContainer(ctx, "moka1", `A example.com/NOERROR("A 1.2.3.4 123")`, `A delay.com/delay(NOERROR("A 1.1.1.1 100"), "300ms")`) Expect(err).Should(Succeed()) DeferCleanup(moka.Terminate) - blocky, err = createBlockyContainer(tmpDir, + blocky, err = createBlockyContainer(ctx, tmpDir, "upstreams:", " groups:", " default:", @@ -114,10 +116,10 @@ var _ = Describe("Upstream resolver configuration tests", func() { Expect(err).Should(Succeed()) DeferCleanup(blocky.Terminate) }) - It("should consider the timeout parameter", func() { + It("should consider the timeout parameter", func(ctx context.Context) { By("query without timeout", func() { msg := util.NewMsgWithQuestion("example.com.", A) - Expect(doDNSRequest(blocky, msg)). + Expect(doDNSRequest(ctx, blocky, msg)). Should( SatisfyAll( BeDNSRecord("example.com.", A, "1.2.3.4"), @@ -128,7 +130,7 @@ var _ = Describe("Upstream resolver configuration tests", func() { By("query with timeout", func() { msg := util.NewMsgWithQuestion("delay.com.", A) - resp, err := doDNSRequest(blocky, msg) + resp, err := doDNSRequest(ctx, blocky, msg) Expect(err).Should(Succeed()) Expect(resp.Rcode).Should(Equal(dns.RcodeServerFailure)) }) diff --git a/lists/downloader.go b/lists/downloader.go index 1536f79ad..ff95c8628 100644 --- a/lists/downloader.go +++ b/lists/downloader.go @@ -1,6 +1,7 @@ package lists import ( + "context" "errors" "fmt" "io" @@ -27,7 +28,7 @@ func (e *TransientError) Unwrap() error { // FileDownloader is able to download some text file type FileDownloader interface { - DownloadFile(link string) (io.ReadCloser, error) + DownloadFile(ctx context.Context, link string) (io.ReadCloser, error) } // httpDownloader downloads files via HTTP protocol @@ -52,12 +53,17 @@ func newDownloader(cfg config.DownloaderConfig, transport http.RoundTripper) *ht } } -func (d *httpDownloader) DownloadFile(link string) (io.ReadCloser, error) { +func (d *httpDownloader) DownloadFile(ctx context.Context, link string) (io.ReadCloser, error) { var body io.ReadCloser err := retry.Do( func() error { - resp, httpErr := d.client.Get(link) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, link, nil) + if err != nil { + return err + } + + resp, httpErr := d.client.Do(req) if httpErr == nil { if resp.StatusCode == http.StatusOK { body = resp.Body diff --git a/lists/downloader_test.go b/lists/downloader_test.go index 10df7e515..25a8f2f22 100644 --- a/lists/downloader_test.go +++ b/lists/downloader_test.go @@ -1,6 +1,7 @@ package lists import ( + "context" "errors" "io" "net" @@ -80,8 +81,8 @@ var _ = Describe("Downloader", func() { sut = newDownloader(sutConfig, nil) }) - It("Should return all lines from the file", func() { - reader, err := sut.DownloadFile(server.URL) + It("Should return all lines from the file", func(ctx context.Context) { + reader, err := sut.DownloadFile(ctx, server.URL) Expect(err).Should(Succeed()) Expect(reader).Should(Not(BeNil())) @@ -101,8 +102,8 @@ var _ = Describe("Downloader", func() { sutConfig.Attempts = 3 }) - It("Should return error", func() { - reader, err := sut.DownloadFile(server.URL) + It("Should return error", func(ctx context.Context) { + reader, err := sut.DownloadFile(ctx, server.URL) Expect(err).Should(HaveOccurred()) Expect(reader).Should(BeNil()) @@ -115,8 +116,8 @@ var _ = Describe("Downloader", func() { BeforeEach(func() { sutConfig.Attempts = 1 }) - It("Should return error", func() { - _, err := sut.DownloadFile("somewrongurl") + It("Should return error", func(ctx context.Context) { + _, err := sut.DownloadFile(ctx, "somewrongurl") Expect(err).Should(HaveOccurred()) Expect(loggerHook.LastEntry().Message).Should(ContainSubstring("Can't download file: ")) @@ -149,8 +150,8 @@ var _ = Describe("Downloader", func() { })) DeferCleanup(server.Close) }) - It("Should perform a retry and return file content", func() { - reader, err := sut.DownloadFile(server.URL) + It("Should perform a retry and return file content", func(ctx context.Context) { + reader, err := sut.DownloadFile(ctx, server.URL) Expect(err).Should(Succeed()) Expect(reader).Should(Not(BeNil())) DeferCleanup(reader.Close) @@ -180,17 +181,18 @@ var _ = Describe("Downloader", func() { })) DeferCleanup(server.Close) }) - It("Should perform a retry until max retry attempt count is reached and return TransientError", func() { - reader, err := sut.DownloadFile(server.URL) - Expect(err).Should(HaveOccurred()) - Expect(errors.As(err, new(*TransientError))).Should(BeTrue()) - Expect(err.Error()).Should(ContainSubstring("Timeout")) - Expect(reader).Should(BeNil()) - - // failed download event was emitted 3 times - Expect(failedDownloadCountEvtChannel).Should(HaveLen(3)) - Expect(failedDownloadCountEvtChannel).Should(Receive(Equal(server.URL))) - }) + It("Should perform a retry until max retry attempt count is reached and return TransientError", + func(ctx context.Context) { + reader, err := sut.DownloadFile(ctx, server.URL) + Expect(err).Should(HaveOccurred()) + Expect(errors.As(err, new(*TransientError))).Should(BeTrue()) + Expect(err.Error()).Should(ContainSubstring("Timeout")) + Expect(reader).Should(BeNil()) + + // failed download event was emitted 3 times + Expect(failedDownloadCountEvtChannel).Should(HaveLen(3)) + Expect(failedDownloadCountEvtChannel).Should(Receive(Equal(server.URL))) + }) }) When("DNS resolution of passed URL fails", func() { BeforeEach(func() { @@ -200,19 +202,20 @@ var _ = Describe("Downloader", func() { Cooldown: 200 * config.Duration(time.Millisecond), } }) - It("Should perform a retry until max retry attempt count is reached and return DNSError", func() { - reader, err := sut.DownloadFile("http://some.domain.which.does.not.exist") - Expect(err).Should(HaveOccurred()) - - var dnsError *net.DNSError - Expect(errors.As(err, &dnsError)).Should(BeTrue(), "received error %w", err) - Expect(reader).Should(BeNil()) - - // failed download event was emitted 3 times - Expect(failedDownloadCountEvtChannel).Should(HaveLen(3)) - Expect(failedDownloadCountEvtChannel).Should(Receive(Equal("http://some.domain.which.does.not.exist"))) - Expect(loggerHook.LastEntry().Message).Should(ContainSubstring("Name resolution err: ")) - }) + It("Should perform a retry until max retry attempt count is reached and return DNSError", + func(ctx context.Context) { + reader, err := sut.DownloadFile(ctx, "http://some.domain.which.does.not.exist") + Expect(err).Should(HaveOccurred()) + + var dnsError *net.DNSError + Expect(errors.As(err, &dnsError)).Should(BeTrue(), "received error %w", err) + Expect(reader).Should(BeNil()) + + // failed download event was emitted 3 times + Expect(failedDownloadCountEvtChannel).Should(HaveLen(3)) + Expect(failedDownloadCountEvtChannel).Should(Receive(Equal("http://some.domain.which.does.not.exist"))) + Expect(loggerHook.LastEntry().Message).Should(ContainSubstring("Name resolution err: ")) + }) }) }) }) diff --git a/lists/list_cache.go b/lists/list_cache.go index 035f460ab..70ae15bd0 100644 --- a/lists/list_cache.go +++ b/lists/list_cache.go @@ -228,7 +228,7 @@ func (b *ListCache) parseFile(ctx context.Context, opener SourceOpener, resultCh logger().Debug("starting processing of source") - r, err := opener.Open() + r, err := opener.Open(ctx) if err != nil { logger().Error("cannot open source: ", err) diff --git a/lists/list_cache_test.go b/lists/list_cache_test.go index b677774e0..d89c4943b 100644 --- a/lists/list_cache_test.go +++ b/lists/list_cache_test.go @@ -455,7 +455,7 @@ func newMockDownloader(driver func(res chan<- string, err chan<- error)) *MockDo return &MockDownloader{NewMockCallSequence(driver)} } -func (m *MockDownloader) DownloadFile(_ string) (io.ReadCloser, error) { +func (m *MockDownloader) DownloadFile(_ context.Context, _ string) (io.ReadCloser, error) { str, err := m.Call() if err != nil { return nil, err diff --git a/lists/sourcereader.go b/lists/sourcereader.go index 8e6541c8d..855207011 100644 --- a/lists/sourcereader.go +++ b/lists/sourcereader.go @@ -1,6 +1,7 @@ package lists import ( + "context" "fmt" "io" "os" @@ -12,7 +13,7 @@ import ( type SourceOpener interface { fmt.Stringer - Open() (io.ReadCloser, error) + Open(ctx context.Context) (io.ReadCloser, error) } func NewSourceOpener(txtLocInfo string, source config.BytesSource, downloader FileDownloader) (SourceOpener, error) { @@ -35,7 +36,7 @@ type textOpener struct { locInfo string } -func (o *textOpener) Open() (io.ReadCloser, error) { +func (o *textOpener) Open(_ context.Context) (io.ReadCloser, error) { return io.NopCloser(strings.NewReader(o.source.From)), nil } @@ -48,8 +49,8 @@ type httpOpener struct { downloader FileDownloader } -func (o *httpOpener) Open() (io.ReadCloser, error) { - return o.downloader.DownloadFile(o.source.From) +func (o *httpOpener) Open(ctx context.Context) (io.ReadCloser, error) { + return o.downloader.DownloadFile(ctx, o.source.From) } func (o *httpOpener) String() string { @@ -60,7 +61,7 @@ type fileOpener struct { source config.BytesSource } -func (o *fileOpener) Open() (io.ReadCloser, error) { +func (o *fileOpener) Open(_ context.Context) (io.ReadCloser, error) { return os.Open(o.source.From) } diff --git a/resolver/hosts_file_resolver.go b/resolver/hosts_file_resolver.go index ea77f23bd..dd50a8162 100644 --- a/resolver/hosts_file_resolver.go +++ b/resolver/hosts_file_resolver.go @@ -212,7 +212,7 @@ func (r *HostsFileResolver) loadSources(ctx context.Context) error { func (r *HostsFileResolver) parseFile( ctx context.Context, opener lists.SourceOpener, hostsChan chan<- *HostsFileEntry, ) error { - reader, err := opener.Open() + reader, err := opener.Open(ctx) if err != nil { return err }