diff --git a/command.go b/command.go index 77b399e02e..8c47d93273 100644 --- a/command.go +++ b/command.go @@ -105,6 +105,19 @@ type Command struct { // * PersistentPostRun() // All functions get the same args, the arguments after the command name. // + // PersistentPreRun and PersistentPostRun are chained together so that + // each child PersistentPreRun is ran, and the PersistentPostRun are ran + // in reverse order. For example: + // + // Commands: root -> subcommand-a -> subcommand-b + // + // root - PersistentPreRun + // subcommand-a - PersistentPreRun + // subcommand-b - PersistentPreRun + // subcommand-b - Run + // subcommand-b - PersistentPostRun + // subcommand-a - PersistentPostRun + // // PersistentPreRun: children of this command will inherit and execute. PersistentPreRun func(cmd *Command, args []string) // PersistentPreRunE: PersistentPreRun but returns an error. @@ -824,55 +837,114 @@ func (c *Command) execute(a []string) (err error) { return err } - for p := c; p != nil; p = p.Parent() { - if p.PersistentPreRunE != nil { - if err := p.PersistentPreRunE(c, argWoFlags); err != nil { - return err - } - break - } else if p.PersistentPreRun != nil { - p.PersistentPreRun(c, argWoFlags) - break - } + if _, err := c.runTree(c, argWoFlags); err != nil { + return err } - if c.PreRunE != nil { - if err := c.PreRunE(c, argWoFlags); err != nil { - return err + + return nil +} + +func (c *Command) runTree( + cmd *Command, + args []string, +) ( + persistentPostRunEs []func(cmd *Command, args []string) error, + err error, +) { + if c == nil { + return nil, nil + } + + // Traverse command tree and save the PersistentPostRun{,E} functions. + persistentPostRunEs, err = c.Parent().runTree(cmd, args) + if err != nil { + return nil, err + } + + switch { + case c.PersistentPreRun != nil: + c.PersistentPreRun(cmd, args) + case c.PersistentPreRunE != nil: + if err := c.PersistentPreRunE(cmd, args); err != nil { + return nil, err } - } else if c.PreRun != nil { - c.PreRun(c, argWoFlags) + default: + // Doesn't have a registered PersistentPreRun{,E}. Move on... + } + + // PersistentPostRun/PersistentPostRunE + switch { + case c.PersistentPostRun != nil: + persistentPostRunEs = append( + persistentPostRunEs, + func(cmd *Command, args []string) error { + c.PersistentPostRun(cmd, args) + return nil + }, + ) + case c.PersistentPostRunE != nil: + persistentPostRunEs = append( + persistentPostRunEs, + c.PersistentPostRunE, + ) + default: + // Doesn't have a registered PersistentPostRun{,E}. Move on... + } + + if c != cmd { + // Don't run a parent command. + return persistentPostRunEs, nil + } + + // PreRun/PreRunE + switch { + case c.PreRun != nil: + c.PreRun(cmd, args) + case c.PreRunE != nil: + if err := c.PreRunE(cmd, args); err != nil { + return nil, err + } + default: + // Doesn't have a registered PreRun{,E}. Move on... } if err := c.validateRequiredFlags(); err != nil { - return err + return nil, err } - if c.RunE != nil { - if err := c.RunE(c, argWoFlags); err != nil { - return err + + // Run/RunE + switch { + case c.RunE != nil: + if err := c.RunE(cmd, args); err != nil { + return nil, err } - } else { - c.Run(c, argWoFlags) - } - if c.PostRunE != nil { - if err := c.PostRunE(c, argWoFlags); err != nil { - return err + case c.Run != nil: + c.Run(cmd, args) + default: + // Both RunE and Run are nil... + panic(fmt.Sprintf("command %q does not have a non-nil RunE or Run function", c.Use)) + } + + // PostRun/PostRunE + switch { + case c.PostRun != nil: + c.PostRun(cmd, args) + case c.PostRunE != nil: + if err := c.PostRunE(cmd, args); err != nil { + return nil, err } - } else if c.PostRun != nil { - c.PostRun(c, argWoFlags) + default: + // Doesn't have a registered PostRun{,E}. Move on... } - for p := c; p != nil; p = p.Parent() { - if p.PersistentPostRunE != nil { - if err := p.PersistentPostRunE(c, argWoFlags); err != nil { - return err - } - break - } else if p.PersistentPostRun != nil { - p.PersistentPostRun(c, argWoFlags) - break + + // PersistentPostRun/PersistentPostRunE + for _, r := range persistentPostRunEs { + if err := r(cmd, args); err != nil { + return nil, err } } - return nil + return nil, nil } func (c *Command) preRun() { diff --git a/command_test.go b/command_test.go index 3a47a81b35..b290b6b248 100644 --- a/command_test.go +++ b/command_test.go @@ -3,6 +3,7 @@ package cobra import ( "bytes" "context" + "errors" "fmt" "io/ioutil" "os" @@ -1380,12 +1381,8 @@ func TestPersistentHooks(t *testing.T) { t.Errorf("Unexpected error: %v", err) } - // TODO: currently PersistenPreRun* defined in parent does not - // run if the matchin child subcommand has PersistenPreRun. - // If the behavior changes (https://github.com/spf13/cobra/issues/252) - // this test must be fixed. - if parentPersPreArgs != "" { - t.Errorf("Expected blank parentPersPreArgs, got %q", parentPersPreArgs) + if parentPersPreArgs != "one two" { + t.Errorf("Expected parentPersPreArgs %q, got %q", "one two", parentPersPreArgs) } if parentPreArgs != "" { t.Errorf("Expected blank parentPreArgs, got %q", parentPreArgs) @@ -1396,12 +1393,8 @@ func TestPersistentHooks(t *testing.T) { if parentPostArgs != "" { t.Errorf("Expected blank parentPostArgs, got %q", parentPostArgs) } - // TODO: currently PersistenPostRun* defined in parent does not - // run if the matchin child subcommand has PersistenPostRun. - // If the behavior changes (https://github.com/spf13/cobra/issues/252) - // this test must be fixed. - if parentPersPostArgs != "" { - t.Errorf("Expected blank parentPersPostArgs, got %q", parentPersPostArgs) + if parentPersPostArgs != "one two" { + t.Errorf("Expected parentPersPostArgs %q, got %q", "one two", parentPersPostArgs) } if childPersPreArgs != "one two" { @@ -1421,6 +1414,138 @@ func TestPersistentHooks(t *testing.T) { } } +func Example_persistentPostRun_ordering(t *testing.T) { + nopRun := func(*Command, []string) {} + printRun := func(cmd *Command, args []string) { + println(cmd.Use) + } + + rootCmd := &Command{Use: "root", Run: nopRun, PreRun: printRun} + childCmd := &Command{Use: "child", Run: nopRun, PreRun: printRun} + granchildCmd := &Command{Use: "grandchild", Run: nopRun, PreRun: printRun} + + childCmd.AddCommand(granchildCmd) + rootCmd.AddCommand(childCmd) + executeCommand(rootCmd) + + // Output: grandchild + // child + // root +} + +func TestPersistentHooks_errs(t *testing.T) { + nopRun := func(*Command, []string) {} + + testCases := []struct { + name string + setup func() *Command + args []string + expectedErr error + }{ + { + name: "PersistentPreRunE", + expectedErr: errors.New("some-error"), + args: []string{"child"}, + setup: func() *Command { + parentCmd := &Command{ + Use: "parent", + PersistentPreRunE: func(_ *Command, args []string) error { + return errors.New("some-error") + }, + Run: nopRun, + } + childCmd := &Command{ + Use: "child", + Run: nopRun, + PersistentPreRunE: func(_ *Command, args []string) error { + t.Fatal("should not be invoked") + return nil + }, + } + parentCmd.AddCommand(childCmd) + + return parentCmd + }, + }, + { + name: "PersistentPostRunE", + expectedErr: errors.New("some-error"), + args: []string{"child"}, + setup: func() *Command { + parentCmd := &Command{ + Use: "parent", + PersistentPostRunE: func(_ *Command, args []string) error { + return errors.New("some-error") + }, + Run: nopRun, + } + childCmd := &Command{ + Use: "child", + Run: nopRun, + PersistentPostRunE: func(_ *Command, args []string) error { + t.Fatal("should not be invoked") + return nil + }, + } + parentCmd.AddCommand(childCmd) + + return parentCmd + }, + }, + { + name: "PreRunE", + expectedErr: errors.New("some-error"), + args: []string{"parent"}, + setup: func() *Command { + return &Command{ + Use: "parent", + PreRunE: func(_ *Command, args []string) error { + return errors.New("some-error") + }, + Run: nopRun, + } + }, + }, + { + name: "RunE", + expectedErr: errors.New("some-error"), + args: []string{"parent"}, + setup: func() *Command { + return &Command{ + Use: "parent", + RunE: func(_ *Command, args []string) error { + return errors.New("some-error") + }, + } + }, + }, + { + name: "PostRunE", + expectedErr: errors.New("some-error"), + args: []string{"parent"}, + setup: func() *Command { + return &Command{ + Use: "parent", + PostRunE: func(_ *Command, args []string) error { + return errors.New("some-error") + }, + Run: nopRun, + } + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _, err := executeCommand(tc.setup(), tc.args...) + + if actual, expected := fmt.Sprint(err), fmt.Sprint(tc.expectedErr); expected != actual { + t.Fatalf("expected err %v, got %v", expected, actual) + } + }) + } +} + // Related to https://github.com/spf13/cobra/issues/521. func TestGlobalNormFuncPropagation(t *testing.T) { normFunc := func(f *pflag.FlagSet, name string) pflag.NormalizedName {