diff --git a/command.go b/command.go index e80432ade2..ceefc348b4 100644 --- a/command.go +++ b/command.go @@ -9,6 +9,7 @@ import ( "os" "path/filepath" "reflect" + "slices" "sort" "strings" "unicode" @@ -561,23 +562,7 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) { } } - if cmd.Before != nil && !cmd.Root().shellCompletion { - if bctx, err := cmd.Before(ctx, cmd); err != nil { - deferErr = cmd.handleExitCoder(ctx, err) - return deferErr - } else if bctx != nil { - ctx = bctx - } - } - - tracef("running flag actions (cmd=%[1]q)", cmd.Name) - - if err := cmd.runFlagActions(ctx); err != nil { - return err - } - var subCmd *Command - if args.Present() { tracef("checking positional args %[1]q (cmd=%[2]q)", args, cmd.Name) @@ -613,11 +598,46 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) { } } + // If a subcommand has been resolved, let it handle the remaining execution. if subCmd != nil { tracef("running sub-command %[1]q with arguments %[2]q (cmd=%[3]q)", subCmd.Name, cmd.Args(), cmd.Name) return subCmd.Run(ctx, cmd.Args().Slice()) } + // This code path is the innermost command execution. Here we actually + // perform the command action. + // + // First, resolve the chain of nested commands up to the parent. + var cmdChain []*Command + for p := cmd; p != nil; p = p.parent { + cmdChain = append(cmdChain, p) + } + slices.Reverse(cmdChain) + + // Run Before actions in order. + for _, cmd := range cmdChain { + if cmd.Before == nil { + continue + } + if bctx, err := cmd.Before(ctx, cmd); err != nil { + deferErr = cmd.handleExitCoder(ctx, err) + return deferErr + } else if bctx != nil { + ctx = bctx + } + } + + // Run flag actions in order. + // These take a context, so this has to happen after Before actions. + for _, cmd := range cmdChain { + tracef("running flag actions (cmd=%[1]q)", cmd.Name) + if err := cmd.runFlagActions(ctx); err != nil { + deferErr = cmd.handleExitCoder(ctx, err) + return deferErr + } + } + + // Run the command action. if cmd.Action == nil { cmd.Action = helpCommandAction } else { diff --git a/command_test.go b/command_test.go index d9bc3605da..2b39cb7fef 100644 --- a/command_test.go +++ b/command_test.go @@ -1438,6 +1438,43 @@ func TestCommand_BeforeFunc(t *testing.T) { assert.Zero(t, counts.SubCommand, "Subcommand executed when NOT expected") } +func TestCommand_BeforeFuncPersistentFlag(t *testing.T) { + counts := &opCounts{} + beforeError := fmt.Errorf("fail") + var err error + + cmd := &Command{ + Before: func(_ context.Context, cmd *Command) (context.Context, error) { + counts.Before++ + s := cmd.String("opt") + if s != "value" { + return nil, beforeError + } + return nil, nil + }, + Commands: []*Command{ + { + Name: "sub", + Action: func(context.Context, *Command) error { + counts.SubCommand++ + return nil + }, + }, + }, + Flags: []Flag{ + &StringFlag{Name: "opt"}, + }, + Writer: io.Discard, + } + + // Check that --opt value is available in root command Before hook, + // even when it was set on the subcommand. + err = cmd.Run(buildTestContext(t), []string{"command", "sub", "--opt", "value"}) + require.NoError(t, err) + assert.Equal(t, 1, counts.Before, "Before() not executed when expected") + assert.Equal(t, 1, counts.SubCommand, "Subcommand not executed when expected") +} + func TestCommand_BeforeAfterFuncShellCompletion(t *testing.T) { t.Skip("TODO: is '--generate-shell-completion' (flag) still supported?") @@ -2649,7 +2686,7 @@ func TestFlagAction(t *testing.T) { { name: "mixture", args: []string{"app", "--f_string=app", "--f_uint=1", "--f_int_slice=1,2,3", "--f_duration=1h30m20s", "c1", "--f_string=c1", "sub1", "--f_string=sub1"}, - exp: "app 1h30m20s [1 2 3] 1 c1 sub1 ", + exp: "sub1 1h30m20s [1 2 3] 1 sub1 sub1 ", }, { name: "flag_string_map",