Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for custom unmarshaling of strings #29

Merged
merged 11 commits into from
Dec 12, 2023
67 changes: 67 additions & 0 deletions fig.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,42 @@ const (
DefaultTimeLayout = time.RFC3339
)

// StringUnmarshaler is an interface for custom unmarshaling of strings
//
// If a field with a local type asignment satisfies this interface, it allows the user
// to implment their own custom type unmarshaling method.
//
// Example:
kkyr marked this conversation as resolved.
Show resolved Hide resolved
//
// type ListenerType uint
//
// const (
// ListenerUnix ListenerType = iota
// ListenerTCP
// ListenerTLS
// )
//
// type Config struct {
// Listener ListenerType `fig:"listener_type" default:"unix"`
// }
//
// func (l *ListenerType) UnmarshalType(v string) error {
// switch strings.ToLower(v) {
// case "unix":
// *l = ListenerUnix
// case "tcp":
// *l = ListenerTCP
// case "tls":
// *l = ListenerTLS
// default:
// return fmt.Errorf("unknown listener type: %s", v)
// }
// return nil
// }
type StringUnmarshaler interface {
UnmarshalString(s string) error
}

// Load reads a configuration file and loads it into the given struct. The
// parameter `cfg` must be a pointer to a struct.
//
Expand Down Expand Up @@ -158,6 +194,7 @@ func (f *fig) decodeMap(m map[string]interface{}, result interface{}) error {
mapstructure.StringToTimeDurationHookFunc(),
mapstructure.StringToTimeHookFunc(f.timeLayout),
stringToRegexpHookFunc(),
stringToStringUnmarshalerHook(),
),
})
if err != nil {
Expand All @@ -183,6 +220,36 @@ func stringToRegexpHookFunc() mapstructure.DecodeHookFunc {
}
}

// stringToStringUnmarshalerHook returns a DecodeHookFunc that executes a custom method which
// satisfies the StringUnmarshaler interface on custom types.
func stringToStringUnmarshalerHook() mapstructure.DecodeHookFunc {
return func(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) {
if f.Kind() != reflect.String {
return data, nil
}

ds, ok := data.(string)
if !ok {
return data, nil
}
kkyr marked this conversation as resolved.
Show resolved Hide resolved

if reflect.PointerTo(t).Implements(reflect.TypeOf((*StringUnmarshaler)(nil)).Elem()) {
wneessen marked this conversation as resolved.
Show resolved Hide resolved
val := reflect.New(t).Interface()

if unmarshaler, ok := val.(StringUnmarshaler); ok {
err := unmarshaler.UnmarshalString(ds)
if err != nil {
return nil, err
}

return reflect.ValueOf(val).Elem().Interface(), nil
}
}

return data, nil
}
}

// processCfg processes a cfg struct after it has been loaded from
// the config file, by validating required fields and setting defaults
// where applicable.
Expand Down
53 changes: 49 additions & 4 deletions fig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,14 @@ type Item struct {
Path string `fig:"path" validate:"required"`
}

type ListenerType uint

const (
ListenerUnix ListenerType = iota
ListenerTCP
ListenerTLS
)

func validPodConfig() Pod {
var pod Pod

Expand Down Expand Up @@ -249,6 +257,9 @@ func Test_fig_Load_Defaults(t *testing.T) {
Application struct {
BuildDate time.Time `fig:"build_date" default:"2020-01-01T12:00:00Z"`
}
Server struct {
Listener ListenerType `fig:"listener_type" default:"unix"`
wneessen marked this conversation as resolved.
Show resolved Hide resolved
}
}

var want Server
Expand All @@ -259,6 +270,7 @@ func Test_fig_Load_Defaults(t *testing.T) {
want.Logger.Production = false
want.Logger.Metadata.Keys = []string{"ts"}
want.Application.BuildDate = time.Date(2020, 1, 1, 12, 0, 0, 0, time.UTC)
want.Server.Listener = ListenerTCP
wneessen marked this conversation as resolved.
Show resolved Hide resolved

var cfg Server
err := Load(&cfg, File(f), Dirs(filepath.Join("testdata", "valid")))
Expand Down Expand Up @@ -590,17 +602,19 @@ func Test_fig_decodeMap(t *testing.T) {
"log_level": "debug",
"severity": "5",
"server": map[string]interface{}{
"ports": []int{443, 80},
"secure": 1,
"ports": []int{443, 80},
"secure": 1,
"listener_type": "tls",
},
}

var cfg struct {
Level string `fig:"log_level"`
Severity int `fig:"severity" validate:"required"`
Server struct {
Ports []string `fig:"ports" default:"[443]"`
Secure bool
Ports []string `fig:"ports" default:"[443]"`
Secure bool
Listener ListenerType `fig:"listener_type" default:"unix"`
} `fig:"server"`
}

Expand All @@ -623,6 +637,10 @@ func Test_fig_decodeMap(t *testing.T) {
if cfg.Server.Secure == false {
t.Error("cfg.Server.Secure == false")
}

if cfg.Server.Listener != ListenerTLS {
t.Errorf("cfg.Server.Listener: want: %s, got: %s", ListenerTLS, cfg.Server.Listener)
}
}

func Test_fig_processCfg(t *testing.T) {
Expand Down Expand Up @@ -1263,3 +1281,30 @@ func setenv(t *testing.T, key, value string) {
t.Helper()
t.Setenv(key, value)
}

func (l *ListenerType) UnmarshalString(v string) error {
switch strings.ToLower(v) {
case "unix":
*l = ListenerUnix
case "tcp":
*l = ListenerTCP
case "tls":
*l = ListenerTLS
default:
return fmt.Errorf("unknown listener type: %s", v)
}
return nil
}
wneessen marked this conversation as resolved.
Show resolved Hide resolved

func (l ListenerType) String() string {
switch l {
case ListenerUnix:
return "unix"
case ListenerTCP:
return "tcp"
case ListenerTLS:
return "tls"
default:
return "unknown"
}
}
wneessen marked this conversation as resolved.
Show resolved Hide resolved