From 9c9bba0794e4a92f77f912359a40bc6fc3a25439 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Tue, 8 Aug 2023 17:16:08 -0500 Subject: [PATCH] Only return error, add helpers to avoid boilerplate --- cmd/nebula-service/main.go | 10 ++------- cmd/nebula/main.go | 10 ++------- main.go | 27 +++++++++--------------- pki.go | 8 ++++---- util/error.go | 20 ++++++++++++++++++ util/error_test.go | 42 ++++++++++++++++++++++++++++++++++++++ 6 files changed, 80 insertions(+), 37 deletions(-) diff --git a/cmd/nebula-service/main.go b/cmd/nebula-service/main.go index 5616cd4b9..8d0eaa1db 100644 --- a/cmd/nebula-service/main.go +++ b/cmd/nebula-service/main.go @@ -60,14 +60,8 @@ func main() { ctrl, err := nebula.Main(c, *configTest, Build, l, nil) if err != nil { - switch v := err.(type) { - case *util.ContextualError: - v.Log(l) - os.Exit(1) - case error: - l.WithError(err).Error("Failed to start") - os.Exit(1) - } + util.LogWithContextIfNeeded("Failed to start", err, l) + os.Exit(1) } if !*configTest { diff --git a/cmd/nebula/main.go b/cmd/nebula/main.go index d59ccd347..5cf0a028a 100644 --- a/cmd/nebula/main.go +++ b/cmd/nebula/main.go @@ -54,14 +54,8 @@ func main() { ctrl, err := nebula.Main(c, *configTest, Build, l, nil) if err != nil { - switch v := err.(type) { - case *util.ContextualError: - v.Log(l) - os.Exit(1) - case error: - l.WithError(err).Error("Failed to start") - os.Exit(1) - } + util.LogWithContextIfNeeded("Failed to start", err, l) + os.Exit(1) } if !*configTest { diff --git a/main.go b/main.go index 22a5edab8..4e8448b84 100644 --- a/main.go +++ b/main.go @@ -45,7 +45,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg err := configLogger(l, c) if err != nil { - return nil, util.NewContextualError("Failed to configure the logger", nil, err) + return nil, util.ContextualizeIfNeeded("Failed to configure the logger", err) } c.RegisterReloadCallback(func(c *config.C) { @@ -57,14 +57,13 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg pki, err := NewPKIFromConfig(l, c) if err != nil { - //The errors coming out of NewPKIFromConfig are already nicely formatted - return nil, err + return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err) } certificate := pki.GetCertState().Certificate fw, err := NewFirewallFromConfig(l, certificate, c) if err != nil { - return nil, util.NewContextualError("Error while loading firewall rules", nil, err) + return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err) } l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started") @@ -77,7 +76,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg if c.GetBool("sshd.enabled", false) { sshStart, err = configSSH(l, ssh, c) if err != nil { - return nil, util.NewContextualError("Error while configuring the sshd", nil, err) + return nil, util.ContextualizeIfNeeded("Error while configuring the sshd", err) } } @@ -128,7 +127,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg tun, err = overlay.NewDeviceFromConfig(c, l, tunCidr, tunFd, routines) if err != nil { - return nil, util.NewContextualError("Failed to get a tun/tap device", nil, err) + return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err) } defer func() { @@ -152,7 +151,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } else { listenHost, err = net.ResolveIPAddr("ip", rawListenHost) if err != nil { - return nil, util.NewContextualError("Failed to resolve listen.host", nil, err) + return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err) } } @@ -174,7 +173,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg for _, rawPreferredRange := range rawPreferredRanges { _, preferredRange, err := net.ParseCIDR(rawPreferredRange) if err != nil { - return nil, util.NewContextualError("Failed to parse preferred ranges", nil, err) + return nil, util.ContextualizeIfNeeded("Failed to parse preferred ranges", err) } preferredRanges = append(preferredRanges, preferredRange) } @@ -187,7 +186,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg if rawLocalRange != "" { _, localRange, err := net.ParseCIDR(rawLocalRange) if err != nil { - return nil, util.NewContextualError("Failed to parse local_range", nil, err) + return nil, util.ContextualizeIfNeeded("Failed to parse local_range", err) } // Check if the entry for local_range was already specified in @@ -215,12 +214,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg punchy := NewPunchyFromConfig(l, c) lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy) if err != nil { - switch v := err.(type) { - case *util.ContextualError: - return nil, err - case error: - return nil, util.NewContextualError("Failed to initialize lighthouse handler", nil, v) - } + return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err) } var messageMetrics *MessageMetrics @@ -314,9 +308,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg // TODO - stats third-party modules start uncancellable goroutines. Update those libs to accept // a context so that they can exit when the context is Done. statsStart, err := startStats(l, c, buildVersion, configTest) - if err != nil { - return nil, util.NewContextualError("Failed to start stats emitter", nil, err) + return nil, util.ContextualizeIfNeeded("Failed to start stats emitter", err) } if configTest { diff --git a/pki.go b/pki.go index 81530de06..91478ce51 100644 --- a/pki.go +++ b/pki.go @@ -36,9 +36,9 @@ func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) { } c.RegisterReloadCallback(func(c *config.C) { - cErr := pki.reload(c, false) - if cErr != nil { - cErr.Log(l) + rErr := pki.reload(c, false) + if rErr != nil { + util.LogWithContextIfNeeded("Failed to reload PKI from config", rErr, l) } }) @@ -53,7 +53,7 @@ func (p *PKI) GetCAPool() *cert.NebulaCAPool { return p.caPool.Load() } -func (p *PKI) reload(c *config.C, initial bool) *util.ContextualError { +func (p *PKI) reload(c *config.C, initial bool) error { err := p.reloadCert(c, initial) if err != nil { if initial { diff --git a/util/error.go b/util/error.go index 53322d02b..a11c9c471 100644 --- a/util/error.go +++ b/util/error.go @@ -16,6 +16,26 @@ func NewContextualError(msg string, fields map[string]interface{}, realError err return &ContextualError{Context: msg, Fields: fields, RealError: realError} } +// ContextualizeIfNeeded is a helper function to turn an error into a ContextualError if it is not already one +func ContextualizeIfNeeded(msg string, err error) error { + switch err.(type) { + case *ContextualError: + return err + default: + return NewContextualError(msg, nil, err) + } +} + +// LogWithContextIfNeeded is a helper function to log an error line for an error or ContextualError +func LogWithContextIfNeeded(msg string, err error, l *logrus.Logger) { + switch v := err.(type) { + case *ContextualError: + v.Log(l) + default: + l.WithError(err).Error(msg) + } +} + func (ce *ContextualError) Error() string { if ce.RealError == nil { return ce.Context diff --git a/util/error_test.go b/util/error_test.go index 747d04e0c..5041f82ce 100644 --- a/util/error_test.go +++ b/util/error_test.go @@ -2,6 +2,7 @@ package util import ( "errors" + "fmt" "testing" "github.com/sirupsen/logrus" @@ -67,3 +68,44 @@ func TestContextualError_Log(t *testing.T) { e.Log(l) assert.Equal(t, []string{"level=error error=error\n"}, tl.Logs) } + +func TestLogWithContextIfNeeded(t *testing.T) { + l := logrus.New() + l.Formatter = &logrus.TextFormatter{ + DisableTimestamp: true, + DisableColors: true, + } + + tl := NewTestLogWriter() + l.Out = tl + + // Test ignoring fallback context + tl.Reset() + e := NewContextualError("test message", m{"field": "1"}, errors.New("error")) + LogWithContextIfNeeded("This should get thrown away", e, l) + assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs) + + // Test using fallback context + tl.Reset() + err := fmt.Errorf("this is a normal error") + LogWithContextIfNeeded("Fallback context woo", err, l) + assert.Equal(t, []string{"level=error msg=\"Fallback context woo\" error=\"this is a normal error\"\n"}, tl.Logs) +} + +func TestContextualizeIfNeeded(t *testing.T) { + // Test ignoring fallback context + e := NewContextualError("test message", m{"field": "1"}, errors.New("error")) + assert.Same(t, e, ContextualizeIfNeeded("should be ignored", e)) + + // Test using fallback context + err := fmt.Errorf("this is a normal error") + cErr := ContextualizeIfNeeded("Fallback context woo", err) + + switch v := cErr.(type) { + case *ContextualError: + assert.Equal(t, err, v.RealError) + default: + t.Error("Error was not wrapped") + t.Fail() + } +}