diff --git a/marshal.go b/marshal.go index ba53d6350..c40c32a25 100644 --- a/marshal.go +++ b/marshal.go @@ -23,6 +23,7 @@ import ( "github.com/gocql/gocql/serialization/bigint" "github.com/gocql/gocql/serialization/counter" "github.com/gocql/gocql/serialization/cqlint" + "github.com/gocql/gocql/serialization/float" "github.com/gocql/gocql/serialization/smallint" "github.com/gocql/gocql/serialization/tinyint" ) @@ -150,7 +151,7 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) { case TypeCounter: return marshalCounter(value) case TypeFloat: - return marshalFloat(info, value) + return marshalFloat(value) case TypeDouble: return marshalDouble(info, value) case TypeDecimal: @@ -256,7 +257,7 @@ func Unmarshal(info TypeInfo, data []byte, value interface{}) error { case TypeTinyInt: return unmarshalTinyInt(data, value) case TypeFloat: - return unmarshalFloat(info, data, value) + return unmarshalFloat(data, value) case TypeDouble: return unmarshalDouble(info, data, value) case TypeDecimal: @@ -899,47 +900,19 @@ func decBool(v []byte) bool { return v[0] != 0 } -func marshalFloat(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case float32: - return encInt(int32(math.Float32bits(v))), nil - } - - if value == nil { - return nil, nil - } - - rv := reflect.ValueOf(value) - switch rv.Type().Kind() { - case reflect.Float32: - return encInt(int32(math.Float32bits(float32(rv.Float())))), nil +func marshalFloat(value interface{}) ([]byte, error) { + data, err := float.Marshal(value) + if err != nil { + return nil, wrapMarshalError(err, "marshal error") } - return nil, marshalErrorf("can not marshal %T into %s", value, info) + return data, nil } -func unmarshalFloat(info TypeInfo, data []byte, value interface{}) error { - switch v := value.(type) { - case Unmarshaler: - return v.UnmarshalCQL(info, data) - case *float32: - *v = math.Float32frombits(uint32(decInt(data))) - return nil - } - rv := reflect.ValueOf(value) - if rv.Kind() != reflect.Ptr { - return unmarshalErrorf("can not unmarshal into non-pointer %T", value) - } - rv = rv.Elem() - switch rv.Type().Kind() { - case reflect.Float32: - rv.SetFloat(float64(math.Float32frombits(uint32(decInt(data))))) - return nil +func unmarshalFloat(data []byte, value interface{}) error { + if err := float.Unmarshal(data, value); err != nil { + return wrapUnmarshalError(err, "unmarshal error") } - return unmarshalErrorf("can not unmarshal %s into %T", info, value) + return nil } func marshalDouble(info TypeInfo, value interface{}) ([]byte, error) { diff --git a/marshal_test.go b/marshal_test.go index 86f629510..cf9ca4b24 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -118,13 +118,6 @@ var marshalTests = []struct { nil, nil, }, - { - NativeType{proto: 2, typ: TypeFloat}, - []byte("\x40\x49\x0f\xdb"), - float32(3.14159265), - nil, - nil, - }, { NativeType{proto: 2, typ: TypeDouble}, []byte("\x40\x09\x21\xfb\x53\xc8\xd4\xf1"), @@ -542,23 +535,6 @@ var marshalTests = []struct { nil, nil, }, - { - NativeType{proto: 2, typ: TypeFloat}, - []byte("\x40\x49\x0f\xdb"), - func() *float32 { - f := float32(3.14159265) - return &f - }(), - nil, - nil, - }, - { - NativeType{proto: 2, typ: TypeFloat}, - []byte(nil), - (*float32)(nil), - nil, - nil, - }, { NativeType{proto: 2, typ: TypeDouble}, []byte("\x40\x09\x21\xfb\x53\xc8\xd4\xf1"), diff --git a/serialization/float/marshal.go b/serialization/float/marshal.go new file mode 100644 index 000000000..91f4141b4 --- /dev/null +++ b/serialization/float/marshal.go @@ -0,0 +1,24 @@ +package float + +import ( + "reflect" +) + +func Marshal(value interface{}) ([]byte, error) { + switch v := value.(type) { + case nil: + return nil, nil + case float32: + return EncFloat32(v) + case *float32: + return EncFloat32R(v) + default: + // Custom types (type MyFloat float32) can be serialized only via `reflect` package. + // Later, when generic-based serialization is introduced we can do that via generics. + rv := reflect.TypeOf(value) + if rv.Kind() != reflect.Ptr { + return EncReflect(reflect.ValueOf(v)) + } + return EncReflectR(reflect.ValueOf(v)) + } +} diff --git a/serialization/float/marshal_utils.go b/serialization/float/marshal_utils.go new file mode 100644 index 000000000..ce20fcf1b --- /dev/null +++ b/serialization/float/marshal_utils.go @@ -0,0 +1,54 @@ +package float + +import ( + "fmt" + "reflect" + "unsafe" +) + +func EncFloat32(v float32) ([]byte, error) { + return encFloat32(v), nil +} + +func EncFloat32R(v *float32) ([]byte, error) { + if v == nil { + return nil, nil + } + return encFloat32R(v), nil +} + +func EncReflect(v reflect.Value) ([]byte, error) { + switch v.Kind() { + case reflect.Float32: + return encFloat32(float32(v.Float())), nil + default: + return nil, fmt.Errorf("failed to marshal float: unsupported value type (%T)(%[1]v)", v.Interface()) + } +} + +func EncReflectR(v reflect.Value) ([]byte, error) { + if v.IsNil() { + return nil, nil + } + return EncReflect(v.Elem()) +} + +func encFloat32(v float32) []byte { + return encUint32(floatToUint(v)) +} + +func encFloat32R(v *float32) []byte { + return encUint32(floatToUintR(v)) +} + +func encUint32(v uint32) []byte { + return []byte{byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)} +} + +func floatToUint(v float32) uint32 { + return *(*uint32)(unsafe.Pointer(&v)) +} + +func floatToUintR(v *float32) uint32 { + return *(*uint32)(unsafe.Pointer(v)) +} diff --git a/serialization/float/unmarshal.go b/serialization/float/unmarshal.go new file mode 100644 index 000000000..1d809b3e2 --- /dev/null +++ b/serialization/float/unmarshal.go @@ -0,0 +1,29 @@ +package float + +import ( + "fmt" + "reflect" +) + +func Unmarshal(data []byte, value interface{}) error { + switch v := value.(type) { + case nil: + return nil + case *float32: + return DecFloat32(data, v) + case **float32: + return DecFloat32R(data, v) + default: + // Custom types (type MyFloat float32) can be deserialized only via `reflect` package. + // Later, when generic-based serialization is introduced we can do that via generics. + rv := reflect.ValueOf(value) + rt := rv.Type() + if rt.Kind() != reflect.Ptr { + return fmt.Errorf("failed to unmarshal float: unsupported value type (%T)(%[1]v)", v) + } + if rt.Elem().Kind() != reflect.Ptr { + return DecReflect(data, rv) + } + return DecReflectR(data, rv) + } +} diff --git a/serialization/float/unmarshal_utils.go b/serialization/float/unmarshal_utils.go new file mode 100644 index 000000000..d4ad55eb7 --- /dev/null +++ b/serialization/float/unmarshal_utils.go @@ -0,0 +1,126 @@ +package float + +import ( + "fmt" + "reflect" + "unsafe" +) + +var errWrongDataLen = fmt.Errorf("failed to unmarshal float: the length of the data should be 0 or 4") + +func errNilReference(v interface{}) error { + return fmt.Errorf("failed to unmarshal float: can not unmarshal into nil reference(%T)(%[1]v)", v) +} + +func DecFloat32(p []byte, v *float32) error { + if v == nil { + return errNilReference(v) + } + switch len(p) { + case 0: + *v = 0 + case 4: + *v = decFloat32(p) + default: + return errWrongDataLen + } + return nil +} + +func DecFloat32R(p []byte, v **float32) error { + if v == nil { + return errNilReference(v) + } + switch len(p) { + case 0: + if p == nil { + *v = nil + } else { + *v = new(float32) + } + case 4: + *v = decFloat32R(p) + default: + return errWrongDataLen + } + return nil +} + +func DecReflect(p []byte, v reflect.Value) error { + if v.IsNil() { + return errNilReference(v) + } + + switch v = v.Elem(); v.Kind() { + case reflect.Float32: + return decReflectFloat32(p, v) + default: + return fmt.Errorf("failed to unmarshal float: unsupported value type (%T)(%[1]v)", v.Interface()) + } +} + +func DecReflectR(p []byte, v reflect.Value) error { + if v.IsNil() { + return errNilReference(v) + } + + switch v.Type().Elem().Elem().Kind() { + case reflect.Float32: + return decReflectFloat32R(p, v) + default: + return fmt.Errorf("failed to unmarshal float: unsupported value type (%T)(%[1]v)", v.Interface()) + } +} + +func decReflectFloat32(p []byte, v reflect.Value) error { + switch len(p) { + case 0: + v.SetFloat(0) + case 4: + v.SetFloat(float64(decFloat32(p))) + default: + return errWrongDataLen + } + return nil +} + +func decReflectFloat32R(p []byte, v reflect.Value) error { + switch len(p) { + case 0: + v.Elem().Set(decReflectNullableR(p, v)) + case 4: + val := reflect.New(v.Type().Elem().Elem()) + val.Elem().SetFloat(float64(decFloat32(p))) + v.Elem().Set(val) + default: + return errWrongDataLen + } + return nil +} + +func decReflectNullableR(p []byte, v reflect.Value) reflect.Value { + if p == nil { + return reflect.Zero(v.Elem().Type()) + } + return reflect.New(v.Type().Elem().Elem()) +} + +func decFloat32(p []byte) float32 { + return uint32ToFloat(decUint32(p)) +} + +func decFloat32R(p []byte) *float32 { + return uint32ToFloatR(decUint32(p)) +} + +func uint32ToFloat(v uint32) float32 { + return *(*float32)(unsafe.Pointer(&v)) +} + +func uint32ToFloatR(v uint32) *float32 { + return (*float32)(unsafe.Pointer(&v)) +} + +func decUint32(p []byte) uint32 { + return uint32(p[0])<<24 | uint32(p[1])<<16 | uint32(p[2])<<8 | uint32(p[3]) +} diff --git a/tests/serialization/marshal_8_float_corrupt_test.go b/tests/serialization/marshal_8_float_corrupt_test.go new file mode 100644 index 000000000..c149460df --- /dev/null +++ b/tests/serialization/marshal_8_float_corrupt_test.go @@ -0,0 +1,59 @@ +package serialization_test + +import ( + "testing" + + "github.com/gocql/gocql" + "github.com/gocql/gocql/internal/tests/serialization" + "github.com/gocql/gocql/internal/tests/serialization/mod" + "github.com/gocql/gocql/serialization/float" +) + +func TestMarshalFloatMustFail(t *testing.T) { + type testSuite struct { + name string + marshal func(interface{}) ([]byte, error) + unmarshal func(bytes []byte, i interface{}) error + } + + tType := gocql.NewNativeType(4, gocql.TypeFloat, "") + + testSuites := [2]testSuite{ + { + name: "serialization.float", + marshal: float.Marshal, + unmarshal: float.Unmarshal, + }, + { + name: "glob", + marshal: func(i interface{}) ([]byte, error) { + return gocql.Marshal(tType, i) + }, + unmarshal: func(bytes []byte, i interface{}) error { + return gocql.Unmarshal(tType, bytes, i) + }, + }, + } + + for _, tSuite := range testSuites { + unmarshal := tSuite.unmarshal + + t.Run(tSuite.name, func(t *testing.T) { + + serialization.NegativeUnmarshalSet{ + Data: []byte("\x80\x00\x00\x00\x00"), + Values: mod.Values{float32(0)}.AddVariants(mod.All...), + }.Run("big_data", t, unmarshal) + + serialization.NegativeUnmarshalSet{ + Data: []byte("\x80"), + Values: mod.Values{float32(0)}.AddVariants(mod.All...), + }.Run("small_data1", t, unmarshal) + + serialization.NegativeUnmarshalSet{ + Data: []byte("\x80\x00\x00"), + Values: mod.Values{float32(0)}.AddVariants(mod.All...), + }.Run("small_data2", t, unmarshal) + }) + } +} diff --git a/tests/serialization/marshal_8_float_test.go b/tests/serialization/marshal_8_float_test.go index 9cd4bed7b..728ea4012 100644 --- a/tests/serialization/marshal_8_float_test.go +++ b/tests/serialization/marshal_8_float_test.go @@ -7,58 +7,90 @@ import ( "github.com/gocql/gocql" "github.com/gocql/gocql/internal/tests/serialization" "github.com/gocql/gocql/internal/tests/serialization/mod" + "github.com/gocql/gocql/serialization/float" ) func TestMarshalFloat(t *testing.T) { + type testSuite struct { + name string + marshal func(interface{}) ([]byte, error) + unmarshal func(bytes []byte, i interface{}) error + } + tType := gocql.NewNativeType(4, gocql.TypeFloat, "") - marshal := func(i interface{}) ([]byte, error) { return gocql.Marshal(tType, i) } - unmarshal := func(bytes []byte, i interface{}) error { - return gocql.Unmarshal(tType, bytes, i) + testSuites := [2]testSuite{ + { + name: "serialization.float", + marshal: float.Marshal, + unmarshal: float.Unmarshal, + }, + { + name: "glob", + marshal: func(i interface{}) ([]byte, error) { + return gocql.Marshal(tType, i) + }, + unmarshal: func(bytes []byte, i interface{}) error { + return gocql.Unmarshal(tType, bytes, i) + }, + }, } - serialization.PositiveSet{ - Data: nil, - Values: mod.Values{(*float32)(nil)}.AddVariants(mod.CustomType), - }.Run("[nil]nullable", t, marshal, unmarshal) - - serialization.PositiveSet{ - Data: nil, - Values: mod.Values{float32(0)}.AddVariants(mod.CustomType), - }.Run("[nil]unmarshal", t, nil, unmarshal) - - serialization.PositiveSet{ - Data: make([]byte, 0), - Values: mod.Values{float32(0)}.AddVariants(mod.All...), - }.Run("[]unmarshal", t, nil, unmarshal) - - serialization.PositiveSet{ - Data: []byte("\x00\x00\x00\x00"), - Values: mod.Values{float32(0)}.AddVariants(mod.All...), - }.Run("zeros", t, marshal, unmarshal) - - serialization.PositiveSet{ - Data: []byte("\x7f\x7f\xff\xff"), - Values: mod.Values{float32(math.MaxFloat32)}.AddVariants(mod.All...), - }.Run("max", t, marshal, unmarshal) - - serialization.PositiveSet{ - Data: []byte("\x00\x00\x00\x01"), - Values: mod.Values{float32(math.SmallestNonzeroFloat32)}.AddVariants(mod.All...), - }.Run("smallest", t, marshal, unmarshal) - - serialization.PositiveSet{ - Data: []byte("\x7f\x80\x00\x00"), - Values: mod.Values{float32(math.Inf(1))}.AddVariants(mod.All...), - }.Run("inf+", t, marshal, unmarshal) - - serialization.PositiveSet{ - Data: []byte("\xff\x80\x00\x00"), - Values: mod.Values{float32(math.Inf(-1))}.AddVariants(mod.All...), - }.Run("inf-", t, marshal, unmarshal) - - serialization.PositiveSet{ - Data: []byte("\x7f\xc0\x00\x00"), - Values: mod.Values{float32(math.NaN())}.AddVariants(mod.All...), - }.Run("nan", t, marshal, unmarshal) + for _, tSuite := range testSuites { + marshal := tSuite.marshal + unmarshal := tSuite.unmarshal + + t.Run(tSuite.name, func(t *testing.T) { + + serialization.PositiveSet{ + Data: nil, + Values: mod.Values{(*float32)(nil)}.AddVariants(mod.CustomType), + }.Run("[nil]nullable", t, marshal, unmarshal) + + serialization.PositiveSet{ + Data: nil, + Values: mod.Values{float32(0)}.AddVariants(mod.CustomType), + }.Run("[nil]unmarshal", t, nil, unmarshal) + + serialization.PositiveSet{ + Data: make([]byte, 0), + Values: mod.Values{float32(0)}.AddVariants(mod.All...), + }.Run("[]unmarshal", t, nil, unmarshal) + + serialization.PositiveSet{ + Data: []byte("\x00\x00\x00\x00"), + Values: mod.Values{float32(0)}.AddVariants(mod.All...), + }.Run("zeros", t, marshal, unmarshal) + + serialization.PositiveSet{ + Data: []byte("\x7f\x7f\xff\xff"), + Values: mod.Values{float32(math.MaxFloat32)}.AddVariants(mod.All...), + }.Run("max", t, marshal, unmarshal) + + serialization.PositiveSet{ + Data: []byte("\x00\x00\x00\x01"), + Values: mod.Values{float32(math.SmallestNonzeroFloat32)}.AddVariants(mod.All...), + }.Run("smallest", t, marshal, unmarshal) + + serialization.PositiveSet{ + Data: []byte("\x7f\x80\x00\x00"), + Values: mod.Values{float32(math.Inf(1))}.AddVariants(mod.All...), + }.Run("inf+", t, marshal, unmarshal) + + serialization.PositiveSet{ + Data: []byte("\xff\x80\x00\x00"), + Values: mod.Values{float32(math.Inf(-1))}.AddVariants(mod.All...), + }.Run("inf-", t, marshal, unmarshal) + + serialization.PositiveSet{ + Data: []byte("\x7f\xc0\x00\x00"), + Values: mod.Values{float32(math.NaN())}.AddVariants(mod.All...), + }.Run("nan", t, marshal, unmarshal) + + serialization.PositiveSet{ + Data: []byte("\x40\x49\x0f\xdb"), + Values: mod.Values{float32(3.14159265)}.AddVariants(mod.All...), + }.Run("pi", t, marshal, unmarshal) + }) + } }