diff --git a/decode.go b/decode.go index 79117cae..18f47f20 100644 --- a/decode.go +++ b/decode.go @@ -24,42 +24,44 @@ import ( // Decoder reads and decodes YAML values from an input stream. type Decoder struct { - reader io.Reader - referenceReaders []io.Reader - anchorNodeMap map[string]ast.Node - anchorValueMap map[string]reflect.Value - customUnmarshalerMap map[reflect.Type]func(interface{}, []byte) error - toCommentMap CommentMap - opts []DecodeOption - referenceFiles []string - referenceDirs []string - isRecursiveDir bool - isResolvedReference bool - validator StructValidator - disallowUnknownField bool - disallowDuplicateKey bool - useOrderedMap bool - useJSONUnmarshaler bool - parsedFile *ast.File - streamIndex int + reader io.Reader + referenceReaders []io.Reader + anchorNodeMap map[string]ast.Node + anchorValueMap map[string]reflect.Value + customUnmarshalerMap map[reflect.Type]func(interface{}, []byte) error + customInterfaceUnmarshalerMap map[reflect.Type]func(interface{}, func(interface{}) error) error + toCommentMap CommentMap + opts []DecodeOption + referenceFiles []string + referenceDirs []string + isRecursiveDir bool + isResolvedReference bool + validator StructValidator + disallowUnknownField bool + disallowDuplicateKey bool + useOrderedMap bool + useJSONUnmarshaler bool + parsedFile *ast.File + streamIndex int } // NewDecoder returns a new decoder that reads from r. func NewDecoder(r io.Reader, opts ...DecodeOption) *Decoder { return &Decoder{ - reader: r, - anchorNodeMap: map[string]ast.Node{}, - anchorValueMap: map[string]reflect.Value{}, - customUnmarshalerMap: map[reflect.Type]func(interface{}, []byte) error{}, - opts: opts, - referenceReaders: []io.Reader{}, - referenceFiles: []string{}, - referenceDirs: []string{}, - isRecursiveDir: false, - isResolvedReference: false, - disallowUnknownField: false, - disallowDuplicateKey: false, - useOrderedMap: false, + reader: r, + anchorNodeMap: map[string]ast.Node{}, + anchorValueMap: map[string]reflect.Value{}, + customUnmarshalerMap: map[reflect.Type]func(interface{}, []byte) error{}, + customInterfaceUnmarshalerMap: map[reflect.Type]func(interface{}, func(interface{}) error) error{}, + opts: opts, + referenceReaders: []io.Reader{}, + referenceFiles: []string{}, + referenceDirs: []string{}, + isRecursiveDir: false, + isResolvedReference: false, + disallowUnknownField: false, + disallowDuplicateKey: false, + useOrderedMap: false, } } @@ -656,7 +658,6 @@ func (d *Decoder) unmarshalerFromCustomUnmarshalerMap(t reflect.Type) (func(inte if unmarshaler, exists := d.customUnmarshalerMap[t]; exists { return unmarshaler, exists } - globalCustomUnmarshalerMu.Lock() defer globalCustomUnmarshalerMu.Unlock() if unmarshaler, exists := globalCustomUnmarshalerMap[t]; exists { @@ -665,11 +666,40 @@ func (d *Decoder) unmarshalerFromCustomUnmarshalerMap(t reflect.Type) (func(inte return nil, false } +func (d *Decoder) existsTypeInCustomInterfaceUnmarshalerMap(t reflect.Type) bool { + if _, exists := d.customInterfaceUnmarshalerMap[t]; exists { + return true + } + + globalCustomInterfaceUnmarshalerMu.Lock() + defer globalCustomInterfaceUnmarshalerMu.Unlock() + if _, exists := globalCustomInterfaceUnmarshalerMap[t]; exists { + return true + } + return false +} + +func (d *Decoder) unmarshalerFromCustomInterfaceUnmarshalerMap(t reflect.Type) (func(interface{}, func(interface{}) error) error, bool) { + if unmarshaler, exists := d.customInterfaceUnmarshalerMap[t]; exists { + return unmarshaler, exists + } + + globalCustomInterfaceUnmarshalerMu.Lock() + defer globalCustomInterfaceUnmarshalerMu.Unlock() + if unmarshaler, exists := globalCustomInterfaceUnmarshalerMap[t]; exists { + return unmarshaler, exists + } + return nil, false +} + func (d *Decoder) canDecodeByUnmarshaler(dst reflect.Value) bool { ptrValue := dst.Addr() if d.existsTypeInCustomUnmarshalerMap(ptrValue.Type()) { return true } + if d.existsTypeInCustomInterfaceUnmarshalerMap(ptrValue.Type()) { + return true + } iface := ptrValue.Interface() switch iface.(type) { case BytesUnmarshalerContext: @@ -704,6 +734,21 @@ func (d *Decoder) decodeByUnmarshaler(ctx context.Context, dst reflect.Value, sr } return nil } + if unmarshaler, exists := d.unmarshalerFromCustomInterfaceUnmarshalerMap(ptrValue.Type()); exists { + if err := unmarshaler(ptrValue.Interface(), func(v interface{}) error { + rv := reflect.ValueOf(v) + if rv.Type().Kind() != reflect.Ptr { + return errors.ErrDecodeRequiredPointerType + } + if err := d.decodeValue(ctx, rv.Elem(), src); err != nil { + return errors.Wrapf(err, "failed to decode value") + } + return nil + }); err != nil { + return errors.Wrapf(err, "failed to UnmarshalYAML") + } + return nil + } iface := ptrValue.Interface() if unmarshaler, ok := iface.(BytesUnmarshalerContext); ok { diff --git a/decode_test.go b/decode_test.go index cabfd33c..09629e74 100644 --- a/decode_test.go +++ b/decode_test.go @@ -1940,6 +1940,113 @@ func TestDecoder_CustomUnmarshaler(t *testing.T) { t.Fatalf("failed to switch to custom unmarshaler. got: %q", v.Foo) } }) + + t.Run("override interface type", func(t *testing.T) { + type I interface{} + type T struct { + Foo string `yaml:"foo"` + } + var i I + src := []byte(`foo: "bar"`) + if err := yaml.UnmarshalWithOptions(src, &i, yaml.CustomUnmarshaler[I](func(dst *I, b []byte) error { + var v T + if err := yaml.Unmarshal(b, &v); err != nil { + t.Fatal(err) + } + if v.Foo != "bar" { + t.Fatalf("failed to use unmarshal function. got %q", v.Foo) + } + *dst = &v + return nil + })); err != nil { + t.Fatal(err) + } + if v, ok := i.(*T); ok { + if v.Foo != "bar" { + t.Fatalf("failed to decode with custom interface unmarshaler. got: %q", v.Foo) + } + } else { + t.Fatalf("failed to switch to custom interface unmarshaler.") + } + }) +} + +func TestDecoder_CustomInterfaceUnmarshaler(t *testing.T) { + t.Run("override struct type", func(t *testing.T) { + type T struct { + Foo string `yaml:"foo"` + } + src := []byte(`foo: "bar"`) + var v T + if err := yaml.UnmarshalWithOptions(src, &v, yaml.CustomInterfaceUnmarshaler[T](func(dst *T, f func(interface{}) error) error { + var m map[string]string + if err := f(&m); err != nil { + t.Fatal(err) + } + if m["foo"] != "bar" { + t.Fatalf("failed to use unmarshal function. got %q", m["foo"]) + } + dst.Foo = "bazbaz" // assign another value to target + return nil + })); err != nil { + t.Fatal(err) + } + if v.Foo != "bazbaz" { + t.Fatalf("failed to switch to custom interface unmarshaler. got: %v", v.Foo) + } + }) + t.Run("override bytes type", func(t *testing.T) { + type T struct { + Foo []byte `yaml:"foo"` + } + src := []byte(`foo: "bar"`) + var v T + if err := yaml.UnmarshalWithOptions(src, &v, yaml.CustomInterfaceUnmarshaler[[]byte](func(dst *[]byte, f func(interface{}) error) error { + var str string + if err := f(&str); err != nil { + t.Fatal(err) + } + if str != "bar" { + t.Fatalf("failed to use unmarshal function. got %q", str) + } + *dst = []byte("bazbaz") + return nil + })); err != nil { + t.Fatal(err) + } + if !bytes.Equal(v.Foo, []byte("bazbaz")) { + t.Fatalf("failed to switch to custom interface unmarshaler. got: %q", v.Foo) + } + }) + + t.Run("override interface type", func(t *testing.T) { + type I interface{} + type T struct { + Foo string `yaml:"foo"` + } + var i I + src := []byte(`foo: "bar"`) + if err := yaml.UnmarshalWithOptions(src, &i, yaml.CustomInterfaceUnmarshaler[I](func(dst *I, f func(interface{}) error) error { + var v T + if err := f(&v); err != nil { + t.Fatal(err) + } + if v.Foo != "bar" { + t.Fatalf("failed to use unmarshal function. got %q", v.Foo) + } + *dst = &v + return nil + })); err != nil { + t.Fatal(err) + } + if v, ok := i.(*T); ok { + if v.Foo != "bar" { + t.Fatalf("failed to decode with custom interface unmarshaler. got: %q", v.Foo) + } + } else { + t.Fatalf("failed to switch to custom interface unmarshaler.") + } + }) } type unmarshalContext struct { diff --git a/encode.go b/encode.go index 3b9b2981..e8cf6820 100644 --- a/encode.go +++ b/encode.go @@ -27,20 +27,21 @@ const ( // Encoder writes YAML values to an output stream. type Encoder struct { - writer io.Writer - opts []EncodeOption - indent int - indentSequence bool - singleQuote bool - isFlowStyle bool - isJSONStyle bool - useJSONMarshaler bool - anchorCallback func(*ast.AnchorNode, interface{}) error - anchorPtrToNameMap map[uintptr]string - customMarshalerMap map[reflect.Type]func(interface{}) ([]byte, error) - useLiteralStyleIfMultiline bool - commentMap map[*Path][]*Comment - written bool + writer io.Writer + opts []EncodeOption + indent int + indentSequence bool + singleQuote bool + isFlowStyle bool + isJSONStyle bool + useJSONMarshaler bool + anchorCallback func(*ast.AnchorNode, interface{}) error + anchorPtrToNameMap map[uintptr]string + customMarshalerMap map[reflect.Type]func(interface{}) ([]byte, error) + customInterfaceMarshalerMap map[reflect.Type]func(interface{}) (interface{}, error) + useLiteralStyleIfMultiline bool + commentMap map[*Path][]*Comment + written bool line int column int @@ -53,14 +54,15 @@ type Encoder struct { // The Encoder should be closed after use to flush all data to w. func NewEncoder(w io.Writer, opts ...EncodeOption) *Encoder { return &Encoder{ - writer: w, - opts: opts, - indent: DefaultIndentSpaces, - anchorPtrToNameMap: map[uintptr]string{}, - customMarshalerMap: map[reflect.Type]func(interface{}) ([]byte, error){}, - line: 1, - column: 1, - offset: 0, + writer: w, + opts: opts, + indent: DefaultIndentSpaces, + anchorPtrToNameMap: map[uintptr]string{}, + customMarshalerMap: map[reflect.Type]func(interface{}) ([]byte, error){}, + customInterfaceMarshalerMap: map[reflect.Type]func(interface{}) (interface{}, error){}, + line: 1, + column: 1, + offset: 0, } } @@ -301,6 +303,32 @@ func (e *Encoder) marshalerFromCustomMarshalerMap(t reflect.Type) (func(interfac return nil, false } +func (e *Encoder) existsTypeInCustomInterfaceMarshalerMap(t reflect.Type) bool { + if _, exists := e.customInterfaceMarshalerMap[t]; exists { + return true + } + + globalCustomInterfaceMarshalerMu.Lock() + defer globalCustomInterfaceMarshalerMu.Unlock() + if _, exists := globalCustomInterfaceMarshalerMap[t]; exists { + return true + } + return false +} + +func (e *Encoder) marshalerFromCustomInterfaceMarshalerMap(t reflect.Type) (func(interface{}) (interface{}, error), bool) { + if marshaler, exists := e.customInterfaceMarshalerMap[t]; exists { + return marshaler, exists + } + + globalCustomInterfaceMarshalerMu.Lock() + defer globalCustomInterfaceMarshalerMu.Unlock() + if marshaler, exists := globalCustomInterfaceMarshalerMap[t]; exists { + return marshaler, exists + } + return nil, false +} + func (e *Encoder) canEncodeByMarshaler(v reflect.Value) bool { if !v.CanInterface() { return false @@ -308,6 +336,9 @@ func (e *Encoder) canEncodeByMarshaler(v reflect.Value) bool { if e.existsTypeInCustomMarshalerMap(v.Type()) { return true } + if e.existsTypeInCustomInterfaceMarshalerMap(v.Type()) { + return true + } iface := v.Interface() switch iface.(type) { case BytesMarshalerContext: @@ -344,6 +375,13 @@ func (e *Encoder) encodeByMarshaler(ctx context.Context, v reflect.Value, column } return node, nil } + if marshaler, exists := e.marshalerFromCustomInterfaceMarshalerMap(v.Type()); exists { + marshalV, err := marshaler(iface) + if err != nil { + return nil, errors.Wrapf(err, "failed to MarshalYAML") + } + return e.encodeValue(ctx, reflect.ValueOf(marshalV), column) + } if marshaler, ok := iface.(BytesMarshalerContext); ok { doc, err := marshaler.MarshalYAML(ctx) diff --git a/encode_test.go b/encode_test.go index 3ff6f1c1..70a7be5f 100644 --- a/encode_test.go +++ b/encode_test.go @@ -1251,6 +1251,40 @@ func TestEncoder_CustomMarshaler(t *testing.T) { }) } +func TestEncoder_CustomInterfaceMarshaler(t *testing.T) { + t.Run("override struct type", func(t *testing.T) { + type T struct { + Foo string `yaml:"foo"` + } + b, err := yaml.MarshalWithOptions(&T{Foo: "bar"}, yaml.CustomInterfaceMarshaler[T](func(v T) (interface{}, error) { + return "override", nil + })) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(b, []byte("override\n")) { + t.Fatalf("failed to switch to custom marshaler. got: %q", b) + } + }) + t.Run("override bytes type", func(t *testing.T) { + type T struct { + Foo []byte `yaml:"foo"` + } + b, err := yaml.MarshalWithOptions(&T{Foo: []byte("bar")}, yaml.CustomInterfaceMarshaler[[]byte](func(v []byte) (interface{}, error) { + if !bytes.Equal(v, []byte("bar")) { + t.Fatalf("failed to get src buffer: %q", v) + } + return "override", nil + })) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(b, []byte("foo: override\n")) { + t.Fatalf("failed to switch to custom marshaler. got: %q", b) + } + }) +} + func TestEncoder_MultipleDocuments(t *testing.T) { var buf bytes.Buffer enc := yaml.NewEncoder(&buf) diff --git a/option.go b/option.go index eab5d43a..7c4370dd 100644 --- a/option.go +++ b/option.go @@ -109,6 +109,20 @@ func CustomUnmarshaler[T any](unmarshaler func(*T, []byte) error) DecodeOption { } } +// CustomInterfaceUnmarshaler overrides any decoding process for the type specified in generics. +// +// NOTE: If RegisterCustomUnmarshaler and CustomUnmarshaler of DecodeOption are specified for the same type, +// the CustomUnmarshaler specified in DecodeOption takes precedence. +func CustomInterfaceUnmarshaler[T any](unmarshaler func(*T, func(interface{}) error) error) DecodeOption { + return func(d *Decoder) error { + var typ *T + d.customInterfaceUnmarshalerMap[reflect.TypeOf(typ)] = func(v interface{}, f func(interface{}) error) error { + return unmarshaler(v.(*T), f) + } + return nil + } +} + // EncodeOption functional option type for Encoder type EncodeOption func(e *Encoder) error @@ -195,6 +209,21 @@ func CustomMarshaler[T any](marshaler func(T) ([]byte, error)) EncodeOption { } } +// CustomInterfaceMarshaler overrides any encoding process for the type specified in generics. +// +// NOTE: If type T implements MarshalYAML for pointer receiver, the type specified in CustomMarshaler must be *T. +// If RegisterCustomMarshaler and CustomMarshaler of EncodeOption are specified for the same type, +// the CustomMarshaler specified in EncodeOption takes precedence. +func CustomInterfaceMarshaler[T any](marshaler func(T) (interface{}, error)) EncodeOption { + return func(e *Encoder) error { + var typ T + e.customInterfaceMarshalerMap[reflect.TypeOf(typ)] = func(v interface{}) (interface{}, error) { + return marshaler(v.(T)) + } + return nil + } +} + // CommentPosition type of the position for comment. type CommentPosition int diff --git a/yaml.go b/yaml.go index 25b1056f..0cd2ec59 100644 --- a/yaml.go +++ b/yaml.go @@ -250,10 +250,14 @@ func JSONToYAML(bytes []byte) ([]byte, error) { } var ( - globalCustomMarshalerMu sync.Mutex - globalCustomUnmarshalerMu sync.Mutex - globalCustomMarshalerMap = map[reflect.Type]func(interface{}) ([]byte, error){} - globalCustomUnmarshalerMap = map[reflect.Type]func(interface{}, []byte) error{} + globalCustomMarshalerMu sync.Mutex + globalCustomUnmarshalerMu sync.Mutex + globalCustomInterfaceMarshalerMu sync.Mutex + globalCustomInterfaceUnmarshalerMu sync.Mutex + globalCustomMarshalerMap = map[reflect.Type]func(interface{}) ([]byte, error){} + globalCustomUnmarshalerMap = map[reflect.Type]func(interface{}, []byte) error{} + globalCustomInterfaceMarshalerMap = map[reflect.Type]func(interface{}) (interface{}, error){} + globalCustomInterfaceUnmarshalerMap = map[reflect.Type]func(interface{}, func(interface{}) error) error{} ) // RegisterCustomMarshaler overrides any encoding process for the type specified in generics. @@ -286,3 +290,34 @@ func RegisterCustomUnmarshaler[T any](unmarshaler func(*T, []byte) error) { return unmarshaler(v.(*T), b) } } + +// RegisterCustomInterfaceMarshaler overrides any encoding process for the type specified in generics. +// If you want to switch the behavior for each encoder, use `CustomMarshaler` defined as EncodeOption. +// +// NOTE: If type T implements MarshalYAML for pointer receiver, the type specified in RegisterCustomMarshaler must be *T. +// If RegisterCustomMarshaler and CustomMarshaler of EncodeOption are specified for the same type, +// the CustomMarshaler specified in EncodeOption takes precedence. +func RegisterCustomInterfaceMarshaler[T any](marshaler func(T) (interface{}, error)) { + globalCustomInterfaceMarshalerMu.Lock() + defer globalCustomInterfaceMarshalerMu.Unlock() + + var typ T + globalCustomInterfaceMarshalerMap[reflect.TypeOf(typ)] = func(v interface{}) (interface{}, error) { + return marshaler(v.(T)) + } +} + +// RegisterCustomInterfaceUnmarshaler overrides any decoding process for the type specified in generics. +// If you want to switch the behavior for each decoder, use `CustomUnmarshaler` defined as DecodeOption. +// +// NOTE: If RegisterCustomUnmarshaler and CustomUnmarshaler of DecodeOption are specified for the same type, +// the CustomUnmarshaler specified in DecodeOption takes precedence. +func RegisterCustomInterfaceUnmarshaler[T any](unmarshaler func(*T, func(interface{}) error) error) { + globalCustomInterfaceUnmarshalerMu.Lock() + defer globalCustomInterfaceUnmarshalerMu.Unlock() + + var typ *T + globalCustomInterfaceUnmarshalerMap[reflect.TypeOf(typ)] = func(v interface{}, f func(interface{}) error) error { + return unmarshaler(v.(*T), f) + } +} diff --git a/yaml_test.go b/yaml_test.go index 4446d31a..8da679ee 100644 --- a/yaml_test.go +++ b/yaml_test.go @@ -1282,3 +1282,44 @@ func TestRegisterCustomUnmarshaler(t *testing.T) { t.Fatalf("failed to decode. got %q", v.Foo) } } + +func TestRegisterCustomInterfaceMarshaler(t *testing.T) { + type T struct { + Foo []byte `yaml:"foo"` + } + yaml.RegisterCustomInterfaceMarshaler[T](func(_ T) (interface{}, error) { + return "override", nil + }) + b, err := yaml.Marshal(&T{Foo: []byte("bar")}) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(b, []byte("override\n")) { + t.Fatalf("failed to register custom interface marshaler. got: %q", b) + } +} + +func TestRegisterCustomInterfaceUnmarshaler(t *testing.T) { + type T struct { + Foo []byte `yaml:"foo"` + } + yaml.RegisterCustomInterfaceUnmarshaler[T](func(v *T, unmarshaler func(interface{}) error) error { + m := map[string]string{} + if err := unmarshaler(&m); err != nil { + return err + } + if m["foo"] != "bar" { + t.Fatalf("failed to use unmarshal function. got %q", m["foo"]) + return nil + } + v.Foo = []byte("override") + return nil + }) + var v T + if err := yaml.Unmarshal([]byte(`foo: "bar"`), &v); err != nil { + t.Fatal(err) + } + if !bytes.Equal(v.Foo, []byte("override")) { + t.Fatalf("failed to decode. got %q", v.Foo) + } +}