Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run Before actions after setting up subcommand #2028

Merged
merged 3 commits into from
Jan 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 36 additions & 16 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"os"
"path/filepath"
"reflect"
"slices"
"sort"
"strings"
"unicode"
Expand Down Expand Up @@ -561,23 +562,7 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) {
}
}

if cmd.Before != nil && !cmd.Root().shellCompletion {
if bctx, err := cmd.Before(ctx, cmd); err != nil {
deferErr = cmd.handleExitCoder(ctx, err)
return deferErr
} else if bctx != nil {
ctx = bctx
}
}

tracef("running flag actions (cmd=%[1]q)", cmd.Name)

if err := cmd.runFlagActions(ctx); err != nil {
return err
}

var subCmd *Command

if args.Present() {
tracef("checking positional args %[1]q (cmd=%[2]q)", args, cmd.Name)

Expand Down Expand Up @@ -613,11 +598,46 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) {
}
}

// If a subcommand has been resolved, let it handle the remaining execution.
if subCmd != nil {
tracef("running sub-command %[1]q with arguments %[2]q (cmd=%[3]q)", subCmd.Name, cmd.Args(), cmd.Name)
return subCmd.Run(ctx, cmd.Args().Slice())
}

// This code path is the innermost command execution. Here we actually
// perform the command action.
//
// First, resolve the chain of nested commands up to the parent.
var cmdChain []*Command
for p := cmd; p != nil; p = p.parent {
cmdChain = append(cmdChain, p)
}
slices.Reverse(cmdChain)

// Run Before actions in order.
for _, cmd := range cmdChain {
if cmd.Before == nil {
continue
}
if bctx, err := cmd.Before(ctx, cmd); err != nil {
deferErr = cmd.handleExitCoder(ctx, err)
return deferErr
} else if bctx != nil {
ctx = bctx
}
}

// Run flag actions in order.
// These take a context, so this has to happen after Before actions.
for _, cmd := range cmdChain {
tracef("running flag actions (cmd=%[1]q)", cmd.Name)
if err := cmd.runFlagActions(ctx); err != nil {
deferErr = cmd.handleExitCoder(ctx, err)
return deferErr
}
}

// Run the command action.
if cmd.Action == nil {
cmd.Action = helpCommandAction
} else {
Expand Down
39 changes: 38 additions & 1 deletion command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1438,6 +1438,43 @@ func TestCommand_BeforeFunc(t *testing.T) {
assert.Zero(t, counts.SubCommand, "Subcommand executed when NOT expected")
}

func TestCommand_BeforeFuncPersistentFlag(t *testing.T) {
counts := &opCounts{}
beforeError := fmt.Errorf("fail")
var err error

cmd := &Command{
Before: func(_ context.Context, cmd *Command) (context.Context, error) {
counts.Before++
s := cmd.String("opt")
if s != "value" {
return nil, beforeError
}
return nil, nil
},
Commands: []*Command{
{
Name: "sub",
Action: func(context.Context, *Command) error {
counts.SubCommand++
return nil
},
},
},
Flags: []Flag{
&StringFlag{Name: "opt"},
},
Writer: io.Discard,
}

// Check that --opt value is available in root command Before hook,
// even when it was set on the subcommand.
err = cmd.Run(buildTestContext(t), []string{"command", "sub", "--opt", "value"})
require.NoError(t, err)
assert.Equal(t, 1, counts.Before, "Before() not executed when expected")
assert.Equal(t, 1, counts.SubCommand, "Subcommand not executed when expected")
}

func TestCommand_BeforeAfterFuncShellCompletion(t *testing.T) {
t.Skip("TODO: is '--generate-shell-completion' (flag) still supported?")

Expand Down Expand Up @@ -2649,7 +2686,7 @@ func TestFlagAction(t *testing.T) {
{
name: "mixture",
args: []string{"app", "--f_string=app", "--f_uint=1", "--f_int_slice=1,2,3", "--f_duration=1h30m20s", "c1", "--f_string=c1", "sub1", "--f_string=sub1"},
exp: "app 1h30m20s [1 2 3] 1 c1 sub1 ",
exp: "sub1 1h30m20s [1 2 3] 1 sub1 sub1 ",
},
{
name: "flag_string_map",
Expand Down
Loading