Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add flag prerequisite check #2153

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 84 additions & 4 deletions flag_groups.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,22 @@ package cobra

import (
"fmt"
"os"
"sort"
"strconv"
"strings"

"github.com/ghetzel/go-stockutil/sliceutil"
flag "github.com/spf13/pflag"
)

const (
requiredAsGroup = "cobra_annotation_required_if_others_set"
oneRequired = "cobra_annotation_one_required"
mutuallyExclusive = "cobra_annotation_mutually_exclusive"
requiredAsGroup = "cobra_annotation_required_if_others_set"
oneRequired = "cobra_annotation_one_required"
mutuallyExclusive = "cobra_annotation_mutually_exclusive"
_ANNOTATION_PREREQUISITE = "cobra_annotation_prerequisite"
_ANNOTATION_DEPENDANT = "cobra_annotation_dependant"
_ANNOTATION_PREREQUISITE_MISSING_ACTION = "cobra_annotation_prerequisite_missing_action"
)

// MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors
Expand Down Expand Up @@ -76,6 +82,46 @@ func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {
}
}

const (
PREREQUISITE_MISSING_IGNORE = 0
PREREQUISITE_MISSING_PANIC = 1
PREREQUISITE_MISSING_MURMUR = 2
)

func (c *Command) MarkFlagsPrerequisite(prerequisites, dependants []string, l int) {
if len(prerequisites) <= 0 || len(dependants) <= 0 {
panic("Prerequisite or dependant flags cannot be empty")
}
var invalidPrerequisites []interface{} = sliceutil.Intersect(prerequisites, dependants)
if len(invalidPrerequisites) > 0 {
panic(fmt.Sprintf("Flags %v cannot be prerequisite of themselves", invalidPrerequisites))
}
c.mergePersistentFlags()
for _, flagNameDep := range dependants { // check dependant flags
fDep := c.Flags().Lookup(flagNameDep)
if fDep == nil {
panic(fmt.Sprintf("Failed to find flag %q and mark its prerequisites", flagNameDep))
}
if _, ok := fDep.Annotations[_ANNOTATION_PREREQUISITE]; !ok {
fDep.Annotations = map[string][]string{}
}
fDep.Annotations[_ANNOTATION_PREREQUISITE] = append(fDep.Annotations[_ANNOTATION_PREREQUISITE], prerequisites...)
if errAnno := c.Flags().SetAnnotation(flagNameDep, _ANNOTATION_PREREQUISITE_MISSING_ACTION, []string{strconv.Itoa(l)}); errAnno != nil {
panic(errAnno)
}
}
for _, flagNamePre := range prerequisites { // check prerequisite flags
fPre := c.Flags().Lookup(flagNamePre)
if fPre == nil {
panic(fmt.Sprintf("Failed to find flag %q and mark it as prerequisite of other flags", flagNamePre))
}
if _, ok := fPre.Annotations[_ANNOTATION_DEPENDANT]; !ok {
fPre.Annotations = map[string][]string{}
}
fPre.Annotations[_ANNOTATION_DEPENDANT] = append(fPre.Annotations[_ANNOTATION_DEPENDANT], dependants...)
}
}

