Skip to content

Commit

Permalink
Only return error, add helpers to avoid boilerplate
Browse files Browse the repository at this point in the history
  • Loading branch information
nbrownus committed Aug 8, 2023
1 parent 7293838 commit 9c9bba0
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 37 deletions.
10 changes: 2 additions & 8 deletions cmd/nebula-service/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,8 @@ func main() {

ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
if err != nil {
switch v := err.(type) {
case *util.ContextualError:
v.Log(l)
os.Exit(1)
case error:
l.WithError(err).Error("Failed to start")
os.Exit(1)
}
util.LogWithContextIfNeeded("Failed to start", err, l)
os.Exit(1)
}

if !*configTest {
Expand Down
10 changes: 2 additions & 8 deletions cmd/nebula/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,8 @@ func main() {

ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
if err != nil {
switch v := err.(type) {
case *util.ContextualError:
v.Log(l)
os.Exit(1)
case error:
l.WithError(err).Error("Failed to start")
os.Exit(1)
}
util.LogWithContextIfNeeded("Failed to start", err, l)
os.Exit(1)
}

if !*configTest {
Expand Down
27 changes: 10 additions & 17 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg

err := configLogger(l, c)
if err != nil {
return nil, util.NewContextualError("Failed to configure the logger", nil, err)
return nil, util.ContextualizeIfNeeded("Failed to configure the logger", err)
}

c.RegisterReloadCallback(func(c *config.C) {
Expand All @@ -57,14 +57,13 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg

pki, err := NewPKIFromConfig(l, c)
if err != nil {
//The errors coming out of NewPKIFromConfig are already nicely formatted
return nil, err
return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err)
}

certificate := pki.GetCertState().Certificate
fw, err := NewFirewallFromConfig(l, certificate, c)
if err != nil {
return nil, util.NewContextualError("Error while loading firewall rules", nil, err)
return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
}
l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started")

Expand All @@ -77,7 +76,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
if c.GetBool("sshd.enabled", false) {
sshStart, err = configSSH(l, ssh, c)
if err != nil {
return nil, util.NewContextualError("Error while configuring the sshd", nil, err)
return nil, util.ContextualizeIfNeeded("Error while configuring the sshd", err)
}
}

Expand Down Expand Up @@ -128,7 +127,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg

tun, err = overlay.NewDeviceFromConfig(c, l, tunCidr, tunFd, routines)
if err != nil {
return nil, util.NewContextualError("Failed to get a tun/tap device", nil, err)
return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err)
}

defer func() {
Expand All @@ -152,7 +151,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
} else {
listenHost, err = net.ResolveIPAddr("ip", rawListenHost)
if err != nil {
return nil, util.NewContextualError("Failed to resolve listen.host", nil, err)
return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err)
}
}

Expand All @@ -174,7 +173,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
for _, rawPreferredRange := range rawPreferredRanges {
_, preferredRange, err := net.ParseCIDR(rawPreferredRange)
if err != nil {
return nil, util.NewContextualError("Failed to parse preferred ranges", nil, err)
return nil, util.ContextualizeIfNeeded("Failed to parse preferred ranges", err)
}
preferredRanges = append(preferredRanges, preferredRange)
}
Expand All @@ -187,7 +186,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
if rawLocalRange != "" {
_, localRange, err := net.ParseCIDR(rawLocalRange)
if err != nil {
return nil, util.NewContextualError("Failed to parse local_range", nil, err)
return nil, util.ContextualizeIfNeeded("Failed to parse local_range", err)
}

// Check if the entry for local_range was already specified in
Expand Down Expand Up @@ -215,12 +214,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
punchy := NewPunchyFromConfig(l, c)
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy)
if err != nil {
switch v := err.(type) {
case *util.ContextualError:
return nil, err
case error:
return nil, util.NewContextualError("Failed to initialize lighthouse handler", nil, v)
}
return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err)
}

var messageMetrics *MessageMetrics
Expand Down Expand Up @@ -314,9 +308,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
// TODO - stats third-party modules start uncancellable goroutines. Update those libs to accept
// a context so that they can exit when the context is Done.
statsStart, err := startStats(l, c, buildVersion, configTest)

if err != nil {
return nil, util.NewContextualError("Failed to start stats emitter", nil, err)
return nil, util.ContextualizeIfNeeded("Failed to start stats emitter", err)
}

if configTest {
Expand Down
8 changes: 4 additions & 4 deletions pki.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
}

c.RegisterReloadCallback(func(c *config.C) {
cErr := pki.reload(c, false)
if cErr != nil {
cErr.Log(l)
rErr := pki.reload(c, false)
if rErr != nil {
util.LogWithContextIfNeeded("Failed to reload PKI from config", rErr, l)
}
})

Expand All @@ -53,7 +53,7 @@ func (p *PKI) GetCAPool() *cert.NebulaCAPool {
return p.caPool.Load()
}

func (p *PKI) reload(c *config.C, initial bool) *util.ContextualError {
func (p *PKI) reload(c *config.C, initial bool) error {
err := p.reloadCert(c, initial)
if err != nil {
if initial {
Expand Down
20 changes: 20 additions & 0 deletions util/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,26 @@ func NewContextualError(msg string, fields map[string]interface{}, realError err
return &ContextualError{Context: msg, Fields: fields, RealError: realError}
}

// ContextualizeIfNeeded is a helper function to turn an error into a ContextualError if it is not already one
func ContextualizeIfNeeded(msg string, err error) error {
switch err.(type) {
case *ContextualError:
return err
default:
return NewContextualError(msg, nil, err)
}
}

// LogWithContextIfNeeded is a helper function to log an error line for an error or ContextualError
func LogWithContextIfNeeded(msg string, err error, l *logrus.Logger) {
switch v := err.(type) {
case *ContextualError:
v.Log(l)
default:
l.WithError(err).Error(msg)
}
}

func (ce *ContextualError) Error() string {
if ce.RealError == nil {
return ce.Context
Expand Down
42 changes: 42 additions & 0 deletions util/error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package util

import (
"errors"
"fmt"
"testing"

"github.com/sirupsen/logrus"
Expand Down Expand Up @@ -67,3 +68,44 @@ func TestContextualError_Log(t *testing.T) {
e.Log(l)
assert.Equal(t, []string{"level=error error=error\n"}, tl.Logs)
}

func TestLogWithContextIfNeeded(t *testing.T) {
l := logrus.New()
l.Formatter = &logrus.TextFormatter{
DisableTimestamp: true,
DisableColors: true,
}

tl := NewTestLogWriter()
l.Out = tl

// Test ignoring fallback context
tl.Reset()
e := NewContextualError("test message", m{"field": "1"}, errors.New("error"))
LogWithContextIfNeeded("This should get thrown away", e, l)
assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs)

// Test using fallback context
tl.Reset()
err := fmt.Errorf("this is a normal error")
LogWithContextIfNeeded("Fallback context woo", err, l)
assert.Equal(t, []string{"level=error msg=\"Fallback context woo\" error=\"this is a normal error\"\n"}, tl.Logs)
}

func TestContextualizeIfNeeded(t *testing.T) {
// Test ignoring fallback context
e := NewContextualError("test message", m{"field": "1"}, errors.New("error"))
assert.Same(t, e, ContextualizeIfNeeded("should be ignored", e))

// Test using fallback context
err := fmt.Errorf("this is a normal error")
cErr := ContextualizeIfNeeded("Fallback context woo", err)

switch v := cErr.(type) {
case *ContextualError:
assert.Equal(t, err, v.RealError)
default:
t.Error("Error was not wrapped")
t.Fail()
}
}

0 comments on commit 9c9bba0

Please sign in to comment.