diff --git a/flags.go b/flags.go index 4bf4755..e0fd53c 100644 --- a/flags.go +++ b/flags.go @@ -65,6 +65,8 @@ func (f Flags) SetDefault(long string, val any) bool { return false } +// BuildFlags creates a slice of Flags from a struct. +// It supports nested structs and will only generate flags if it finds a 'short' or 'long' tag. func BuildFlags(obj any) Flags { v := reflect.ValueOf(obj) if v.Kind() != reflect.Ptr { @@ -74,19 +76,39 @@ func BuildFlags(obj any) Flags { if v.Kind() != reflect.Struct { panic(fmt.Errorf("expected a struct, got %s", v.Kind())) } + + return buildFlagsRecursive(v) +} + +func buildFlagsRecursive(v reflect.Value) Flags { t := v.Type() + var flags Flags - var err error - flags := make([]Flag, t.NumField()) for i := 0; i < t.NumField(); i++ { - flags[i], err = buildFlag(v.Field(i), t.Field(i)) - if err != nil { - panic(err) + field := t.Field(i) + fieldValue := v.Field(i) + + // Only process fields with a 'short' or 'long' tag + if hasTag(field.Tag, "short") || hasTag(field.Tag, "long") { + flag, err := buildFlag(fieldValue, field) + if err != nil { + panic(err) + } + flags = append(flags, flag) + } else if fieldValue.Kind() == reflect.Struct { + // If the field is a struct, recurse into it + embeddedFlags := buildFlagsRecursive(fieldValue) + flags = append(flags, embeddedFlags...) } } return flags } +func hasTag(tag reflect.StructTag, key string) bool { + _, ok := tag.Lookup(key) + return ok +} + func buildFlag(val reflect.Value, sf reflect.StructField) (Flag, error) { const ( tagNameLong = "long"