// ValidateFlagGroups validates the mutuallyExclusive/oneRequired/requiredAsGroup logic and returns the
// first error encountered.
func (c *Command) ValidateFlagGroups() error {
Expand All @@ -90,12 +136,13 @@ func (c *Command) ValidateFlagGroups() error {
groupStatus := map[string]map[string]bool{}
oneRequiredGroupStatus := map[string]map[string]bool{}
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
prerequisiteGroupStatus := map[string]map[string]bool{}
flags.VisitAll(func(pflag *flag.Flag) {
processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus)
processFlagForGroupAnnotation(flags, pflag, oneRequired, oneRequiredGroupStatus)
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
processFlagForGroupAnnotation(flags, pflag, _ANNOTATION_DEPENDANT, prerequisiteGroupStatus)
})

if err := validateRequiredFlagGroups(groupStatus); err != nil {
return err
}
Expand All @@ -105,6 +152,9 @@ func (c *Command) ValidateFlagGroups() error {
if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil {
return err
}
if errPrereq := c.validatePrerequisiteFlagGroups(prerequisiteGroupStatus); errPrereq != nil {
return errPrereq
}
return nil
}

Expand Down Expand Up @@ -206,6 +256,36 @@ func validateExclusiveFlagGroups(data map[string]map[string]bool) error {
return nil
}

func (c *Command) validatePrerequisiteFlagGroups(data map[string]map[string]bool) error {
for flagNameDep, flagPrereqExistence := range data {
fDep := c.Flags().Lookup(flagNameDep)
isPrerequisiteFound := false
flagsetPrerequisite := fDep.Annotations[_ANNOTATION_PREREQUISITE]
for _, flagNamePre := range flagsetPrerequisite {
if v, b := flagPrereqExistence[flagNamePre]; b && v {
isPrerequisiteFound = true
break
}
}
if !isPrerequisiteFound {
iPrereqMissingAction, _ := strconv.Atoi(fDep.Annotations[_ANNOTATION_PREREQUISITE_MISSING_ACTION][0])
errMsg := fmt.Sprintf("flag %q is only effective if any of the flags in the group %v is set", flagNameDep, flagsetPrerequisite)
switch iPrereqMissingAction {
case PREREQUISITE_MISSING_PANIC:
return fmt.Errorf(errMsg)
case PREREQUISITE_MISSING_MURMUR:
fmt.Fprintln(os.Stderr, errMsg)
fallthrough
case PREREQUISITE_MISSING_IGNORE:
fallthrough
default:

}
}
}
return nil
}

func sortedKeys(m map[string]map[string]bool) []string {
keys := make([]string, len(m))
i := 0
Expand Down
10 changes: 10 additions & 0 deletions flag_groups_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ func TestValidateFlagGroups(t *testing.T) {
flagGroupsRequired []string
flagGroupsOneRequired []string
flagGroupsExclusive []string
flagGroupsPrerequisite [][]string
subCmdFlagGroupsRequired []string
subCmdFlagGroupsOneRequired []string
subCmdFlagGroupsExclusive []string
Expand Down Expand Up @@ -159,6 +160,12 @@ func TestValidateFlagGroups(t *testing.T) {
subCmdFlagGroupsRequired: []string{"e subonly"},
args: []string{"--e=foo"},
},
{
desc: "prerequisite",
flagGroupsPrerequisite: [][]string{{"a", "b"}, {"c", "d"}},
args: []string{"--c=sam80180", "--d=bar"},
expectErr: `flag "c" is only effective if any of the flags in the group [a b] is set`,
},
}
for _, tc := range testcases {
t.Run(tc.desc, func(t *testing.T) {
Expand All @@ -173,6 +180,9 @@ func TestValidateFlagGroups(t *testing.T) {
for _, flagGroup := range tc.flagGroupsExclusive {
c.MarkFlagsMutuallyExclusive(strings.Split(flagGroup, " ")...)
}
if tc.flagGroupsPrerequisite != nil && len(tc.flagGroupsPrerequisite) == 2 {
c.MarkFlagsPrerequisite(tc.flagGroupsPrerequisite[0], tc.flagGroupsPrerequisite[1], PREREQUISITE_MISSING_PANIC)
}
for _, flagGroup := range tc.subCmdFlagGroupsRequired {
sub.MarkFlagsRequiredTogether(strings.Split(flagGroup, " ")...)
}
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ go 1.15

require (
github.com/cpuguy83/go-md2man/v2 v2.0.3
github.com/ghetzel/go-stockutil v1.11.4
github.com/inconshreveable/mousetrap v1.1.0
github.com/spf13/pflag v1.0.5
golang.org/x/sys v0.16.0 // indirect
gopkg.in/yaml.v3 v3.0.1
)
Loading