Skip to content

Commit

Permalink
feat: add flag prerequisite check
Browse files Browse the repository at this point in the history
Add function for checking whether flags are effective only if some other flags are set
  • Loading branch information
sam80180 committed Jan 18, 2024
1 parent bcfcff7 commit 1bfcd78
Show file tree
Hide file tree
Showing 5 changed files with 788 additions and 5 deletions.
88 changes: 84 additions & 4 deletions flag_groups.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,22 @@ package cobra

import (
"fmt"
"os"
"sort"
"strconv"
"strings"

"github.com/ghetzel/go-stockutil/sliceutil"
flag "github.com/spf13/pflag"
)

const (
requiredAsGroup = "cobra_annotation_required_if_others_set"
oneRequired = "cobra_annotation_one_required"
mutuallyExclusive = "cobra_annotation_mutually_exclusive"
requiredAsGroup = "cobra_annotation_required_if_others_set"
oneRequired = "cobra_annotation_one_required"
mutuallyExclusive = "cobra_annotation_mutually_exclusive"
_ANNOTATION_PREREQUISITE = "cobra_annotation_prerequisite"
_ANNOTATION_DEPENDANT = "cobra_annotation_dependant"
_ANNOTATION_PREREQUISITE_MISSING_ACTION = "cobra_annotation_prerequisite_missing_action"
)

// MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors
Expand Down Expand Up @@ -76,6 +82,46 @@ func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {
}
}

const (
PREREQUISITE_MISSING_IGNORE = 0
PREREQUISITE_MISSING_PANIC = 1
PREREQUISITE_MISSING_MURMUR = 2
)

func (c *Command) MarkFlagsPrerequisite(prerequisites, dependants []string, l int) {
if len(prerequisites) <= 0 || len(dependants) <= 0 {
panic("Prerequisite or dependant flags cannot be empty")
}
var invalidPrerequisites []interface{} = sliceutil.Intersect(prerequisites, dependants)
if len(invalidPrerequisites) > 0 {
panic(fmt.Sprintf("Flags %v cannot be prerequisite of themselves", invalidPrerequisites))
}
c.mergePersistentFlags()
for _, flagNameDep := range dependants { // check dependant flags
fDep := c.Flags().Lookup(flagNameDep)
if fDep == nil {
panic(fmt.Sprintf("Failed to find flag %q and mark its prerequisites", flagNameDep))
}
if _, ok := fDep.Annotations[_ANNOTATION_PREREQUISITE]; !ok {
fDep.Annotations = map[string][]string{}
}
fDep.Annotations[_ANNOTATION_PREREQUISITE] = append(fDep.Annotations[_ANNOTATION_PREREQUISITE], prerequisites...)
if errAnno := c.Flags().SetAnnotation(flagNameDep, _ANNOTATION_PREREQUISITE_MISSING_ACTION, []string{strconv.Itoa(l)}); errAnno != nil {
panic(errAnno)
}
}
for _, flagNamePre := range prerequisites { // check prerequisite flags
fPre := c.Flags().Lookup(flagNamePre)
if fPre == nil {
panic(fmt.Sprintf("Failed to find flag %q and mark it as prerequisite of other flags", flagNamePre))
}
if _, ok := fPre.Annotations[_ANNOTATION_DEPENDANT]; !ok {
fPre.Annotations = map[string][]string{}
}
fPre.Annotations[_ANNOTATION_DEPENDANT] = append(fPre.Annotations[_ANNOTATION_DEPENDANT], dependants...)
}
}

