From 955c56c7657cd740b4268cb06c001dbfcde7c9e8 Mon Sep 17 00:00:00 2001 From: poy Date: Tue, 13 Oct 2020 22:34:28 -0600 Subject: [PATCH] Adds Persistent{Pre,Post}Run hook chaining PersistentPreRun and PersistentPostRun are chained together so that each child PersistentPreRun is ran, and the PersistentPostRun are ran in reverse order. For example: Commands: root -> subcommand-a -> subcommand-b root - PersistentPreRun subcommand-a - PersistentPreRun subcommand-b - PersistentPreRun subcommand-b - Run subcommand-b - PersistentPostRun subcommand-a - PersistentPostRun fixes #252 --- command.go | 150 +++++++++++++++++++++++++++++++++++------------ command_test.go | 152 ++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 253 insertions(+), 49 deletions(-) diff --git a/command.go b/command.go index 77b399e02e..afdc167dc7 100644 --- a/command.go +++ b/command.go @@ -105,6 +105,19 @@ type Command struct { // * PersistentPostRun() // All functions get the same args, the arguments after the command name. // + // PersistentPreRun and PersistentPostRun are chained together so that + // each child PersistentPreRun is ran, and the PersistentPostRun are ran + // in reverse order. For example: + // + // Commands: root -> subcommand-a -> subcommand-b + // + // root - PersistentPreRun + // subcommand-a - PersistentPreRun + // subcommand-b - PersistentPreRun + // subcommand-b - Run + // subcommand-b - PersistentPostRun + // subcommand-a - PersistentPostRun + // // PersistentPreRun: children of this command will inherit and execute. PersistentPreRun func(cmd *Command, args []string) // PersistentPreRunE: PersistentPreRun but returns an error. @@ -824,55 +837,118 @@ func (c *Command) execute(a []string) (err error) { return err } - for p := c; p != nil; p = p.Parent() { - if p.PersistentPreRunE != nil { - if err := p.PersistentPreRunE(c, argWoFlags); err != nil { - return err - } - break - } else if p.PersistentPreRun != nil { - p.PersistentPreRun(c, argWoFlags) - break - } + if _, err := c.runTree(c, argWoFlags); err != nil { + return err } - if c.PreRunE != nil { - if err := c.PreRunE(c, argWoFlags); err != nil { - return err + + return nil +} + +func (c *Command) runTree( + cmd *Command, + args []string, +) ( + persistentPostRunEs []func(cmd *Command, args []string) error, + err error, +) { + if c == nil { + return nil, nil + } + + // Traverse command tree and save the PersistentPostRun{,E} functions. + persistentPostRunEs, err = c.Parent().runTree(cmd, args) + if err != nil { + return nil, err + } + + // PersistentPreRun/PersistentPreRunE + switch { + case c.PersistentPreRun != nil: + c.PersistentPreRun(cmd, args) + case c.PersistentPreRunE != nil: + if err := c.PersistentPreRunE(cmd, args); err != nil { + return nil, err } - } else if c.PreRun != nil { - c.PreRun(c, argWoFlags) + default: + // Doesn't have a registered PersistentPreRun{,E}. Move on... + } + + // PersistentPostRun/PersistentPostRunE + switch { + case c.PersistentPostRun != nil: + persistentPostRunEs = append( + persistentPostRunEs, + func(cmd *Command, args []string) error { + c.PersistentPostRun(cmd, args) + return nil + }, + ) + case c.PersistentPostRunE != nil: + persistentPostRunEs = append( + persistentPostRunEs, + c.PersistentPostRunE, + ) + default: + // Doesn't have a registered PersistentPostRun{,E}. Move on... + } + + if c != cmd { + // Don't run a parent command. + return persistentPostRunEs, nil + } + + // PreRun/PreRunE + switch { + case c.PreRun != nil: + c.PreRun(cmd, args) + case c.PreRunE != nil: + if err := c.PreRunE(cmd, args); err != nil { + return nil, err + } + default: + // Doesn't have a registered PreRun{,E}. Move on... } if err := c.validateRequiredFlags(); err != nil { - return err + return nil, err } - if c.RunE != nil { - if err := c.RunE(c, argWoFlags); err != nil { - return err + + // Run/RunE + switch { + case c.RunE != nil: + if err := c.RunE(cmd, args); err != nil { + return nil, err } - } else { - c.Run(c, argWoFlags) - } - if c.PostRunE != nil { - if err := c.PostRunE(c, argWoFlags); err != nil { - return err + case c.Run != nil: + c.Run(cmd, args) + default: + // Both RunE and Run are nil... + panic(fmt.Sprintf("command %q does not have a non-nil RunE or Run function", c.Use)) + } + + // PostRun/PostRunE + switch { + case c.PostRun != nil: + c.PostRun(cmd, args) + case c.PostRunE != nil: + if err := c.PostRunE(cmd, args); err != nil { + return nil, err } - } else if c.PostRun != nil { - c.PostRun(c, argWoFlags) + default: + // Doesn't have a registered PostRun{,E}. Move on... } - for p := c; p != nil; p = p.Parent() { - if p.PersistentPostRunE != nil { - if err := p.PersistentPostRunE(c, argWoFlags); err != nil { - return err - } - break - } else if p.PersistentPostRun != nil { - p.PersistentPostRun(c, argWoFlags) - break + + // PersistentPostRun/PersistentPostRunE + // Iterate through the list in reverse order. Similar to a defer, allow + // the topmost commands to cleanup first. + for i := range persistentPostRunEs { + r := persistentPostRunEs[len(persistentPostRunEs)-1-i] + if err := r(cmd, args); err != nil { + return nil, err } } - return nil + return nil, nil } func (c *Command) preRun() { diff --git a/command_test.go b/command_test.go index 3a47a81b35..6c74199fc3 100644 --- a/command_test.go +++ b/command_test.go @@ -3,6 +3,7 @@ package cobra import ( "bytes" "context" + "errors" "fmt" "io/ioutil" "os" @@ -1380,12 +1381,8 @@ func TestPersistentHooks(t *testing.T) { t.Errorf("Unexpected error: %v", err) } - // TODO: currently PersistenPreRun* defined in parent does not - // run if the matchin child subcommand has PersistenPreRun. - // If the behavior changes (https://github.com/spf13/cobra/issues/252) - // this test must be fixed. - if parentPersPreArgs != "" { - t.Errorf("Expected blank parentPersPreArgs, got %q", parentPersPreArgs) + if parentPersPreArgs != "one two" { + t.Errorf("Expected parentPersPreArgs %q, got %q", "one two", parentPersPreArgs) } if parentPreArgs != "" { t.Errorf("Expected blank parentPreArgs, got %q", parentPreArgs) @@ -1396,12 +1393,8 @@ func TestPersistentHooks(t *testing.T) { if parentPostArgs != "" { t.Errorf("Expected blank parentPostArgs, got %q", parentPostArgs) } - // TODO: currently PersistenPostRun* defined in parent does not - // run if the matchin child subcommand has PersistenPostRun. - // If the behavior changes (https://github.com/spf13/cobra/issues/252) - // this test must be fixed. - if parentPersPostArgs != "" { - t.Errorf("Expected blank parentPersPostArgs, got %q", parentPersPostArgs) + if parentPersPostArgs != "one two" { + t.Errorf("Expected parentPersPostArgs %q, got %q", "one two", parentPersPostArgs) } if childPersPreArgs != "one two" { @@ -1421,6 +1414,141 @@ func TestPersistentHooks(t *testing.T) { } } +func TestPersistentHooks_persistentPostRun_ordering(t *testing.T) { + var uses []string + nopRun := func(*Command, []string) {} + printRun := func(name string) func(*Command, []string) { + return func(cmd *Command, args []string) { + uses = append(uses, name) + } + } + + rootCmd := &Command{Use: "root", Run: nopRun, PersistentPostRun: printRun("root")} + childCmd := &Command{Use: "child", Run: nopRun, PersistentPostRun: printRun("child")} + granchildCmd := &Command{Use: "grandchild", Run: nopRun, PersistentPostRun: printRun("grandchild")} + + childCmd.AddCommand(granchildCmd) + rootCmd.AddCommand(childCmd) + executeCommand(rootCmd, "child", "grandchild") + + if !reflect.DeepEqual(uses, []string{"grandchild", "child", "root"}) { + t.Fatalf("incorrect ordering: %v", uses) + } +} + +func TestPersistentHooks_errs(t *testing.T) { + nopRun := func(*Command, []string) {} + + testCases := []struct { + name string + setup func() *Command + args []string + expectedErr error + }{ + { + name: "PersistentPreRunE", + expectedErr: errors.New("some-error"), + args: []string{"child"}, + setup: func() *Command { + parentCmd := &Command{ + Use: "parent", + PersistentPreRunE: func(_ *Command, args []string) error { + return errors.New("some-error") + }, + Run: nopRun, + } + childCmd := &Command{ + Use: "child", + Run: nopRun, + PersistentPreRunE: func(_ *Command, args []string) error { + t.Fatal("should not be invoked") + return nil + }, + } + parentCmd.AddCommand(childCmd) + + return parentCmd + }, + }, + { + name: "PersistentPostRunE", + expectedErr: errors.New("some-error"), + args: []string{"child"}, + setup: func() *Command { + parentCmd := &Command{ + Use: "parent", + PersistentPostRunE: func(_ *Command, args []string) error { + t.Fatal("should not be invoked") + return nil + }, + Run: nopRun, + } + childCmd := &Command{ + Use: "child", + Run: nopRun, + PersistentPostRunE: func(_ *Command, args []string) error { + return errors.New("some-error") + }, + } + parentCmd.AddCommand(childCmd) + + return parentCmd + }, + }, + { + name: "PreRunE", + expectedErr: errors.New("some-error"), + args: []string{"parent"}, + setup: func() *Command { + return &Command{ + Use: "parent", + PreRunE: func(_ *Command, args []string) error { + return errors.New("some-error") + }, + Run: nopRun, + } + }, + }, + { + name: "RunE", + expectedErr: errors.New("some-error"), + args: []string{"parent"}, + setup: func() *Command { + return &Command{ + Use: "parent", + RunE: func(_ *Command, args []string) error { + return errors.New("some-error") + }, + } + }, + }, + { + name: "PostRunE", + expectedErr: errors.New("some-error"), + args: []string{"parent"}, + setup: func() *Command { + return &Command{ + Use: "parent", + PostRunE: func(_ *Command, args []string) error { + return errors.New("some-error") + }, + Run: nopRun, + } + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _, err := executeCommand(tc.setup(), tc.args...) + + if actual, expected := fmt.Sprint(err), fmt.Sprint(tc.expectedErr); expected != actual { + t.Fatalf("expected err %v, got %v", expected, actual) + } + }) + } +} + // Related to https://github.com/spf13/cobra/issues/521. func TestGlobalNormFuncPropagation(t *testing.T) { normFunc := func(f *pflag.FlagSet, name string) pflag.NormalizedName {