From 152c6f0b91067c8334e5fba65ef1c1ce619c266e Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Wed, 14 Jun 2023 15:09:50 +0000 Subject: [PATCH] chore(sources/env): use gosettings/sources/env functions --- go.mod | 2 +- go.sum | 2 + internal/config/sources/env/block.go | 77 +++------------- internal/config/sources/env/cache.go | 12 ++- internal/config/sources/env/doh.go | 14 +-- internal/config/sources/env/dot.go | 10 +-- internal/config/sources/env/helpers.go | 93 -------------------- internal/config/sources/env/helpers_test.go | 73 --------------- internal/config/sources/env/log.go | 9 +- internal/config/sources/env/metrics.go | 17 +--- internal/config/sources/env/middlewarelog.go | 11 ++- internal/config/sources/env/prometheus.go | 8 +- internal/config/sources/env/reader.go | 31 ++++--- 13 files changed, 63 insertions(+), 296 deletions(-) delete mode 100644 internal/config/sources/env/helpers.go delete mode 100644 internal/config/sources/env/helpers_test.go diff --git a/go.mod b/go.mod index b38d673c..bde4f9b6 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/prometheus/client_model v0.2.0 github.com/qdm12/golibs v0.0.0-20210915134941-19815c6f95fe github.com/qdm12/goservices v0.1.0-rc2 - github.com/qdm12/gosettings v0.3.0 + github.com/qdm12/gosettings v0.4.0-rc1 github.com/qdm12/gosplash v0.1.0 github.com/qdm12/gotree v0.0.0-20211231173558-a8b7dce9989e github.com/qdm12/govalid v0.2.0-rc1 diff --git a/go.sum b/go.sum index be263f0d..07de1bbb 100644 --- a/go.sum +++ b/go.sum @@ -258,6 +258,8 @@ github.com/qdm12/gosettings v0.3.0-rc6 h1:zsGXIIZ5FWsn9LeFBbr7s2/xgeodfvmWliVAFX github.com/qdm12/gosettings v0.3.0-rc6/go.mod h1:+hHzN8lsE63T01t6SruGzc6xkpvfsZFod/ooDs8FWnQ= github.com/qdm12/gosettings v0.3.0 h1:YutcgQzVaOB3LuLj+Smtoy90JOH/5B5p2IH3BvV3ra4= github.com/qdm12/gosettings v0.3.0/go.mod h1:JRV3opOpHvnKlIA29lKQMdYw1WSMVMfHYLLHPHol5ME= +github.com/qdm12/gosettings v0.4.0-rc1 h1:UYA92yyeDPbmZysIuG65yrpZVPtdIoRmtEHft/AyI38= +github.com/qdm12/gosettings v0.4.0-rc1/go.mod h1:JRV3opOpHvnKlIA29lKQMdYw1WSMVMfHYLLHPHol5ME= github.com/qdm12/gosplash v0.1.0 h1:Sfl+zIjFZFP7b0iqf2l5UkmEY97XBnaKkH3FNY6Gf7g= github.com/qdm12/gosplash v0.1.0/go.mod h1:+A3fWW4/rUeDXhY3ieBzwghKdnIPFJgD8K3qQkenJlw= github.com/qdm12/gotree v0.0.0-20211231173558-a8b7dce9989e h1:L1zGR8xaYpbGFO9GjphGHmN91nIzwVp2EO6cauAtwoI= diff --git a/internal/config/sources/env/block.go b/internal/config/sources/env/block.go index 9f227db8..dbce749f 100644 --- a/internal/config/sources/env/block.go +++ b/internal/config/sources/env/block.go @@ -2,109 +2,56 @@ package env import ( "fmt" - "net/netip" "github.com/qdm12/dns/v2/internal/config/settings" ) -func readBlock() (settings settings.Block, err error) { - settings.BlockMalicious, err = envToBoolPtr("BLOCK_MALICIOUS") +func (r *Reader) readBlock() (settings settings.Block, err error) { + settings.BlockMalicious, err = r.env.BoolPtr("BLOCK_MALICIOUS") if err != nil { return settings, fmt.Errorf("environment variable BLOCK_MALICIOUS: %w", err) } - settings.BlockSurveillance, err = envToBoolPtr("BLOCK_SURVEILLANCE") + settings.BlockSurveillance, err = r.env.BoolPtr("BLOCK_SURVEILLANCE") if err != nil { return settings, fmt.Errorf("environment variable BLOCK_SURVEILLANCE: %w", err) } - settings.BlockAds, err = envToBoolPtr("BLOCK_ADS") + settings.BlockAds, err = r.env.BoolPtr("BLOCK_ADS") if err != nil { return settings, fmt.Errorf("environment variable BLOCK_ADS: %w", err) } - settings.RebindingProtection, err = envToBoolPtr("REBINDING_PROTECTION") + settings.RebindingProtection, err = r.env.BoolPtr("REBINDING_PROTECTION") if err != nil { return settings, fmt.Errorf("environment variable REBINDING_PROTECTION: %w", err) } - settings.AllowedHosts = envToCSV("ALLOWED_HOSTNAMES") - settings.AddBlockedHosts = envToCSV("BLOCK_HOSTNAMES") + settings.AllowedHosts = r.env.CSV("ALLOWED_HOSTNAMES") + settings.AddBlockedHosts = r.env.CSV("BLOCK_HOSTNAMES") - settings.AllowedIPs, err = getAllowedIPs() + settings.AllowedIPs, err = r.env.CSVNetipAddresses("ALLOWED_IPS") if err != nil { return settings, err } - settings.AddBlockedIPs, err = getBlockedIPs() + settings.AddBlockedIPs, err = r.env.CSVNetipAddresses("BLOCK_IPS") if err != nil { return settings, err } - settings.AllowedIPPrefixes, err = getAllowedIPPrefixes() + settings.AllowedIPPrefixes, err = r.env.CSVNetipPrefixes("ALLOWED_CIDRS") if err != nil { return settings, err } - settings.AddBlockedIPPrefixes, err = getBlockedIPPrefixes() + settings.AddBlockedIPPrefixes, err = r.env.CSVNetipPrefixes("BLOCK_CIDRS") if err != nil { return settings, err } - settings.RebindingProtection, err = envToBoolPtr("REBINDING_PROTECTION") + settings.RebindingProtection, err = r.env.BoolPtr("REBINDING_PROTECTION") if err != nil { return settings, fmt.Errorf("environment variable REBINDING_PROTECTION: %w", err) } return settings, nil } - -// getAllowedIPs obtains a list of IPs to unblock from block lists -// from the comma separated list for the environment variable ALLOWED_IPS. -func getAllowedIPs() (ips []netip.Addr, err error) { - ipStrings := envToCSV("ALLOWED_IPS") - - ips, err = parseIPStrings(ipStrings) - if err != nil { - return nil, fmt.Errorf("environment variable ALLOWED_IPS: %w", err) - } - - return ips, nil -} - -// getBlockedIPs obtains a list of IP addresses to block from -// the comma separated list for the environment variable BLOCK_IPS. -func getBlockedIPs() (ips []netip.Addr, err error) { - values := envToCSV("BLOCK_IPS") - - ips, err = parseIPStrings(values) - if err != nil { - return nil, fmt.Errorf("environment variable BLOCK_IPS: %w", err) - } - - return ips, nil -} - -// getAllowedIPPrefixes obtains a list of IP Prefixes to unblock from block lists -// from the comma separated list for the environment variable ALLOWED_CIDRS. -func getAllowedIPPrefixes() (ipPrefixes []netip.Prefix, err error) { - ipPrefixStrings := envToCSV("ALLOWED_CIDRS") - - ipPrefixes, err = parseIPPrefixStrings(ipPrefixStrings) - if err != nil { - return nil, fmt.Errorf("environment variable ALLOWED_CIDRS: %w", err) - } - - return ipPrefixes, nil -} - -// getBlockedIPPrefixes obtains a list of IP networks (CIDR notation) to block from -// the comma separated list for the environment variable BLOCK_CIDRS. -func getBlockedIPPrefixes() (ipPrefixes []netip.Prefix, err error) { - values := envToCSV("BLOCK_CIDRS") - - ipPrefixes, err = parseIPPrefixStrings(values) - if err != nil { - return nil, fmt.Errorf("environment variable BLOCK_CIDRS: %w", err) - } - - return ipPrefixes, nil -} diff --git a/internal/config/sources/env/cache.go b/internal/config/sources/env/cache.go index eb872b14..30d311e0 100644 --- a/internal/config/sources/env/cache.go +++ b/internal/config/sources/env/cache.go @@ -4,17 +4,15 @@ import ( "errors" "fmt" "math" - "os" "strconv" - "strings" "github.com/qdm12/dns/v2/internal/config/settings" ) -func readCache() (settings settings.Cache, err error) { - settings.Type = strings.ToLower(os.Getenv("CACHE_TYPE")) +func (r *Reader) readCache() (settings settings.Cache, err error) { + settings.Type = r.env.String("CACHE_TYPE") - settings.LRU.MaxEntries, err = getLRUCacheMaxEntries() + settings.LRU.MaxEntries, err = r.getLRUCacheMaxEntries() if err != nil { return settings, fmt.Errorf("LRU max entries: %w", err) } @@ -24,8 +22,8 @@ func readCache() (settings settings.Cache, err error) { var ErrCacheLRUMaxEntries = errors.New("invalid value for max entries of the LRU cache") -func getLRUCacheMaxEntries() (maxEntries uint, err error) { - s := os.Getenv("CACHE_LRU_MAX_ENTRIES") +func (r *Reader) getLRUCacheMaxEntries() (maxEntries uint, err error) { + s := r.env.String("CACHE_LRU_MAX_ENTRIES") if s == "" { return 0, nil } diff --git a/internal/config/sources/env/doh.go b/internal/config/sources/env/doh.go index c60ca5f9..4c9da20b 100644 --- a/internal/config/sources/env/doh.go +++ b/internal/config/sources/env/doh.go @@ -6,21 +6,21 @@ import ( "github.com/qdm12/dns/v2/internal/config/settings" ) -func readDoH() (settings settings.DoH, err error) { - settings.DoHProviders = envToCSV("DOH_RESOLVERS") - settings.Timeout, err = envToDuration("DOH_TIMEOUT") +func (r *Reader) readDoH() (settings settings.DoH, err error) { + settings.DoHProviders = r.env.CSV("DOH_RESOLVERS") + settings.Timeout, err = r.env.Duration("DOH_TIMEOUT") if err != nil { return settings, fmt.Errorf("environment variable DOH_TIMEOUT: %w", err) } - settings.Self.DoTProviders = envToCSV("DOT_RESOLVERS") - settings.Self.DNSProviders = envToCSV("DNS_FALLBACK_PLAINTEXT_RESOLVERS") - settings.Self.IPv6, err = envToBoolPtr("DOT_CONNECT_IPV6") + settings.Self.DoTProviders = r.env.CSV("DOT_RESOLVERS") + settings.Self.DNSProviders = r.env.CSV("DNS_FALLBACK_PLAINTEXT_RESOLVERS") + settings.Self.IPv6, err = r.env.BoolPtr("DOT_CONNECT_IPV6") if err != nil { return settings, fmt.Errorf("environment variable DOT_CONNECT_IPV6: %w", err) } - settings.Self.Timeout, err = envToDuration("DOT_TIMEOUT") + settings.Self.Timeout, err = r.env.Duration("DOT_TIMEOUT") if err != nil { return settings, fmt.Errorf("environment variable DOT_TIMEOUT: %w", err) } diff --git a/internal/config/sources/env/dot.go b/internal/config/sources/env/dot.go index 991c066f..e5ed6b8a 100644 --- a/internal/config/sources/env/dot.go +++ b/internal/config/sources/env/dot.go @@ -6,15 +6,15 @@ import ( "github.com/qdm12/dns/v2/internal/config/settings" ) -func readDoT() (settings settings.DoT, err error) { - settings.DoTProviders = envToCSV("DOT_RESOLVERS") - settings.DNSProviders = envToCSV("DNS_FALLBACK_PLAINTEXT_RESOLVERS") - settings.Timeout, err = envToDuration("DOT_TIMEOUT") +func (r *Reader) readDoT() (settings settings.DoT, err error) { + settings.DoTProviders = r.env.CSV("DOT_RESOLVERS") + settings.DNSProviders = r.env.CSV("DNS_FALLBACK_PLAINTEXT_RESOLVERS") + settings.Timeout, err = r.env.Duration("DOT_TIMEOUT") if err != nil { return settings, fmt.Errorf("environment variable DOT_TIMEOUT: %w", err) } - settings.IPv6, err = envToBoolPtr("DOT_CONNECT_IPV6") + settings.IPv6, err = r.env.BoolPtr("DOT_CONNECT_IPV6") if err != nil { return settings, fmt.Errorf("environment variable DOT_CONNECT_IPV6: %w", err) } diff --git a/internal/config/sources/env/helpers.go b/internal/config/sources/env/helpers.go deleted file mode 100644 index 1201afe1..00000000 --- a/internal/config/sources/env/helpers.go +++ /dev/null @@ -1,93 +0,0 @@ -package env - -import ( - "fmt" - "net/netip" - "os" - "strings" - "time" - - "github.com/qdm12/govalid/binary" -) - -func envToCSV(envKey string) (values []string) { - csv := os.Getenv(envKey) - if csv == "" { - return nil - } - return lowerAndSplit(csv) -} - -func envToStringPtr(envKey string) (stringPtr *string) { - s := os.Getenv(envKey) - if s == "" { - return nil - } - return &s -} - -func envToBoolPtr(envKey string) (boolPtr *bool, err error) { - s := os.Getenv(envKey) - if s == "" { - return nil, nil //nolint:nilnil - } - return binary.Validate(s) -} - -func envToDuration(envKey string) (d time.Duration, err error) { - s := os.Getenv(envKey) - if s == "" { - return 0, nil - } - - d, err = time.ParseDuration(s) - if err != nil { - return 0, err - } - return d, nil -} - -func envToDurationPtr(envKey string) (d *time.Duration, err error) { - s := os.Getenv(envKey) - if s == "" { - return nil, nil //nolint:nilnil - } - - d = new(time.Duration) - *d, err = time.ParseDuration(s) - if err != nil { - return nil, err - } - return d, nil -} - -func lowerAndSplit(csv string) (values []string) { - csv = strings.ToLower(csv) - return strings.Split(csv, ",") -} - -func parseIPStrings(ipStrings []string) (ips []netip.Addr, err error) { - ips = make([]netip.Addr, len(ipStrings)) - - for i, ipString := range ipStrings { - ips[i], err = netip.ParseAddr(ipString) - if err != nil { - return nil, fmt.Errorf("IP address string is not valid: %w", err) - } - } - - return ips, nil -} - -func parseIPPrefixStrings(ipPrefixStrings []string) (ipPrefixes []netip.Prefix, err error) { - ipPrefixes = make([]netip.Prefix, len(ipPrefixStrings)) - - for i, ipPrefixString := range ipPrefixStrings { - ipPrefixes[i], err = netip.ParsePrefix(ipPrefixString) - if err != nil { - return nil, fmt.Errorf("IP prefix CIDR string is not valid: %w", err) - } - } - - return ipPrefixes, nil -} diff --git a/internal/config/sources/env/helpers_test.go b/internal/config/sources/env/helpers_test.go deleted file mode 100644 index 8f639875..00000000 --- a/internal/config/sources/env/helpers_test.go +++ /dev/null @@ -1,73 +0,0 @@ -package env - -import ( - "os" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// setTestEnv is used to set environment variables in -// parallel tests. -func setTestEnv(t *testing.T, key, value string) { - t.Helper() - existing := os.Getenv(key) - err := os.Setenv(key, value) //nolint:tenv - t.Cleanup(func() { - err = os.Setenv(key, existing) - assert.NoError(t, err) - }) - require.NoError(t, err) -} - -func Test_envToDurationPtr(t *testing.T) { - t.Parallel() - - durationPtr := func(d time.Duration) *time.Duration { return &d } - - testCases := map[string]struct { - envKey string - envValue string - d *time.Duration - errorMessage string - }{ - "empty": { - envKey: "DURATION_EMPTY", - }, - "zero": { - envKey: "DURATION_ZERO", - envValue: "0", - d: durationPtr(0), - }, - "one second": { - envKey: "DURATION_ONE_SECOND", - envValue: "1s", - d: durationPtr(time.Second), - }, - "parse error": { - envKey: "DURATION_MALFORMED", - envValue: "x", - errorMessage: "time: invalid duration \"x\"", - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - - setTestEnv(t, testCase.envKey, testCase.envValue) - - d, err := envToDurationPtr(testCase.envKey) - - assert.Equal(t, testCase.d, d) - if testCase.errorMessage != "" { - assert.EqualError(t, err, testCase.errorMessage) - } else { - assert.NoError(t, err) - } - }) - } -} diff --git a/internal/config/sources/env/log.go b/internal/config/sources/env/log.go index e90c075a..dc0fc5ee 100644 --- a/internal/config/sources/env/log.go +++ b/internal/config/sources/env/log.go @@ -2,14 +2,13 @@ package env import ( "fmt" - "os" "github.com/qdm12/dns/v2/internal/config/settings" "github.com/qdm12/log" ) -func readLog() (settings settings.Log, err error) { - settings.Level, err = readLogLevel() +func (r *Reader) readLog() (settings settings.Log, err error) { + settings.Level, err = r.readLogLevel() if err != nil { return settings, fmt.Errorf("environment variable LOG_LEVEL: %w", err) } @@ -17,8 +16,8 @@ func readLog() (settings settings.Log, err error) { return settings, nil } -func readLogLevel() (level *log.Level, err error) { - levelString := os.Getenv("LOG_LEVEL") +func (r *Reader) readLogLevel() (level *log.Level, err error) { + levelString := r.env.String("LOG_LEVEL") if levelString == "" { return nil, nil //nolint:nilnil } diff --git a/internal/config/sources/env/metrics.go b/internal/config/sources/env/metrics.go index d881f292..db3fa0b9 100644 --- a/internal/config/sources/env/metrics.go +++ b/internal/config/sources/env/metrics.go @@ -1,20 +1,11 @@ package env import ( - "fmt" - "os" - "strings" - "github.com/qdm12/dns/v2/internal/config/settings" ) -func readMetrics() (settings settings.Metrics, err error) { - settings.Type = strings.ToLower(os.Getenv("METRICS_TYPE")) - if err != nil { - return settings, fmt.Errorf("environment variable METRICS_TYPE: %w", err) - } - - settings.Prometheus = readPrometheus() - - return settings, nil +func (r *Reader) readMetrics() (settings settings.Metrics) { + settings.Type = r.env.String("METRICS_TYPE") + settings.Prometheus = r.readPrometheus() + return settings } diff --git a/internal/config/sources/env/middlewarelog.go b/internal/config/sources/env/middlewarelog.go index e70cd873..e79030e5 100644 --- a/internal/config/sources/env/middlewarelog.go +++ b/internal/config/sources/env/middlewarelog.go @@ -2,25 +2,24 @@ package env import ( "fmt" - "os" "github.com/qdm12/dns/v2/internal/config/settings" ) -func readMiddlewareLog() (settings settings.MiddlewareLog, err error) { - settings.Enabled, err = envToBoolPtr("MIDDLEWARE_LOG_ENABLED") +func (r *Reader) readMiddlewareLog() (settings settings.MiddlewareLog, err error) { + settings.Enabled, err = r.env.BoolPtr("MIDDLEWARE_LOG_ENABLED") if err != nil { return settings, fmt.Errorf("environment variable MIDDLEWARE_LOG_ENABLED: %w", err) } - settings.DirPath = os.Getenv("MIDDLEWARE_LOG_DIRECTORY") + settings.DirPath = r.env.String("MIDDLEWARE_LOG_DIRECTORY") - settings.LogRequests, err = envToBoolPtr("MIDDLEWARE_LOG_REQUESTS") + settings.LogRequests, err = r.env.BoolPtr("MIDDLEWARE_LOG_REQUESTS") if err != nil { return settings, fmt.Errorf("environment variable MIDDLEWARE_LOG_REQUESTS: %w", err) } - settings.LogResponses, err = envToBoolPtr("MIDDLEWARE_LOG_RESPONSES") + settings.LogResponses, err = r.env.BoolPtr("MIDDLEWARE_LOG_RESPONSES") if err != nil { return settings, fmt.Errorf("environment variable MIDDLEWARE_LOG_RESPONSES: %w", err) } diff --git a/internal/config/sources/env/prometheus.go b/internal/config/sources/env/prometheus.go index 7bd27cea..6dae5edc 100644 --- a/internal/config/sources/env/prometheus.go +++ b/internal/config/sources/env/prometheus.go @@ -1,13 +1,11 @@ package env import ( - "os" - "github.com/qdm12/dns/v2/internal/config/settings" ) -func readPrometheus() (settings settings.Prometheus) { - settings.ListeningAddress = os.Getenv("METRICS_PROMETHEUS_ADDRESS") - settings.Subsystem = envToStringPtr("METRICS_PROMETHEUS_SUBSYSTEM") +func (r *Reader) readPrometheus() (settings settings.Prometheus) { + settings.ListeningAddress = r.env.String("METRICS_PROMETHEUS_ADDRESS") + settings.Subsystem = r.env.Get("METRICS_PROMETHEUS_SUBSYSTEM") return settings } diff --git a/internal/config/sources/env/reader.go b/internal/config/sources/env/reader.go index bd753823..f89691d4 100644 --- a/internal/config/sources/env/reader.go +++ b/internal/config/sources/env/reader.go @@ -3,12 +3,13 @@ package env import ( "fmt" "os" - "strings" "github.com/qdm12/dns/v2/internal/config/settings" + "github.com/qdm12/gosettings/sources/env" ) type Reader struct { + env env.Env warner Warner } @@ -19,59 +20,57 @@ type Warner interface { func New(warner Warner) *Reader { return &Reader{ warner: warner, + env: *env.New(os.Environ(), nil), } } -func (r *Reader) Read() (settings settings.Settings, err error) { //nolint:cyclop +func (r *Reader) Read() (settings settings.Settings, err error) { warnings := checkOutdatedVariables() for _, warning := range warnings { r.warner.Warn(warning) } - settings.Upstream = strings.ToLower(os.Getenv("UPSTREAM_TYPE")) - settings.ListeningAddress = os.Getenv("LISTENING_ADDRESS") + settings.Upstream = r.env.String("UPSTREAM_TYPE") + settings.ListeningAddress = r.env.String("LISTENING_ADDRESS") - settings.Block, err = readBlock() + settings.Block, err = r.readBlock() if err != nil { return settings, fmt.Errorf("block settings: %w", err) } - settings.Cache, err = readCache() + settings.Cache, err = r.readCache() if err != nil { return settings, fmt.Errorf("cache settings: %w", err) } - settings.DoH, err = readDoH() + settings.DoH, err = r.readDoH() if err != nil { return settings, fmt.Errorf("DoH settings: %w", err) } - settings.DoT, err = readDoT() + settings.DoT, err = r.readDoT() if err != nil { return settings, fmt.Errorf("DoT settings: %w", err) } - settings.Log, err = readLog() + settings.Log, err = r.readLog() if err != nil { return settings, fmt.Errorf("log settings: %w", err) } - settings.MiddlewareLog, err = readMiddlewareLog() + settings.MiddlewareLog, err = r.readMiddlewareLog() if err != nil { return settings, fmt.Errorf("middleware log settings: %w", err) } - settings.Metrics, err = readMetrics() - if err != nil { - return settings, fmt.Errorf("metrics settings: %w", err) - } + settings.Metrics = r.readMetrics() - settings.CheckDNS, err = envToBoolPtr("CHECK_DNS") + settings.CheckDNS, err = r.env.BoolPtr("CHECK_DNS") if err != nil { return settings, fmt.Errorf("environment variable CHECK_DNS: %w", err) } - settings.UpdatePeriod, err = envToDurationPtr("UPDATE_PERIOD") + settings.UpdatePeriod, err = r.env.DurationPtr("UPDATE_PERIOD") if err != nil { return settings, fmt.Errorf("environment variable UPDATE_PERIOD: %w", err) }