Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for custom unmarshaling of strings #29

Merged
merged 11 commits into from
Dec 12, 2023
Merged
7 changes: 7 additions & 0 deletions examples/custom/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
app:
environment: dev

server:
port: 443
read_timeout: 1m

73 changes: 73 additions & 0 deletions examples/custom/custom_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package custom

import (
"fmt"
"strings"

"github.com/kkyr/fig"
)

type ListenerType uint

const (
ListenerUnix ListenerType = iota
ListenerTCP
ListenerTLS
)

type Config struct {
App struct {
Environment string `fig:"environment" validate:"required"`
} `fig:"app"`
Server struct {
Host string `fig:"host" default:"0.0.0.0"`
Port int `fig:"port" default:"80"`
Listener ListenerType `fig:"listener_type" default:"tcp"`
} `fig:"server"`
}

func ExampleLoad() {
var cfg Config
err := fig.Load(&cfg)
if err != nil {
panic(err)
}

fmt.Println(cfg.App.Environment)
fmt.Println(cfg.Server.Host)
fmt.Println(cfg.Server.Port)
fmt.Println(cfg.Server.Listener)

// Output:
// dev
// 0.0.0.0
// 443
// tcp
}

func (l *ListenerType) UnmarshalString(v string) error {
switch strings.ToLower(v) {
case "unix":
*l = ListenerUnix
case "tcp":
*l = ListenerTCP
case "tls":
*l = ListenerTLS
default:
return fmt.Errorf("unknown listener type: %s", v)
}
return nil
}

func (l ListenerType) String() string {
switch l {
case ListenerUnix:
return "unix"
case ListenerTCP:
return "tcp"
case ListenerTLS:
return "tls"
default:
return "unknown"
}
}
82 changes: 80 additions & 2 deletions fig.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,42 @@ const (
DefaultTimeLayout = time.RFC3339
)

// StringUnmarshaler is an interface for custom unmarshaling of strings
//
// If a field with a local type asignment satisfies this interface, it allows the user
// to implment their own custom type unmarshaling method.
//
// Example:
kkyr marked this conversation as resolved.
Show resolved Hide resolved
//
// type ListenerType uint
//
// const (
// ListenerUnix ListenerType = iota
// ListenerTCP
// ListenerTLS
// )
//
// type Config struct {
// Listener ListenerType `fig:"listener_type" default:"unix"`
// }
//
// func (l *ListenerType) UnmarshalType(v string) error {
// switch strings.ToLower(v) {
// case "unix":
// *l = ListenerUnix
// case "tcp":
// *l = ListenerTCP
// case "tls":
// *l = ListenerTLS
// default:
// return fmt.Errorf("unknown listener type: %s", v)
// }
// return nil
// }
type StringUnmarshaler interface {
UnmarshalString(s string) error
}

// Load reads a configuration file and loads it into the given struct. The
// parameter `cfg` must be a pointer to a struct.
//
Expand Down Expand Up @@ -158,6 +194,7 @@ func (f *fig) decodeMap(m map[string]interface{}, result interface{}) error {
mapstructure.StringToTimeDurationHookFunc(),
mapstructure.StringToTimeHookFunc(f.timeLayout),
stringToRegexpHookFunc(),
stringToStringUnmarshalerHook(),
),
})
if err != nil {
Expand All @@ -183,6 +220,36 @@ func stringToRegexpHookFunc() mapstructure.DecodeHookFunc {
}
}

// stringToStringUnmarshalerHook returns a DecodeHookFunc that executes a custom method which
// satisfies the StringUnmarshaler interface on custom types.
func stringToStringUnmarshalerHook() mapstructure.DecodeHookFunc {
return func(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) {
if f.Kind() != reflect.String {
return data, nil
}

ds, ok := data.(string)
if !ok {
return data, nil
}
kkyr marked this conversation as resolved.
Show resolved Hide resolved

if reflect.PointerTo(t).Implements(reflect.TypeOf((*StringUnmarshaler)(nil)).Elem()) {
wneessen marked this conversation as resolved.
Show resolved Hide resolved
val := reflect.New(t).Interface()

if unmarshaler, ok := val.(StringUnmarshaler); ok {
err := unmarshaler.UnmarshalString(ds)
if err != nil {
return nil, err
}

return reflect.ValueOf(val).Elem().Interface(), nil
}
}

return data, nil
}
}

// processCfg processes a cfg struct after it has been loaded from
// the config file, by validating required fields and setting defaults
// where applicable.
Expand Down Expand Up @@ -246,12 +313,23 @@ func (f *fig) formatEnvKey(key string) string {
return strings.ToUpper(key)
}

