diff --git a/command.go b/command.go index 8ce5c3c88..b4857b423 100644 --- a/command.go +++ b/command.go @@ -161,6 +161,9 @@ type Command struct { // versionTemplate is the version template defined by user. versionTemplate string + // errorHandlerFunc allows setting a custom error handler by the user. + errorHandlerFunc func(error) + // inReader is a reader defined by the user that replaces stdin inReader io.Reader // outWriter is a writer defined by the user that replaces stdout @@ -323,6 +326,12 @@ func (c *Command) SetGlobalNormalizationFunc(n func(f *flag.FlagSet, name string } } +// SetErrorHandlerFunc is the function that will be called, if set, when there is any kind of error in the +// execution of the command. +func (c *Command) SetErrorHandlerFunc(f func(error)) { + c.errorHandlerFunc = f +} + // OutOrStdout returns output to stdout. func (c *Command) OutOrStdout() io.Writer { return c.getOut(os.Stdout) @@ -979,7 +988,11 @@ func (c *Command) ExecuteC() (cmd *Command, err error) { c = cmd } if !c.SilenceErrors { - c.PrintErrln("Error:", err.Error()) + if c.errorHandlerFunc != nil { + c.errorHandlerFunc(err) + } else { + c.PrintErrln("Error:", err.Error()) + } c.PrintErrf("Run '%v --help' for usage.\n", c.CommandPath()) } return c, err @@ -1008,7 +1021,11 @@ func (c *Command) ExecuteC() (cmd *Command, err error) { // If root command has SilenceErrors flagged, // all subcommands should respect it if !cmd.SilenceErrors && !c.SilenceErrors { - c.PrintErrln("Error:", err.Error()) + if c.errorHandlerFunc != nil { + c.errorHandlerFunc(err) + } else { + c.PrintErrln("Error:", err.Error()) + } } // If root command has SilenceUsage flagged, diff --git a/command_test.go b/command_test.go index b3dd03040..69880a7f3 100644 --- a/command_test.go +++ b/command_test.go @@ -17,6 +17,7 @@ package cobra import ( "bytes" "context" + "errors" "fmt" "io/ioutil" "os" @@ -2430,3 +2431,23 @@ func TestHelpflagCommandExecutedWithoutVersionSet(t *testing.T) { checkStringContains(t, output, HelpFlag) checkStringOmits(t, output, VersionFlag) } + +func TestSetCustomErrorHandler(t *testing.T) { + var writer bytes.Buffer + handler := func(err error) { + writer.Write([]byte(err.Error())) + } + + root := &Command{ + Use: "root", + RunE: func(cmd *Command, args []string) error { + return errors.New("test error handler function") + }, + SilenceUsage: true, + } + root.SetErrorHandlerFunc(handler) + _ = root.Execute() + if writer.String() != "test error handler function" { + t.Errorf("Expected error handler to contain [%s] instead it contains [%s]", "test error handler function", writer.String()) + } +}