// ValidateFlagGroups validates the mutuallyExclusive/oneRequired/requiredAsGroup logic and returns the
// first error encountered.
func (c *Command) ValidateFlagGroups() error {
Expand All @@ -90,12 +136,13 @@ func (c *Command) ValidateFlagGroups() error {
groupStatus := map[string]map[string]bool{}
oneRequiredGroupStatus := map[string]map[string]bool{}
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
prerequisiteGroupStatus := map[string]map[string]bool{}
flags.VisitAll(func(pflag *flag.Flag) {
processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus)
processFlagForGroupAnnotation(flags, pflag, oneRequired, oneRequiredGroupStatus)
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
processFlagForGroupAnnotation(flags, pflag, _ANNOTATION_DEPENDANT, prerequisiteGroupStatus)
})

if err := validateRequiredFlagGroups(groupStatus); err != nil {
return err
}
Expand All @@ -105,6 +152,9 @@ func (c *Command) ValidateFlagGroups() error {
if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil {
return err
}
if errPrereq := c.validatePrerequisiteFlagGroups(prerequisiteGroupStatus); errPrereq != nil {
return errPrereq
}
return nil
}

Expand Down Expand Up @@ -206,6 +256,36 @@ func validateExclusiveFlagGroups(data map[string]map[string]bool) error {
return nil
}

func (c *Command) validatePrerequisiteFlagGroups(data map[string]map[string]bool) error {
for flagNameDep, flagPrereqExistence := range data {
fDep := c.Flags().Lookup(flagNameDep)
isPrerequisiteFound := false
flagsetPrerequisite := fDep.Annotations[_ANNOTATION_PREREQUISITE]
for _, flagNamePre := range flagsetPrerequisite {
if v, b := flagPrereqExistence[flagNamePre]; b && v {
isPrerequisiteFound = true
break
}
}
if !isPrerequisiteFound {
iPrereqMissingAction, _ := strconv.Atoi(fDep.Annotations[_ANNOTATION_PREREQUISITE_MISSING_ACTION][0])
errMsg := fmt.Sprintf("flag %q is only effective if any of the flags in the group %v is set", flagNameDep, flagsetPrerequisite)
switch iPrereqMissingAction {
case PREREQUISITE_MISSING_PANIC:
return fmt.Errorf(errMsg)
case PREREQUISITE_MISSING_MURMUR:
fmt.Fprintln(os.Stderr, errMsg)
fallthrough
case PREREQUISITE_MISSING_IGNORE:
fallthrough
default:

}
}
}
return nil
}

func sortedKeys(m map[string]map[string]bool) []string {
keys := make([]string, len(m))
i := 0
Expand Down
10 changes: 10 additions & 0 deletions flag_groups_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ func TestValidateFlagGroups(t *testing.T) {
flagGroupsRequired []string
flagGroupsOneRequired []string
flagGroupsExclusive []string
flagGroupsPrerequisite [][]string
subCmdFlagGroupsRequired []string
subCmdFlagGroupsOneRequired []string
subCmdFlagGroupsExclusive []string
Expand Down Expand Up @@ -159,6 +160,12 @@ func TestValidateFlagGroups(t *testing.T) {
subCmdFlagGroupsRequired: []string{"e subonly"},
args: []string{"--e=foo"},
},
{
desc: "prerequisite",
flagGroupsPrerequisite: [][]string{{"a", "b"}, {"c", "d"}},
args: []string{"--c=sam80180", "--d=bar"},
expectErr: `flag "c" is only effective if any of the flags in the group [a b] is set`,
},
}
for _, tc := range testcases {
t.Run(tc.desc, func(t *testing.T) {
Expand All @@ -173,6 +180,9 @@ func TestValidateFlagGroups(t *testing.T) {
for _, flagGroup := range tc.flagGroupsExclusive {
c.MarkFlagsMutuallyExclusive(strings.Split(flagGroup, " ")...)
}
if tc.flagGroupsPrerequisite != nil && len(tc.flagGroupsPrerequisite) == 2 {
c.MarkFlagsPrerequisite(tc.flagGroupsPrerequisite[0], tc.flagGroupsPrerequisite[1], PREREQUISITE_MISSING_PANIC)
}
for _, flagGroup := range tc.subCmdFlagGroupsRequired {
sub.MarkFlagsRequiredTogether(strings.Split(flagGroup, " ")...)
}
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ go 1.15

require (
github.com/cpuguy83/go-md2man/v2 v2.0.3
github.com/ghetzel/go-stockutil v1.11.4
github.com/inconshreveable/mousetrap v1.1.0
github.com/spf13/pflag v1.0.5
golang.org/x/sys v0.16.0 // indirect
gopkg.in/yaml.v3 v3.0.1
)
Loading

0 comments on commit 1bfcd78

Please sign in to comment.