Skip to content

Commit

Permalink
x/schema: test fallback conversions
Browse files Browse the repository at this point in the history
  • Loading branch information
widmogrod committed May 11, 2024
1 parent ab496ad commit abe5d83
Show file tree
Hide file tree
Showing 2 changed files with 242 additions and 46 deletions.
141 changes: 95 additions & 46 deletions x/schema/go.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,15 +147,15 @@ func ToGoPrimitive(x Schema) (any, error) {
}

func ToGoG[A any](x Schema) (res A, err error) {
//defer func() {
// if r := recover(); r != nil {
// if e, ok := r.(error); ok {
// err = fmt.Errorf("schema.ToGoG: panic recover; %w", e)
// } else {
// err = fmt.Errorf("schema.ToGoG: panic recover; %#v", e)
// }
// }
//}()
defer func() {
if r := recover(); r != nil {
if e, ok := r.(error); ok {
err = fmt.Errorf("schema.ToGoG: panic recover; %w", e)
} else {
err = fmt.Errorf("schema.ToGoG: panic recover; %#v", e)
}
}
}()

res = ToGo[A](x)
return
Expand All @@ -174,48 +174,77 @@ func ToGo[A any](x Schema) A {
switch any(result).(type) {
case int:
return any(int(y)).(A)
//case int8:
// return any(int8(y)).(A)
//case int16:
// return any(int16(y)).(A)
//case int32:
// return any(int32(y)).(A)
//case int64:
// return any(int64(y)).(A)
//case uint:
// return any(uint(y)).(A)
//case uint8:
// return any(uint8(y)).(A)
//case uint16:
// return any(uint16(y)).(A)
//case uint32:
// return any(uint32(y)).(A)
//case uint64:
// return any(uint64(y)).(A)
//case float32:
// return any(float32(y)).(A)
case int8:
return any(int8(y)).(A)
case int16:
return any(int16(y)).(A)
case int32:
return any(int32(y)).(A)
case int64:
return any(int64(y)).(A)
case uint:
return any(uint(y)).(A)
case uint8:
return any(uint8(y)).(A)
case uint16:
return any(uint16(y)).(A)
case uint32:
return any(uint32(y)).(A)
case uint64:
return any(uint64(y)).(A)
case float32:
return any(float32(y)).(A)
case float64:
return any(float64(y)).(A)
}
}

return value.(A)
}

v := reflect.TypeOf(new(A)).Elem()
original := shape.MkRefNameFromReflect(v)

s, found := shape.LookupShape(original)
if !found {
panic(fmt.Errorf("schema.FromGo: shape.RefName not found %s; %w", v.String(), shape.ErrShapeNotFound))
if found {
s = shape.IndexWith(s, original)

value, err := ToGoReflect(s, x, v)
if err != nil {
panic(fmt.Errorf("schema.ToGo: %w", err))
}

return value.Interface().(A)
}

s = shape.IndexWith(s, original)
str, ok := x.(*String)
if ok {
// to properly fallback, type needs to have MarshalJSON/UnmarshalJSON methods
res := unmarshalFallback(reflect.ValueOf(new(A)), str, *new(A))
val := res.(*A)
return *val
}

panic(fmt.Errorf("schema.ToGo: cannot build type %T", *new(A)))
}

