diff --git a/dump_test.go b/dump_test.go index 835e5e2..72443f3 100644 --- a/dump_test.go +++ b/dump_test.go @@ -949,3 +949,21 @@ TT.DY: d ` assert.Equal(t, expected, out.String()) } + +type TextMarshal string + +func (t TextMarshal) MarshallText() (string, error) { + return "'" + string(t) + "'", nil +} + +func (t TextMarshal) Type() string { + return "MyType" +} + +func TestImplementsInterface(t *testing.T) { + out := &bytes.Buffer{} + + require.NoError(t, dump.NewEncoder(out).Fdump(TextMarshal("foo"))) + + assert.Equal(t, ": 'foo'\n", out.String()) +} diff --git a/encoder.go b/encoder.go index 9789a99..437955c 100644 --- a/encoder.go +++ b/encoder.go @@ -30,6 +30,11 @@ type Encoder struct { writer io.Writer } +type TextMarshaler interface { + MarshallText() (string, error) + Type() string +} + // NewDefaultEncoder instanciate a go-dump encoder func NewDefaultEncoder() *Encoder { return NewEncoder(new(bytes.Buffer)) @@ -106,8 +111,25 @@ func (e *Encoder) fdumpInterface(w map[string]interface{}, i interface{}, roots w[prefix+k] = "" return nil } - switch f.Kind() { - case reflect.Struct: + + marshaler, convertible := i.(TextMarshaler) + + switch { + case convertible: + value, err := marshaler.MarshallText() + if err != nil { + return err + } + + if e.ExtraFields.Type { + nodeType := append(roots, "__Type__") + nodeTypeFormatted := strings.Join(sliceFormat(nodeType, e.Formatters), e.Separator) + w[nodeTypeFormatted] = marshaler.Type() + } + + return e.fdumpInterface(w, value, roots) + + case f.Kind() == reflect.Struct: if e.ExtraFields.Type { nodeType := append(roots, "__Type__") nodeTypeFormatted := strings.Join(sliceFormat(nodeType, e.Formatters), e.Separator) @@ -120,12 +142,12 @@ func (e *Encoder) fdumpInterface(w map[string]interface{}, i interface{}, roots if err := e.fdumpStruct(w, f, croots); err != nil { return err } - case reflect.Array, reflect.Slice: + case f.Kind() == reflect.Array, f.Kind() == reflect.Slice: if err := e.fDumpArray(w, i, roots); err != nil { return err } return nil - case reflect.Map: + case f.Kind() == reflect.Map: if e.ExtraFields.Type { nodeType := append(roots, "__Type__") nodeTypeFormatted := strings.Join(sliceFormat(nodeType, e.Formatters), e.Separator)