Skip to content

Commit

Permalink
Merge pull request #313 from illia-li/il/fix/marshal/float
Browse files Browse the repository at this point in the history
Fix `float` marshal, unmarshall functions
  • Loading branch information
dkropachev authored Oct 17, 2024
2 parents 53c3579 + 1f85aa2 commit 29787d4
Show file tree
Hide file tree
Showing 8 changed files with 383 additions and 110 deletions.
51 changes: 12 additions & 39 deletions marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/gocql/gocql/serialization/bigint"
"github.com/gocql/gocql/serialization/counter"
"github.com/gocql/gocql/serialization/cqlint"
"github.com/gocql/gocql/serialization/float"
"github.com/gocql/gocql/serialization/smallint"
"github.com/gocql/gocql/serialization/tinyint"
)
Expand Down Expand Up @@ -150,7 +151,7 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) {
case TypeCounter:
return marshalCounter(value)
case TypeFloat:
return marshalFloat(info, value)
return marshalFloat(value)
case TypeDouble:
return marshalDouble(info, value)
case TypeDecimal:
Expand Down Expand Up @@ -256,7 +257,7 @@ func Unmarshal(info TypeInfo, data []byte, value interface{}) error {
case TypeTinyInt:
return unmarshalTinyInt(data, value)
case TypeFloat:
return unmarshalFloat(info, data, value)
return unmarshalFloat(data, value)
case TypeDouble:
return unmarshalDouble(info, data, value)
case TypeDecimal:
Expand Down Expand Up @@ -899,47 +900,19 @@ func decBool(v []byte) bool {
return v[0] != 0
}

func marshalFloat(info TypeInfo, value interface{}) ([]byte, error) {
switch v := value.(type) {
case Marshaler:
return v.MarshalCQL(info)
case unsetColumn:
return nil, nil
case float32:
return encInt(int32(math.Float32bits(v))), nil
}

if value == nil {
return nil, nil
}

rv := reflect.ValueOf(value)
switch rv.Type().Kind() {
case reflect.Float32:
return encInt(int32(math.Float32bits(float32(rv.Float())))), nil
func marshalFloat(value interface{}) ([]byte, error) {
data, err := float.Marshal(value)
if err != nil {
return nil, wrapMarshalError(err, "marshal error")
}
return nil, marshalErrorf("can not marshal %T into %s", value, info)
return data, nil
}

func unmarshalFloat(info TypeInfo, data []byte, value interface{}) error {
switch v := value.(type) {
case Unmarshaler:
return v.UnmarshalCQL(info, data)
case *float32:
*v = math.Float32frombits(uint32(decInt(data)))
return nil
}
rv := reflect.ValueOf(value)
if rv.Kind() != reflect.Ptr {
return unmarshalErrorf("can not unmarshal into non-pointer %T", value)
}
rv = rv.Elem()
switch rv.Type().Kind() {
case reflect.Float32:
rv.SetFloat(float64(math.Float32frombits(uint32(decInt(data)))))
return nil
func unmarshalFloat(data []byte, value interface{}) error {
if err := float.Unmarshal(data, value); err != nil {
return wrapUnmarshalError(err, "unmarshal error")
}
return unmarshalErrorf("can not unmarshal %s into %T", info, value)
return nil
}

func marshalDouble(info TypeInfo, value interface{}) ([]byte, error) {
Expand Down
24 changes: 0 additions & 24 deletions marshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,6 @@ var marshalTests = []struct {
nil,
nil,
},
{
NativeType{proto: 2, typ: TypeFloat},
[]byte("\x40\x49\x0f\xdb"),
float32(3.14159265),
nil,
nil,
},
{
NativeType{proto: 2, typ: TypeDouble},
[]byte("\x40\x09\x21\xfb\x53\xc8\xd4\xf1"),
Expand Down Expand Up @@ -542,23 +535,6 @@ var marshalTests = []struct {
nil,
nil,
},
{
NativeType{proto: 2, typ: TypeFloat},
[]byte("\x40\x49\x0f\xdb"),
func() *float32 {
f := float32(3.14159265)
return &f
}(),
nil,
nil,
},
{
NativeType{proto: 2, typ: TypeFloat},
[]byte(nil),
(*float32)(nil),
nil,
nil,
},
{
NativeType{proto: 2, typ: TypeDouble},
[]byte("\x40\x09\x21\xfb\x53\xc8\xd4\xf1"),
Expand Down
24 changes: 24 additions & 0 deletions serialization/float/marshal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package float

import (
"reflect"
)

func Marshal(value interface{}) ([]byte, error) {
switch v := value.(type) {
case nil:
return nil, nil
case float32:
return EncFloat32(v)
case *float32:
return EncFloat32R(v)
default:
// Custom types (type MyFloat float32) can be serialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.TypeOf(value)
if rv.Kind() != reflect.Ptr {
return EncReflect(reflect.ValueOf(v))
}
return EncReflectR(reflect.ValueOf(v))
}
}
54 changes: 54 additions & 0 deletions serialization/float/marshal_utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package float

import (
"fmt"
"reflect"
"unsafe"
)

func EncFloat32(v float32) ([]byte, error) {
return encFloat32(v), nil
}

func EncFloat32R(v *float32) ([]byte, error) {
if v == nil {
return nil, nil
}
return encFloat32R(v), nil
}

func EncReflect(v reflect.Value) ([]byte, error) {
switch v.Kind() {
case reflect.Float32:
return encFloat32(float32(v.Float())), nil
default:
return nil, fmt.Errorf("failed to marshal float: unsupported value type (%T)(%[1]v)", v.Interface())
}
}

func EncReflectR(v reflect.Value) ([]byte, error) {
if v.IsNil() {
return nil, nil
}
return EncReflect(v.Elem())
}

func encFloat32(v float32) []byte {
return encUint32(floatToUint(v))
}

func encFloat32R(v *float32) []byte {
return encUint32(floatToUintR(v))
}

func encUint32(v uint32) []byte {
return []byte{byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}
}

func floatToUint(v float32) uint32 {
return *(*uint32)(unsafe.Pointer(&v))
}

func floatToUintR(v *float32) uint32 {
return *(*uint32)(unsafe.Pointer(v))
}
29 changes: 29 additions & 0 deletions serialization/float/unmarshal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package float

import (
"fmt"
"reflect"
)

func Unmarshal(data []byte, value interface{}) error {
switch v := value.(type) {
case nil:
return nil
case *float32:
return DecFloat32(data, v)
case **float32:
return DecFloat32R(data, v)
default:
// Custom types (type MyFloat float32) can be deserialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.ValueOf(value)
rt := rv.Type()
if rt.Kind() != reflect.Ptr {
return fmt.Errorf("failed to unmarshal float: unsupported value type (%T)(%[1]v)", v)
}
if rt.Elem().Kind() != reflect.Ptr {
return DecReflect(data, rv)
}
return DecReflectR(data, rv)
}
}
126 changes: 126 additions & 0 deletions serialization/float/unmarshal_utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package float

import (
"fmt"
"reflect"
"unsafe"
)

var errWrongDataLen = fmt.Errorf("failed to unmarshal float: the length of the data should be 0 or 4")

func errNilReference(v interface{}) error {
return fmt.Errorf("failed to unmarshal float: can not unmarshal into nil reference(%T)(%[1]v)", v)
}

func DecFloat32(p []byte, v *float32) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = 0
case 4:
*v = decFloat32(p)
default:
return errWrongDataLen
}
return nil
}

func DecFloat32R(p []byte, v **float32) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(float32)
}
case 4:
*v = decFloat32R(p)
default:
return errWrongDataLen
}
return nil
}

func DecReflect(p []byte, v reflect.Value) error {
if v.IsNil() {
return errNilReference(v)
}

switch v = v.Elem(); v.Kind() {
case reflect.Float32:
return decReflectFloat32(p, v)
default:
return fmt.Errorf("failed to unmarshal float: unsupported value type (%T)(%[1]v)", v.Interface())
}
}

func DecReflectR(p []byte, v reflect.Value) error {
if v.IsNil() {
return errNilReference(v)
}

switch v.Type().Elem().Elem().Kind() {
case reflect.Float32:
return decReflectFloat32R(p, v)
default:
return fmt.Errorf("failed to unmarshal float: unsupported value type (%T)(%[1]v)", v.Interface())
}
}

func decReflectFloat32(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.SetFloat(0)
case 4:
v.SetFloat(float64(decFloat32(p)))
default:
return errWrongDataLen
}
return nil
}

func decReflectFloat32R(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.Elem().Set(decReflectNullableR(p, v))
case 4:
val := reflect.New(v.Type().Elem().Elem())
val.Elem().SetFloat(float64(decFloat32(p)))
v.Elem().Set(val)
default:
return errWrongDataLen
}
return nil
}

func decReflectNullableR(p []byte, v reflect.Value) reflect.Value {
if p == nil {
return reflect.Zero(v.Elem().Type())
}
return reflect.New(v.Type().Elem().Elem())
}

func decFloat32(p []byte) float32 {
return uint32ToFloat(decUint32(p))
}

func decFloat32R(p []byte) *float32 {
return uint32ToFloatR(decUint32(p))
}

func uint32ToFloat(v uint32) float32 {
return *(*float32)(unsafe.Pointer(&v))
}

func uint32ToFloatR(v uint32) *float32 {
return (*float32)(unsafe.Pointer(&v))
}

func decUint32(p []byte) uint32 {
return uint32(p[0])<<24 | uint32(p[1])<<16 | uint32(p[2])<<8 | uint32(p[3])
}
Loading

0 comments on commit 29787d4

Please sign in to comment.