Skip to content

Commit

Permalink
Improve API to get flag completion function
Browse files Browse the repository at this point in the history
The new API is simpler and matches the `c.RegisterFlagCompletionFunc()`
API.  By removing the global function `GetFlagCompletion()` we are more
future proof if we ever move from a global map of flag completion
functions to something associated with the command.

The commit also makes this API work with persistent flags by using
`c.Flag(flagName)` instead of `c.Flags().Lookup(flagName)`.

The commit also adds unit tests.

Signed-off-by: Marc Khouzam <[email protected]>
  • Loading branch information
marckhouzam committed Nov 2, 2023
1 parent 890302a commit 74f1665
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 12 deletions.
19 changes: 7 additions & 12 deletions completions.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,25 +145,20 @@ func (c *Command) RegisterFlagCompletionFunc(flagName string, f func(cmd *Comman
return nil
}

// GetFlagCompletion returns the completion function for the given flag, if available.
func GetFlagCompletion(flag *pflag.Flag) (func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective), bool) {
// GetFlagCompletionFunc returns the completion function for the given flag of the command, if available.
func (c *Command) GetFlagCompletionFunc(flagName string) (func(*Command, []string, string) ([]string, ShellCompDirective), bool) {
flag := c.Flag(flagName)
if flag == nil {
return nil, false
}

flagCompletionMutex.RLock()
defer flagCompletionMutex.RUnlock()

completionFunc, exists := flagCompletionFunctions[flag]
return completionFunc, exists
}

// GetFlagCompletionByName returns the completion function for the given flag in the command by name, if available.
func (c *Command) GetFlagCompletionByName(flagName string) (func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective), bool) {
flag := c.Flags().Lookup(flagName)
if flag == nil {
return nil, false
}

return GetFlagCompletion(flag)
}

// Returns a string listing the different directive enabled in the specified parameter
func (d ShellCompDirective) string() string {
var directives []string
Expand Down
90 changes: 90 additions & 0 deletions completions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3427,3 +3427,93 @@ Completion ended with directive: ShellCompDirectiveNoFileComp
})
}
}

func TestGetFlagCompletion(t *testing.T) {
rootCmd := &Command{Use: "root", Run: emptyRun}

rootCmd.Flags().String("rootflag", "", "root flag")
_ = rootCmd.RegisterFlagCompletionFunc("rootflag", func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) {
return []string{"rootvalue"}, ShellCompDirectiveKeepOrder
})

rootCmd.PersistentFlags().String("persistentflag", "", "persistent flag")
_ = rootCmd.RegisterFlagCompletionFunc("persistentflag", func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) {
return []string{"persistentvalue"}, ShellCompDirectiveDefault
})

childCmd := &Command{Use: "child", Run: emptyRun}

childCmd.Flags().String("childflag", "", "child flag")
_ = childCmd.RegisterFlagCompletionFunc("childflag", func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) {
return []string{"childvalue"}, ShellCompDirectiveNoFileComp | ShellCompDirectiveNoSpace
})

rootCmd.AddCommand(childCmd)

testcases := []struct {
desc string
cmd *Command
flagName string
exists bool
comps []string
directive ShellCompDirective
}{
{
desc: "get flag completion function for command",
cmd: rootCmd,
flagName: "rootflag",
exists: true,
comps: []string{"rootvalue"},
directive: ShellCompDirectiveKeepOrder,
},
{
desc: "get persistent flag completion function for command",
cmd: rootCmd,
flagName: "persistentflag",
exists: true,
comps: []string{"persistentvalue"},
directive: ShellCompDirectiveDefault,
},
{
desc: "get flag completion function for child command",
cmd: childCmd,
flagName: "childflag",
exists: true,
comps: []string{"childvalue"},
directive: ShellCompDirectiveNoFileComp | ShellCompDirectiveNoSpace,
},
{
desc: "get persistent flag completion function for child command",
cmd: childCmd,
flagName: "persistentflag",
exists: true,
comps: []string{"persistentvalue"},
directive: ShellCompDirectiveDefault,
},
{
desc: "cannot get flag completion function for local parent flag",
cmd: childCmd,
flagName: "rootflag",
exists: false,
},
}

for _, tc := range testcases {
t.Run(tc.desc, func(t *testing.T) {
compFunc, exists := tc.cmd.GetFlagCompletionFunc(tc.flagName)
if tc.exists != exists {
t.Errorf("Unexpected result looking for flag completion function")
}

if exists {
comps, directive := compFunc(tc.cmd, []string{}, "")
if strings.Join(tc.comps, " ") != strings.Join(comps, " ") {
t.Errorf("Unexpected completions %q", comps)
}
if tc.directive != directive {
t.Errorf("Unexpected directive %q", directive)
}
}
})
}
}

0 comments on commit 74f1665

Please sign in to comment.