// setDefaultValue calls setValue but disallows booleans from
// being set.
// setDefaultValue calls setValue unless the value satisfies the
// StringUnmarshaler interface or is of a boolean type.
func (f *fig) setDefaultValue(fv reflect.Value, val string) error {
if fv.Kind() == reflect.Bool {
return fmt.Errorf("unsupported type: %v", fv.Kind())
}
if reflect.PointerTo(fv.Type()).Implements(reflect.TypeOf((*StringUnmarshaler)(nil)).Elem()) {
vi := reflect.New(fv.Type()).Interface()
if unmarshaler, ok := vi.(StringUnmarshaler); ok {
err := unmarshaler.UnmarshalString(val)
if err != nil {
return err
wneessen marked this conversation as resolved.
Show resolved Hide resolved
}
fv.Set(reflect.ValueOf(vi).Elem())
}
return nil
kkyr marked this conversation as resolved.
Show resolved Hide resolved
}
wneessen marked this conversation as resolved.
Show resolved Hide resolved
return f.setValue(fv, val)
}

Expand Down
54 changes: 49 additions & 5 deletions fig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,14 @@ type Item struct {
Path string `fig:"path" validate:"required"`
}

type ListenerType uint

const (
ListenerUnix ListenerType = iota
ListenerTCP
ListenerTLS
)

func validPodConfig() Pod {
var pod Pod

Expand Down Expand Up @@ -249,6 +257,7 @@ func Test_fig_Load_Defaults(t *testing.T) {
Application struct {
BuildDate time.Time `fig:"build_date" default:"2020-01-01T12:00:00Z"`
}
Listener ListenerType `fig:"listener_type" default:"unix"`
}

var want Server
Expand All @@ -259,6 +268,7 @@ func Test_fig_Load_Defaults(t *testing.T) {
want.Logger.Production = false
want.Logger.Metadata.Keys = []string{"ts"}
want.Application.BuildDate = time.Date(2020, 1, 1, 12, 0, 0, 0, time.UTC)
want.Listener = ListenerTCP
wneessen marked this conversation as resolved.
Show resolved Hide resolved

var cfg Server
err := Load(&cfg, File(f), Dirs(filepath.Join("testdata", "valid")))
Expand Down Expand Up @@ -365,7 +375,8 @@ func Test_fig_Load_UseStrict(t *testing.T) {
for _, f := range []string{"server.yaml", "server.json", "server.toml"} {
t.Run(f, func(t *testing.T) {
type Server struct {
Host string `fig:"host"`
Host string `fig:"host"`
Listener ListenerType `fig:"listener_type"`
}

var cfg Server
Expand Down Expand Up @@ -590,17 +601,19 @@ func Test_fig_decodeMap(t *testing.T) {
"log_level": "debug",
"severity": "5",
"server": map[string]interface{}{
"ports": []int{443, 80},
"secure": 1,
"ports": []int{443, 80},
"secure": 1,
"listener_type": "tls",
},
}

var cfg struct {
Level string `fig:"log_level"`
Severity int `fig:"severity" validate:"required"`
Server struct {
Ports []string `fig:"ports" default:"[443]"`
Secure bool
Ports []string `fig:"ports" default:"[443]"`
Secure bool
Listener ListenerType `fig:"listener_type" default:"unix"`
} `fig:"server"`
}

Expand All @@ -623,6 +636,10 @@ func Test_fig_decodeMap(t *testing.T) {
if cfg.Server.Secure == false {
t.Error("cfg.Server.Secure == false")
}

if cfg.Server.Listener != ListenerTLS {
t.Errorf("cfg.Server.Listener: want: %s, got: %s", ListenerTLS, cfg.Server.Listener)
}
}

func Test_fig_processCfg(t *testing.T) {
Expand Down Expand Up @@ -1263,3 +1280,30 @@ func setenv(t *testing.T, key, value string) {
t.Helper()
t.Setenv(key, value)
}

func (l *ListenerType) UnmarshalString(v string) error {
switch strings.ToLower(v) {
case "unix":
*l = ListenerUnix
case "tcp":
*l = ListenerTCP
case "tls":
*l = ListenerTLS
default:
return fmt.Errorf("unknown listener type: %s", v)
}
return nil
}
wneessen marked this conversation as resolved.
Show resolved Hide resolved

func (l ListenerType) String() string {
switch l {
case ListenerUnix:
return "unix"
case ListenerTCP:
return "tcp"
case ListenerTLS:
return "tls"
default:
return "unknown"
}
}
wneessen marked this conversation as resolved.
Show resolved Hide resolved
3 changes: 2 additions & 1 deletion testdata/valid/server.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
"host": "0.0.0.0",
"logger": {
"log_level": "debug"
}
},
"listener_type": "tcp"
}
3 changes: 2 additions & 1 deletion testdata/valid/server.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
host = "0.0.0.0"
listener_type = "tcp"

[logger]
log_level = "debug"
log_level = "debug"
1 change: 1 addition & 0 deletions testdata/valid/server.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
host: "0.0.0.0"
listener_type: "tcp"

logger:
log_level: "debug"