diff --git a/backend/flags/flags.go b/backend/flags/flags.go index a255115..e3a9a5e 100644 --- a/backend/flags/flags.go +++ b/backend/flags/flags.go @@ -33,6 +33,16 @@ func (b *Backend) LoadStruct(ctx context.Context, cfg *confita.StructConfig) err continue } + // Check if value type implements flag.Value interface and process value accordingly + valuePtr := f.Value + if f.Value.Kind() != reflect.Ptr && f.Value.CanAddr() { + valuePtr = f.Value.Addr() + } + if iface, ok := valuePtr.Interface().(flag.Value); ok { + b.flags.Var(iface, f.Key, f.Description) + continue + } + // Display all the flags and their default values but override the field only if the user has explicitely // set the flag. k := f.Value.Kind() diff --git a/backend/flags/flags_test.go b/backend/flags/flags_test.go index c4c3ea1..cad1810 100644 --- a/backend/flags/flags_test.go +++ b/backend/flags/flags_test.go @@ -3,6 +3,7 @@ package flags import ( "context" "flag" + "fmt" "os" "testing" "time" @@ -12,6 +13,36 @@ import ( "github.com/stretchr/testify/require" ) +type logLevel int + +var ( + logLevelDebug logLevel = 1 + logLevelInfo logLevel = 2 +) + +func (l *logLevel) Set(val string) error { + switch val { + case "debug": + *l = logLevelDebug + case "info": + *l = logLevelInfo + default: + return fmt.Errorf("unknown log level: %s", val) + } + return nil +} + +func (l logLevel) String() string { + switch l { + case logLevelDebug: + return "debug" + case logLevelInfo: + return "info" + default: + return "" + } +} + func runHelper(t *testing.T, cfg interface{}, args ...string) { t.Helper() @@ -30,9 +61,10 @@ func TestFlags(t *testing.T) { D int `config:"d"` E uint `config:"e"` F float32 `config:"f"` + G logLevel `config:"g"` } var cfg config - runHelper(t, &cfg, "-a=hello", "-b=true", "-c=10s", "-d=-100", "-e=1", "-f=100.01") + runHelper(t, &cfg, "-a=hello", "-b=true", "-c=10s", "-d=-100", "-e=1", "-f=100.01", "-g=info") require.Equal(t, config{ A: "hello", B: true, @@ -40,6 +72,7 @@ func TestFlags(t *testing.T) { D: -100, E: 1, F: 100.01, + G: logLevelInfo, }, cfg) }) @@ -51,14 +84,16 @@ func TestFlags(t *testing.T) { Ddef int `config:"d-def,short=dd"` Edef uint `config:"e-def,short=ed"` Fdef float32 `config:"f-def,short=fd"` + Gdef logLevel `config:"g-def,short=gd"` } cfg := &config{ Adef: "hello", Bdef: true, Cdef: 10 * time.Second, Ddef: -100, + Gdef: logLevelInfo, } - runHelper(t, cfg, "-a-def=bye", "-b-def=false", "-c-def=15s", "-d-def=-200", "-e-def=400", "-f-def=2.33") + runHelper(t, cfg, "-a-def=bye", "-b-def=false", "-c-def=15s", "-d-def=-200", "-e-def=400", "-f-def=2.33", "-g-def=debug") require.Equal(t, &config{ Adef: "bye", @@ -67,6 +102,7 @@ func TestFlags(t *testing.T) { Ddef: -200, Edef: 400, Fdef: 2.33, + Gdef: logLevelDebug, }, cfg) }) }