From 3f48611f4183b9d0ba6a740038dc0e14ff7268e5 Mon Sep 17 00:00:00 2001 From: poy Date: Thu, 4 Feb 2021 00:11:58 +0530 Subject: [PATCH] Adds Persistent{Pre,Post}Run hook chaining PersistentPreRun and PersistentPostRun are chained together so that each child PersistentPreRun is ran, and the PersistentPostRun are ran in reverse order. For example: Commands: root -> subcommand-a -> subcommand-b root - PersistentPreRun subcommand-a - PersistentPreRun subcommand-b - PersistentPreRun subcommand-b - Run subcommand-b - PersistentPostRun subcommand-a - PersistentPostRun root - PersistentPostRun fixes #252 --- command.go | 165 +++++++++++++++++++++++------- command_test.go | 261 ++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 382 insertions(+), 44 deletions(-) diff --git a/command.go b/command.go index 2b82a553d..4e2796292 100644 --- a/command.go +++ b/command.go @@ -105,6 +105,21 @@ type Command struct { // * PersistentPostRun() // All functions get the same args, the arguments after the command name. // + // When TraverseChildrenHooks is set, PersistentPreRun and + // PersistentPostRun are chained together so that each child + // PersistentPreRun is ran, and the PersistentPostRun are ran in reverse + // order. For example: + // + // Commands: root -> subcommand-a -> subcommand-b + // + // root - PersistentPreRun + // subcommand-a - PersistentPreRun + // subcommand-b - PersistentPreRun + // subcommand-b - Run + // subcommand-b - PersistentPostRun + // subcommand-a - PersistentPostRun + // root - PersistentPostRun + // // PersistentPreRun: children of this command will inherit and execute. PersistentPreRun func(cmd *Command, args []string) // PersistentPreRunE: PersistentPreRun but returns an error. @@ -154,6 +169,11 @@ type Command struct { // TraverseChildren parses flags on all parents before executing child command. TraverseChildren bool + // TraverseChildrenHooks will have each subcommand's PersistentPreRun and + // PersistentPostRun instead of overriding. It should be set on the root + // command. + TraverseChildrenHooks bool + // FParseErrWhitelist flag parse errors to be ignored FParseErrWhitelist FParseErrWhitelist @@ -824,55 +844,130 @@ func (c *Command) execute(a []string) (err error) { return err } - for p := c; p != nil; p = p.Parent() { - if p.PersistentPreRunE != nil { - if err := p.PersistentPreRunE(c, argWoFlags); err != nil { - return err + // Look to see if TraverseChildrenHooks is set on the root command. + if _, err := c.runTree(c, argWoFlags, c.traverseChildrenHooks()); err != nil { + return err + } + + return nil +} + +func (c *Command) traverseChildrenHooks() bool { + if c.HasParent() { + return c.Parent().traverseChildrenHooks() + } + + return c.TraverseChildrenHooks +} + +func (c *Command) runTree( + cmd *Command, + args []string, + traverseChildrenHooks bool, +) ( + persistentPostRunEs []func(cmd *Command, args []string) error, + err error, +) { + if c == nil { + return nil, nil + } + + // Traverse command tree and save the PersistentPostRun{,E} functions. + persistentPostRunEs, err = c.Parent().runTree(cmd, args, traverseChildrenHooks) + if err != nil { + return nil, err + } + + if traverseChildrenHooks || c == cmd { + // PersistentPreRun/PersistentPreRunE + switch { + case c.PersistentPreRun != nil: + c.PersistentPreRun(cmd, args) + case c.PersistentPreRunE != nil: + if err := c.PersistentPreRunE(cmd, args); err != nil { + return nil, err } - break - } else if p.PersistentPreRun != nil { - p.PersistentPreRun(c, argWoFlags) - break + default: + // Doesn't have a registered PersistentPreRun{,E}. Move on... + } + + // PersistentPostRun/PersistentPostRunE + switch { + case c.PersistentPostRun != nil: + persistentPostRunEs = append( + persistentPostRunEs, + func(cmd *Command, args []string) error { + c.PersistentPostRun(cmd, args) + return nil + }, + ) + case c.PersistentPostRunE != nil: + persistentPostRunEs = append( + persistentPostRunEs, + c.PersistentPostRunE, + ) + default: + // Doesn't have a registered PersistentPostRun{,E}. Move on... } } - if c.PreRunE != nil { - if err := c.PreRunE(c, argWoFlags); err != nil { - return err + + if c != cmd { + // Don't run a parent command. + return persistentPostRunEs, nil + } + + // PreRun/PreRunE + switch { + case c.PreRun != nil: + c.PreRun(cmd, args) + case c.PreRunE != nil: + if err := c.PreRunE(cmd, args); err != nil { + return nil, err } - } else if c.PreRun != nil { - c.PreRun(c, argWoFlags) + default: + // Doesn't have a registered PreRun{,E}. Move on... } if err := c.validateRequiredFlags(); err != nil { - return err + return nil, err } - if c.RunE != nil { - if err := c.RunE(c, argWoFlags); err != nil { - return err + + // Run/RunE + switch { + case c.RunE != nil: + if err := c.RunE(cmd, args); err != nil { + return nil, err } - } else { - c.Run(c, argWoFlags) - } - if c.PostRunE != nil { - if err := c.PostRunE(c, argWoFlags); err != nil { - return err + case c.Run != nil: + c.Run(cmd, args) + default: + // Both RunE and Run are nil... + panic(fmt.Sprintf("command %q does not have a non-nil RunE or Run function", c.Use)) + } + + // PostRun/PostRunE + switch { + case c.PostRun != nil: + c.PostRun(cmd, args) + case c.PostRunE != nil: + if err := c.PostRunE(cmd, args); err != nil { + return nil, err } - } else if c.PostRun != nil { - c.PostRun(c, argWoFlags) + default: + // Doesn't have a registered PostRun{,E}. Move on... } - for p := c; p != nil; p = p.Parent() { - if p.PersistentPostRunE != nil { - if err := p.PersistentPostRunE(c, argWoFlags); err != nil { - return err - } - break - } else if p.PersistentPostRun != nil { - p.PersistentPostRun(c, argWoFlags) - break + + // PersistentPostRun/PersistentPostRunE + // Iterate through the list in reverse order. Similar to a defer, allow + // the topmost commands to cleanup first. + for i := range persistentPostRunEs { + r := persistentPostRunEs[len(persistentPostRunEs)-1-i] + if err := r(cmd, args); err != nil { + return nil, err } } - return nil + return nil, nil } func (c *Command) preRun() { diff --git a/command_test.go b/command_test.go index 3a47a81b3..eb46961c1 100644 --- a/command_test.go +++ b/command_test.go @@ -3,6 +3,7 @@ package cobra import ( "bytes" "context" + "errors" "fmt" "io/ioutil" "os" @@ -1334,7 +1335,8 @@ func TestPersistentHooks(t *testing.T) { ) parentCmd := &Command{ - Use: "parent", + Use: "parent", + TraverseChildrenHooks: false, // Set explicitly to highlight setting. PersistentPreRun: func(_ *Command, args []string) { parentPersPreArgs = strings.Join(args, " ") }, @@ -1380,10 +1382,6 @@ func TestPersistentHooks(t *testing.T) { t.Errorf("Unexpected error: %v", err) } - // TODO: currently PersistenPreRun* defined in parent does not - // run if the matchin child subcommand has PersistenPreRun. - // If the behavior changes (https://github.com/spf13/cobra/issues/252) - // this test must be fixed. if parentPersPreArgs != "" { t.Errorf("Expected blank parentPersPreArgs, got %q", parentPersPreArgs) } @@ -1396,10 +1394,7 @@ func TestPersistentHooks(t *testing.T) { if parentPostArgs != "" { t.Errorf("Expected blank parentPostArgs, got %q", parentPostArgs) } - // TODO: currently PersistenPostRun* defined in parent does not - // run if the matchin child subcommand has PersistenPostRun. - // If the behavior changes (https://github.com/spf13/cobra/issues/252) - // this test must be fixed. + if parentPersPostArgs != "" { t.Errorf("Expected blank parentPersPostArgs, got %q", parentPersPostArgs) } @@ -1421,6 +1416,254 @@ func TestPersistentHooks(t *testing.T) { } } +func TestPersistentHooks_TraverseChildrenHooks(t *testing.T) { + var ( + parentPersPreArgs string + parentPreArgs string + parentRunArgs string + parentPostArgs string + parentPersPostArgs string + ) + + var ( + childPersPreArgs string + childPreArgs string + childRunArgs string + childPostArgs string + childPersPostArgs string + ) + + parentCmd := &Command{ + Use: "parent", + TraverseChildrenHooks: true, + PersistentPreRun: func(_ *Command, args []string) { + parentPersPreArgs = strings.Join(args, " ") + }, + PreRun: func(_ *Command, args []string) { + parentPreArgs = strings.Join(args, " ") + }, + Run: func(_ *Command, args []string) { + parentRunArgs = strings.Join(args, " ") + }, + PostRun: func(_ *Command, args []string) { + parentPostArgs = strings.Join(args, " ") + }, + PersistentPostRun: func(_ *Command, args []string) { + parentPersPostArgs = strings.Join(args, " ") + }, + } + + childCmd := &Command{ + Use: "child", + PersistentPreRun: func(_ *Command, args []string) { + childPersPreArgs = strings.Join(args, " ") + }, + PreRun: func(_ *Command, args []string) { + childPreArgs = strings.Join(args, " ") + }, + Run: func(_ *Command, args []string) { + childRunArgs = strings.Join(args, " ") + }, + PostRun: func(_ *Command, args []string) { + childPostArgs = strings.Join(args, " ") + }, + PersistentPostRun: func(_ *Command, args []string) { + childPersPostArgs = strings.Join(args, " ") + }, + } + parentCmd.AddCommand(childCmd) + + output, err := executeCommand(parentCmd, "child", "one", "two") + if output != "" { + t.Errorf("Unexpected output: %v", output) + } + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if parentPersPreArgs != "one two" { + t.Errorf("Expected parentPersPreArgs %q, got %q", "one two", parentPersPreArgs) + } + if parentPreArgs != "" { + t.Errorf("Expected blank parentPreArgs, got %q", parentPreArgs) + } + if parentRunArgs != "" { + t.Errorf("Expected blank parentRunArgs, got %q", parentRunArgs) + } + if parentPostArgs != "" { + t.Errorf("Expected blank parentPostArgs, got %q", parentPostArgs) + } + if parentPersPostArgs != "one two" { + t.Errorf("Expected parentPersPostArgs %q, got %q", "one two", parentPersPostArgs) + } + + if childPersPreArgs != "one two" { + t.Errorf("Expected childPersPreArgs %q, got %q", "one two", childPersPreArgs) + } + if childPreArgs != "one two" { + t.Errorf("Expected childPreArgs %q, got %q", "one two", childPreArgs) + } + if childRunArgs != "one two" { + t.Errorf("Expected childRunArgs %q, got %q", "one two", childRunArgs) + } + if childPostArgs != "one two" { + t.Errorf("Expected childPostArgs %q, got %q", "one two", childPostArgs) + } + if childPersPostArgs != "one two" { + t.Errorf("Expected childPersPostArgs %q, got %q", "one two", childPersPostArgs) + } +} + +func TestPersistentHooks_persistentPostRun_ordering(t *testing.T) { + var uses []string + nopRun := func(*Command, []string) {} + printRun := func(name string) func(*Command, []string) { + return func(cmd *Command, args []string) { + uses = append(uses, name) + } + } + + rootCmd := &Command{ + Use: "root", + TraverseChildrenHooks: true, + Run: nopRun, + PersistentPostRun: printRun("root"), + } + childCmd := &Command{ + Use: "child", + Run: nopRun, + PersistentPostRun: printRun("child"), + } + granchildCmd := &Command{ + Use: "grandchild", + Run: nopRun, + PersistentPostRun: printRun("grandchild"), + } + + childCmd.AddCommand(granchildCmd) + rootCmd.AddCommand(childCmd) + executeCommand(rootCmd, "child", "grandchild") + + if !reflect.DeepEqual(uses, []string{"grandchild", "child", "root"}) { + t.Fatalf("incorrect ordering: %v", uses) + } +} + +func TestPersistentHooks_errs(t *testing.T) { + nopRun := func(*Command, []string) {} + + testCases := []struct { + name string + setup func() *Command + args []string + expectedErr error + }{ + { + name: "PersistentPreRunE", + expectedErr: errors.New("some-error"), + args: []string{"child"}, + setup: func() *Command { + parentCmd := &Command{ + Use: "parent", + TraverseChildrenHooks: true, + PersistentPreRunE: func(_ *Command, args []string) error { + return errors.New("some-error") + }, + Run: nopRun, + } + childCmd := &Command{ + Use: "child", + Run: nopRun, + PersistentPreRunE: func(_ *Command, args []string) error { + t.Fatal("should not be invoked") + return nil + }, + } + parentCmd.AddCommand(childCmd) + + return parentCmd + }, + }, + { + name: "PersistentPostRunE", + expectedErr: errors.New("some-error"), + args: []string{"child"}, + setup: func() *Command { + parentCmd := &Command{ + Use: "parent", + TraverseChildrenHooks: true, + PersistentPostRunE: func(_ *Command, args []string) error { + t.Fatal("should not be invoked") + return nil + }, + Run: nopRun, + } + childCmd := &Command{ + Use: "child", + Run: nopRun, + PersistentPostRunE: func(_ *Command, args []string) error { + return errors.New("some-error") + }, + } + parentCmd.AddCommand(childCmd) + + return parentCmd + }, + }, + { + name: "PreRunE", + expectedErr: errors.New("some-error"), + args: []string{"parent"}, + setup: func() *Command { + return &Command{ + Use: "parent", + PreRunE: func(_ *Command, args []string) error { + return errors.New("some-error") + }, + Run: nopRun, + } + }, + }, + { + name: "RunE", + expectedErr: errors.New("some-error"), + args: []string{"parent"}, + setup: func() *Command { + return &Command{ + Use: "parent", + RunE: func(_ *Command, args []string) error { + return errors.New("some-error") + }, + } + }, + }, + { + name: "PostRunE", + expectedErr: errors.New("some-error"), + args: []string{"parent"}, + setup: func() *Command { + return &Command{ + Use: "parent", + PostRunE: func(_ *Command, args []string) error { + return errors.New("some-error") + }, + Run: nopRun, + } + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _, err := executeCommand(tc.setup(), tc.args...) + + if actual, expected := fmt.Sprint(err), fmt.Sprint(tc.expectedErr); expected != actual { + t.Fatalf("expected err %v, got %v", expected, actual) + } + }) + } +} + // Related to https://github.com/spf13/cobra/issues/521. func TestGlobalNormFuncPropagation(t *testing.T) { normFunc := func(f *pflag.FlagSet, name string) pflag.NormalizedName {