Skip to content

Commit

Permalink
Add support for custom unmarshaling of strings (#29)
Browse files Browse the repository at this point in the history
* Add support for custom unmarshaling of strings

Implemented an interface for custom unmarshaling of strings which allows users to define their own custom type unmarshaling methods. Updated fig_test.go and fig.go to reflect these changes. This update provides a flexible way for users to handle configs with custom types.

* Add "listener_type" to server configuration

Added "listener_type" field to the server configuration in JSON, YAML, and TOML files. The new field helps initialising the ListenerType field, which is now outside the Server struct in the fig test go file, with the "tcp" value.

* Refactor listener placement in fig test

The "listener_type" field has moved out of the Server struct and is now directly under the server configuration in the fig test go file. This change simplifies the initialization of the ListenerType field with the "tcp" value in JSON, YAML, and TOML configuration files.

* Update setDefaultValue function in fig.go

The function setDefaultValue in fig.go has been modified to call setValue unless the value satisfies the StringUnmarshaler interface.

* Add custom configuration and test files

A new custom configuration file, config.yaml, and a corresponding test file, custom_test.go, have been created. This is to serve as example for the custom UnmarshalString interface

* Add ListenerType in Server struct in fig_test.go

This is to make sure that `Test_fig_Load_UseStrict` won't fail due to the previously changed server.* files in `testdata/`

* Moved string unmarshaling from setDefaultValue to setValue

This is so that setEnv can benefit from that functionality as well

* Add validity check and error wrapping in setValue function

Added a check to ensure the reflect.Value is valid before attempting unmarshal. Also, wrapped error message for failed unmarshalling for clearer debugging. These changes will enhance error handling and debugging.

* Remove 'listener_type' attribute from server configuration files again, following kkyr's suggestion

* Repositioned UnmarshalString function and updated error logging in fig_test.go

The UnmarshalString function for the ListenerType has been moved higher up in the fig_test.go code. Updates to error logging formats have also been made for better readability, while unnecessary attributes in the server configuration have been removed.

* Update UnmarshalString function and enhance error logging in fig.go

The UnmarshalString function in fig.go was repositioned for efficiency. An error message for unexpected issues during string unmarshalling was also added.
  • Loading branch information
wneessen authored Dec 12, 2023
1 parent 9c6777e commit 17cd345
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 6 deletions.
7 changes: 7 additions & 0 deletions examples/custom/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
app:
environment: dev

server:
port: 443
read_timeout: 1m

73 changes: 73 additions & 0 deletions examples/custom/custom_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package custom

import (
"fmt"
"strings"

"github.com/kkyr/fig"
)

type ListenerType uint

const (
ListenerUnix ListenerType = iota
ListenerTCP
ListenerTLS
)

type Config struct {
App struct {
Environment string `fig:"environment" validate:"required"`
} `fig:"app"`
Server struct {
Host string `fig:"host" default:"0.0.0.0"`
Port int `fig:"port" default:"80"`
Listener ListenerType `fig:"listener_type" default:"tcp"`
} `fig:"server"`
}

func ExampleLoad() {
var cfg Config
err := fig.Load(&cfg)
if err != nil {
panic(err)
}

fmt.Println(cfg.App.Environment)
fmt.Println(cfg.Server.Host)
fmt.Println(cfg.Server.Port)
fmt.Println(cfg.Server.Listener)

// Output:
// dev
// 0.0.0.0
// 443
// tcp
}

func (l *ListenerType) UnmarshalString(v string) error {
switch strings.ToLower(v) {
case "unix":
*l = ListenerUnix
case "tcp":
*l = ListenerTCP
case "tls":
*l = ListenerTLS
default:
return fmt.Errorf("unknown listener type: %s", v)
}
return nil
}

func (l ListenerType) String() string {
switch l {
case ListenerUnix:
return "unix"
case ListenerTCP:
return "tcp"
case ListenerTLS:
return "tls"
default:
return "unknown"
}
}
84 changes: 83 additions & 1 deletion fig.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,42 @@ const (
DefaultTimeLayout = time.RFC3339
)

// StringUnmarshaler is an interface for custom unmarshaling of strings
//
// If a field with a local type asignment satisfies this interface, it allows the user
// to implment their own custom type unmarshaling method.
//
// Example:
//
// type ListenerType uint
//
// const (
// ListenerUnix ListenerType = iota
// ListenerTCP
// ListenerTLS
// )
//
// type Config struct {
// Listener ListenerType `fig:"listener_type" default:"unix"`
// }
//
// func (l *ListenerType) UnmarshalType(v string) error {
// switch strings.ToLower(v) {
// case "unix":
// *l = ListenerUnix
// case "tcp":
// *l = ListenerTCP
// case "tls":
// *l = ListenerTLS
// default:
// return fmt.Errorf("unknown listener type: %s", v)
// }
// return nil
// }
type StringUnmarshaler interface {
UnmarshalString(s string) error
}

// Load reads a configuration file and loads it into the given struct. The
// parameter `cfg` must be a pointer to a struct.
//
Expand Down Expand Up @@ -158,6 +194,7 @@ func (f *fig) decodeMap(m map[string]interface{}, result interface{}) error {
mapstructure.StringToTimeDurationHookFunc(),
mapstructure.StringToTimeHookFunc(f.timeLayout),
stringToRegexpHookFunc(),
stringToStringUnmarshalerHook(),
),
})
if err != nil {
Expand All @@ -183,6 +220,36 @@ func stringToRegexpHookFunc() mapstructure.DecodeHookFunc {
}
}

