Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for custom unmarshaling of strings #29

Merged
merged 11 commits into from
Dec 12, 2023
Merged
7 changes: 7 additions & 0 deletions examples/custom/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
app:
environment: dev

server:
port: 443
read_timeout: 1m

73 changes: 73 additions & 0 deletions examples/custom/custom_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package custom

import (
"fmt"
"strings"

"github.com/kkyr/fig"
)

type ListenerType uint

const (
ListenerUnix ListenerType = iota
ListenerTCP
ListenerTLS
)

type Config struct {
App struct {
Environment string `fig:"environment" validate:"required"`
} `fig:"app"`
Server struct {
Host string `fig:"host" default:"0.0.0.0"`
Port int `fig:"port" default:"80"`
Listener ListenerType `fig:"listener_type" default:"tcp"`
} `fig:"server"`
}

func ExampleLoad() {
var cfg Config
err := fig.Load(&cfg)
if err != nil {
panic(err)
}

fmt.Println(cfg.App.Environment)
fmt.Println(cfg.Server.Host)
fmt.Println(cfg.Server.Port)
fmt.Println(cfg.Server.Listener)

// Output:
// dev
// 0.0.0.0
// 443
// tcp
}

func (l *ListenerType) UnmarshalString(v string) error {
switch strings.ToLower(v) {
case "unix":
*l = ListenerUnix
case "tcp":
*l = ListenerTCP
case "tls":
*l = ListenerTLS
default:
return fmt.Errorf("unknown listener type: %s", v)
}
return nil
}

func (l ListenerType) String() string {
switch l {
case ListenerUnix:
return "unix"
case ListenerTCP:
return "tcp"
case ListenerTLS:
return "tls"
default:
return "unknown"
}
}
84 changes: 83 additions & 1 deletion fig.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,42 @@
DefaultTimeLayout = time.RFC3339
)

// StringUnmarshaler is an interface for custom unmarshaling of strings
//
// If a field with a local type asignment satisfies this interface, it allows the user
// to implment their own custom type unmarshaling method.
//
// Example:
kkyr marked this conversation as resolved.
Show resolved Hide resolved
//
// type ListenerType uint
//
// const (
// ListenerUnix ListenerType = iota
// ListenerTCP
// ListenerTLS
// )
//
// type Config struct {
// Listener ListenerType `fig:"listener_type" default:"unix"`
// }
//
// func (l *ListenerType) UnmarshalType(v string) error {
// switch strings.ToLower(v) {
// case "unix":
// *l = ListenerUnix
// case "tcp":
// *l = ListenerTCP
// case "tls":
// *l = ListenerTLS
// default:
// return fmt.Errorf("unknown listener type: %s", v)
// }
// return nil
// }
type StringUnmarshaler interface {
UnmarshalString(s string) error
}

// Load reads a configuration file and loads it into the given struct. The
// parameter `cfg` must be a pointer to a struct.
//
Expand Down Expand Up @@ -158,6 +194,7 @@
mapstructure.StringToTimeDurationHookFunc(),
mapstructure.StringToTimeHookFunc(f.timeLayout),
stringToRegexpHookFunc(),
stringToStringUnmarshalerHook(),
),
})
if err != nil {
Expand All @@ -183,6 +220,36 @@
}
}

// stringToStringUnmarshalerHook returns a DecodeHookFunc that executes a custom method which
// satisfies the StringUnmarshaler interface on custom types.
func stringToStringUnmarshalerHook() mapstructure.DecodeHookFunc {
return func(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) {
if f.Kind() != reflect.String {
return data, nil
}

ds, ok := data.(string)
if !ok {
return data, nil
}
kkyr marked this conversation as resolved.
Show resolved Hide resolved

if reflect.PointerTo(t).Implements(reflect.TypeOf((*StringUnmarshaler)(nil)).Elem()) {
wneessen marked this conversation as resolved.
Show resolved Hide resolved
val := reflect.New(t).Interface()

if unmarshaler, ok := val.(StringUnmarshaler); ok {
err := unmarshaler.UnmarshalString(ds)
if err != nil {
return nil, err
}

return reflect.ValueOf(val).Elem().Interface(), nil
}
}

return data, nil
}
}