value, err := ToGoReflect(s, x, v)
if err != nil {
panic(fmt.Errorf("schema.ToGo: %w", err))
func unmarshalFallback(ref reflect.Value, str *String, typ any) any {
marshal := ref.MethodByName("MarshalJSON")
unmarshal := ref.MethodByName("UnmarshalJSON")
if marshal.IsZero() && unmarshal.IsZero() {
panic(fmt.Errorf("schema.ToGo: shape.RefName not found for %T", typ))
}

return value.Interface().(A)
res := unmarshal.Call([]reflect.Value{reflect.ValueOf([]byte(*str))})
if len(res) != 1 {
panic(fmt.Errorf("schema.ToGo: %T.UnmarshalJSON() expected 1 return value, got %d", typ, len(res)))
}

if res[0].IsZero() {
return ref.Interface()
}

panic(fmt.Errorf("schema.ToGo: %T.UnmarshalJSON() error: %w", typ, res[0].Interface().(error)))
}

func FromGo[A any](x A) Schema {
Expand All @@ -224,11 +253,30 @@ func FromGo[A any](x A) Schema {
}

s, found := shape.LookupShapeReflectAndIndex[A]()
if !found {
panic(fmt.Errorf("schema.FromGo: shape.RefName not found for %T; %w", *new(A), shape.ErrShapeNotFound))
if found {
return FromGoReflect(s, reflect.ValueOf(x))
}

return marshalFallback(reflect.ValueOf(x), *new(A))
}

func marshalFallback(ref reflect.Value, typ any) Schema {
marshal := ref.MethodByName("MarshalJSON")
unmarshal := ref.MethodByName("UnmarshalJSON")
if marshal.IsZero() && unmarshal.IsZero() {
panic(fmt.Errorf("schema.FromGo: shape.RefName not found for %T; %w", typ, shape.ErrShapeNotFound))
}

res := marshal.Call(nil)
if len(res) != 2 {
panic(fmt.Errorf("schema.FromGo: %T.MarshalJSON() expected 2 return values, got %d", typ, len(res)))
}

if res[1].IsZero() {
return MkString(string(res[0].Bytes()))
}

return FromGoReflect(s, reflect.ValueOf(x))
panic(fmt.Errorf("schema.FromGo: %T.MarshalJSON() error: %w", typ, res[1].Interface().(error)))
}

func FromGoReflect(xschema shape.Shape, yreflect reflect.Value) Schema {
Expand All @@ -239,15 +287,16 @@ func FromGoReflect(xschema shape.Shape, yreflect reflect.Value) Schema {
},
func(x *shape.RefName) Schema {
y, found := shape.LookupShape(x)
if !found {
panic(fmt.Errorf("schema.FromGoReflect: shape.RefName not found %s; %w",
shape.ToGoTypeName(x, shape.WithPkgImportName()),
shape.ErrShapeNotFound))
}
if found {
y = shape.IndexWith(y, x)

y = shape.IndexWith(y, x)
return FromGoReflect(y, yreflect)
}

return FromGoReflect(y, yreflect)
// Convert types that are not registered in shape registry, or don't have schema mapping, like time.Time, etc.
// to String, but only when they have MarshalJSON/UnmarshalJSON methods.
// Because JSON is quite popular format, this should cover most of the cases.
return marshalFallback(yreflect, shape.ToGoTypeName(x, shape.WithPkgImportName()))
},
func(x *shape.PointerLike) Schema {
if yreflect.IsNil() {
Expand Down
147 changes: 147 additions & 0 deletions x/schema/go_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package schema

import (
"encoding/json"
"github.com/google/go-cmp/cmp"
"testing"
"testing/quick"
)

func TestNative(t *testing.T) {
t.Run("int", func(t *testing.T) {
assertTypeConversion(t, 1)
})
t.Run("int8", func(t *testing.T) {
if err := quick.Check(func(x int8) bool {
assertTypeConversion(t, x)
return true
}, nil); err != nil {
t.Error(err)
}
})
t.Run("int16", func(t *testing.T) {
if err := quick.Check(func(x int16) bool {
assertTypeConversion(t, x)
return true
}, nil); err != nil {
t.Error(err)
}
})
t.Run("int32", func(t *testing.T) {
if err := quick.Check(func(x int32) bool {
assertTypeConversion(t, x)
return true
}, nil); err != nil {
t.Error(err)
}
})
t.Run("int64", func(t *testing.T) {
t.Skip("boundary conversion issue because *Number is float64")
if err := quick.Check(func(x int64) bool {
assertTypeConversion(t, x)
return true
}, nil); err != nil {
t.Error(err)
}
})
t.Run("uint", func(t *testing.T) {
t.Skip("boundary conversion issue because *Number is float64")
if err := quick.Check(func(x uint) bool {
assertTypeConversion(t, x)
return true
}, nil); err != nil {
t.Error(err)
}
})

t.Run("uint8", func(t *testing.T) {
if err := quick.Check(func(x uint8) bool {
assertTypeConversion(t, x)
return true
}, nil); err != nil {
t.Error(err)
}
})
t.Run("uint16", func(t *testing.T) {
if err := quick.Check(func(x uint16) bool {
assertTypeConversion(t, x)
return true
}, nil); err != nil {
t.Error(err)
}

})
t.Run("uint32", func(t *testing.T) {
if err := quick.Check(func(x uint32) bool {
assertTypeConversion(t, x)
return true
}, nil); err != nil {
t.Error(err)
}
})
t.Run("uint64", func(t *testing.T) {
t.Skip("boundary conversion issue because *Number is float64")
if err := quick.Check(func(x uint64) bool {
assertTypeConversion(t, x)
return true
}, nil); err != nil {
t.Error(err)
}
})
t.Run("float32", func(t *testing.T) {
if err := quick.Check(func(x float32) bool {
assertTypeConversion(t, x)
return true
}, nil); err != nil {
t.Error(err)
}
})
t.Run("float64", func(t *testing.T) {
if err := quick.Check(func(x float64) bool {
assertTypeConversion(t, x)
return true
}, nil); err != nil {
t.Error(err)
}
})
t.Run("string", func(t *testing.T) {
if err := quick.Check(func(x string) bool {
assertTypeConversion(t, x)
return true
}, nil); err != nil {
t.Error(err)
}
})
t.Run("[]byte", func(t *testing.T) {
if err := quick.Check(func(x []byte) bool {
assertTypeConversion(t, x)
return true
}, nil); err != nil {
t.Error(err)
}
})
}

func TestNonNative(t *testing.T) {
t.Run("json.RawMessage", func(t *testing.T) {
assertTypeConversion(t, json.RawMessage(`{"hello": "world"}`))
})
t.Run("time.Time", func(t *testing.T) {
assertTypeConversion(t, "2021-01-01T00:00:00Z")
})
}

func assertTypeConversion[A any](t *testing.T, value A) {
expected := value
t.Logf("expected = %+#v", expected)

schemed := FromGo[A](expected)
t.Logf(" FromGo = %+#v", schemed)

result := ToGo[A](schemed)
t.Logf(" ToGo = %+#v", result)

if diff := cmp.Diff(expected, result); diff != "" {
t.Error(diff)
}
}

0 comments on commit abe5d83

Please sign in to comment.