// stringToStringUnmarshalerHook returns a DecodeHookFunc that executes a custom method which
// satisfies the StringUnmarshaler interface on custom types.
func stringToStringUnmarshalerHook() mapstructure.DecodeHookFunc {
return func(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) {
if f.Kind() != reflect.String {
return data, nil
}

ds, ok := data.(string)
if !ok {
return data, nil
}

if reflect.PointerTo(t).Implements(reflect.TypeOf((*StringUnmarshaler)(nil)).Elem()) {
val := reflect.New(t).Interface()

if unmarshaler, ok := val.(StringUnmarshaler); ok {
err := unmarshaler.UnmarshalString(ds)
if err != nil {
return nil, err
}

return reflect.ValueOf(val).Elem().Interface(), nil
}
}

return data, nil
}
}

// processCfg processes a cfg struct after it has been loaded from
// the config file, by validating required fields and setting defaults
// where applicable.
Expand Down Expand Up @@ -257,9 +324,24 @@ func (f *fig) setDefaultValue(fv reflect.Value, val string) error {

// setValue sets fv to val. it attempts to convert val to the correct
// type based on the field's kind. if conversion fails an error is
// returned.
// returned. If fv satisfies the StringUnmarshaler interface it will
// execute the corresponding StringUnmarshaler.UnmarshalString method
// on the value.
// fv must be settable else this panics.
func (f *fig) setValue(fv reflect.Value, val string) error {

Check failure on line 331 in fig.go

View workflow job for this annotation

GitHub Actions / lint

cognitive complexity 35 of func `(*fig).setValue` is high (> 30) (gocognit)
if fv.IsValid() && reflect.PointerTo(fv.Type()).Implements(reflect.TypeOf((*StringUnmarshaler)(nil)).Elem()) {
vi := reflect.New(fv.Type()).Interface()
if unmarshaler, ok := vi.(StringUnmarshaler); ok {
err := unmarshaler.UnmarshalString(val)
if err != nil {
return fmt.Errorf("could not unmarshal string %q: %w", val, err)
}
fv.Set(reflect.ValueOf(vi).Elem())
return nil
}
return fmt.Errorf("unexpected error while trying to unmarshal string")
}

switch fv.Kind() {
case reflect.Ptr:
if fv.IsNil() {
Expand Down
38 changes: 34 additions & 4 deletions fig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,28 @@ type Item struct {
Path string `fig:"path" validate:"required"`
}

type ListenerType uint

const (
ListenerUnix ListenerType = iota
ListenerTCP
ListenerTLS
)

func (l *ListenerType) UnmarshalString(v string) error {
switch strings.ToLower(v) {
case "unix":
*l = ListenerUnix
case "tcp":
*l = ListenerTCP
case "tls":
*l = ListenerTLS
default:
return fmt.Errorf("unknown listener type: %s", v)
}
return nil
}

func validPodConfig() Pod {
var pod Pod

Expand Down Expand Up @@ -249,6 +271,7 @@ func Test_fig_Load_Defaults(t *testing.T) {
Application struct {
BuildDate time.Time `fig:"build_date" default:"2020-01-01T12:00:00Z"`
}
Listener ListenerType `fig:"listener_type" default:"unix"`
}

var want Server
Expand All @@ -259,6 +282,7 @@ func Test_fig_Load_Defaults(t *testing.T) {
want.Logger.Production = false
want.Logger.Metadata.Keys = []string{"ts"}
want.Application.BuildDate = time.Date(2020, 1, 1, 12, 0, 0, 0, time.UTC)
want.Listener = ListenerUnix

var cfg Server
err := Load(&cfg, File(f), Dirs(filepath.Join("testdata", "valid")))
Expand Down Expand Up @@ -590,17 +614,19 @@ func Test_fig_decodeMap(t *testing.T) {
"log_level": "debug",
"severity": "5",
"server": map[string]interface{}{
"ports": []int{443, 80},
"secure": 1,
"ports": []int{443, 80},
"secure": 1,
"listener_type": "tls",
},
}

var cfg struct {
Level string `fig:"log_level"`
Severity int `fig:"severity" validate:"required"`
Server struct {
Ports []string `fig:"ports" default:"[443]"`
Secure bool
Ports []string `fig:"ports" default:"[443]"`
Secure bool
Listener ListenerType `fig:"listener_type" default:"unix"`
} `fig:"server"`
}

Expand All @@ -623,6 +649,10 @@ func Test_fig_decodeMap(t *testing.T) {
if cfg.Server.Secure == false {
t.Error("cfg.Server.Secure == false")
}

if cfg.Server.Listener != ListenerTLS {
t.Errorf("cfg.Server.Listener: want: %d, got: %d", ListenerTLS, cfg.Server.Listener)
}
}

func Test_fig_processCfg(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion testdata/valid/server.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
host = "0.0.0.0"

[logger]
log_level = "debug"
log_level = "debug"

0 comments on commit 17cd345

Please sign in to comment.