// processCfg processes a cfg struct after it has been loaded from
// the config file, by validating required fields and setting defaults
// where applicable.
Expand Down Expand Up @@ -257,9 +324,24 @@

// setValue sets fv to val. it attempts to convert val to the correct
// type based on the field's kind. if conversion fails an error is
// returned.
// returned. If fv satisfies the StringUnmarshaler interface it will
// execute the corresponding StringUnmarshaler.UnmarshalString method
// on the value.
// fv must be settable else this panics.
func (f *fig) setValue(fv reflect.Value, val string) error {

Check failure on line 331 in fig.go

View workflow job for this annotation

GitHub Actions / lint

cognitive complexity 35 of func `(*fig).setValue` is high (> 30) (gocognit)
if fv.IsValid() && reflect.PointerTo(fv.Type()).Implements(reflect.TypeOf((*StringUnmarshaler)(nil)).Elem()) {
vi := reflect.New(fv.Type()).Interface()
if unmarshaler, ok := vi.(StringUnmarshaler); ok {
err := unmarshaler.UnmarshalString(val)
if err != nil {
return fmt.Errorf("could not unmarshal string %q: %w", val, err)
}
fv.Set(reflect.ValueOf(vi).Elem())
return nil
}
return fmt.Errorf("unexpected error while trying to unmarshal string")
kkyr marked this conversation as resolved.
Show resolved Hide resolved
}

switch fv.Kind() {
case reflect.Ptr:
if fv.IsNil() {
Expand Down
38 changes: 34 additions & 4 deletions fig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,28 @@ type Item struct {
Path string `fig:"path" validate:"required"`
}

type ListenerType uint

const (
ListenerUnix ListenerType = iota
ListenerTCP
ListenerTLS
)

func (l *ListenerType) UnmarshalString(v string) error {
switch strings.ToLower(v) {
case "unix":
*l = ListenerUnix
case "tcp":
*l = ListenerTCP
case "tls":
*l = ListenerTLS
default:
return fmt.Errorf("unknown listener type: %s", v)
}
return nil
}

func validPodConfig() Pod {
var pod Pod

Expand Down Expand Up @@ -249,6 +271,7 @@ func Test_fig_Load_Defaults(t *testing.T) {
Application struct {
BuildDate time.Time `fig:"build_date" default:"2020-01-01T12:00:00Z"`
}
Listener ListenerType `fig:"listener_type" default:"unix"`
}

var want Server
Expand All @@ -259,6 +282,7 @@ func Test_fig_Load_Defaults(t *testing.T) {
want.Logger.Production = false
want.Logger.Metadata.Keys = []string{"ts"}
want.Application.BuildDate = time.Date(2020, 1, 1, 12, 0, 0, 0, time.UTC)
want.Listener = ListenerUnix

var cfg Server
err := Load(&cfg, File(f), Dirs(filepath.Join("testdata", "valid")))
Expand Down Expand Up @@ -590,17 +614,19 @@ func Test_fig_decodeMap(t *testing.T) {
"log_level": "debug",
"severity": "5",
"server": map[string]interface{}{
"ports": []int{443, 80},
"secure": 1,
"ports": []int{443, 80},
"secure": 1,
"listener_type": "tls",
},
}

var cfg struct {
Level string `fig:"log_level"`
Severity int `fig:"severity" validate:"required"`
Server struct {
Ports []string `fig:"ports" default:"[443]"`
Secure bool
Ports []string `fig:"ports" default:"[443]"`
Secure bool
Listener ListenerType `fig:"listener_type" default:"unix"`
} `fig:"server"`
}

Expand All @@ -623,6 +649,10 @@ func Test_fig_decodeMap(t *testing.T) {
if cfg.Server.Secure == false {
t.Error("cfg.Server.Secure == false")
}

if cfg.Server.Listener != ListenerTLS {
t.Errorf("cfg.Server.Listener: want: %d, got: %d", ListenerTLS, cfg.Server.Listener)
}
}

func Test_fig_processCfg(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion testdata/valid/server.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
host = "0.0.0.0"

[logger]
log_level = "debug"
log_level = "debug"
Loading