diff --git a/gonull.go b/gonull.go index dfe4a0e..15377a4 100644 --- a/gonull.go +++ b/gonull.go @@ -7,6 +7,7 @@ import ( "database/sql/driver" "encoding/json" "errors" + "fmt" "reflect" ) @@ -68,7 +69,54 @@ func (n Nullable[T]) Value() (driver.Value, error) { if valuer, ok := interface{}(n.Val).(driver.Valuer); ok { return valuer.Value() } - return n.Val, nil + + return convertToDriverValue(n.Val) +} + +func convertToDriverValue(v any) (driver.Value, error) { + if valuer, ok := v.(driver.Valuer); ok { + return valuer.Value() + } + + rv := reflect.ValueOf(v) + switch rv.Kind() { + case reflect.Pointer: + if rv.IsNil() { + return nil, nil + } + return convertToDriverValue(rv.Elem().Interface()) + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return rv.Int(), nil + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: + return int64(rv.Uint()), nil + + case reflect.Uint64: + u64 := rv.Uint() + if u64 >= 1<<63 { + return nil, fmt.Errorf("uint64 values with high bit set are not supported") + } + return int64(u64), nil + + case reflect.Float32, reflect.Float64: + return rv.Float(), nil + + case reflect.Bool: + return rv.Bool(), nil + + case reflect.Slice: + if rv.Type().Elem().Kind() == reflect.Uint8 { + return rv.Bytes(), nil + } + return nil, fmt.Errorf("unsupported slice type: %s", rv.Type().Elem().Kind()) + + case reflect.String: + return rv.String(), nil + + default: + return nil, fmt.Errorf("unsupported type: %T", v) + } } // UnmarshalJSON implements the json.Unmarshaler interface for Nullable, allowing it to be used as a nullable field in JSON operations. diff --git a/gonull_test.go b/gonull_test.go index fb63f06..bc113e0 100644 --- a/gonull_test.go +++ b/gonull_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "reflect" "testing" "github.com/stretchr/testify/assert" @@ -504,3 +505,99 @@ func TestNullableOrElse(t *testing.T) { var empty Nullable[string] assert.Equal(t, "world", empty.OrElse("world")) } + +type customValuer struct { + value any + err error +} + +func (cv customValuer) Value() (driver.Value, error) { + return cv.value, cv.err +} + +func TestConvertToDriverValue(t *testing.T) { + var ( + intVal int = 123 + int8Val int8 = 12 + int16Val int16 = 1234 + int32Val int32 = 12345 + int64Val int64 = 123456 + uintVal uint = 123 + uint8Val uint8 = 12 + uint16Val uint16 = 1234 + uint32Val uint32 = 12345 + uint64Val uint64 = 1 << 62 + float32Val float32 = 12.34 + float64Val float64 = 123.456 + boolVal bool = true + stringVal string = "test" + byteSlice []byte = []byte("byte slice") + ptrToInt *int = &intVal + nilPtr *int = nil + valuerSuccess customValuer = customValuer{value: "valuer value", err: nil} + valuerError customValuer = customValuer{err: errors.New("valuer error")} + unsupportedSlice = []int{1, 2, 3} + ) + + tests := []struct { + name string + value any + want driver.Value + wantErr bool + }{ + {"Int", intVal, int64(intVal), false}, + {"Int8", int8Val, int64(int8Val), false}, + {"Int16", int16Val, int64(int16Val), false}, + {"Int32", int32Val, int64(int32Val), false}, + {"Int64", int64Val, int64(int64Val), false}, + {"Uint", uintVal, int64(uintVal), false}, + {"Uint8", uint8Val, int64(uint8Val), false}, + {"Uint16", uint16Val, int64(uint16Val), false}, + {"Uint32", uint32Val, int64(uint32Val), false}, + {"Uint64", uint64Val, int64(uint64Val), false}, + {"Float32", float32Val, float64(float32Val), false}, + {"Float64", float64Val, float64(float64Val), false}, + {"Bool", boolVal, boolVal, false}, + {"String", stringVal, stringVal, false}, + {"ByteSlice", byteSlice, byteSlice, false}, + {"PointerToInt", ptrToInt, int64(*ptrToInt), false}, + {"NilPointer", nilPtr, nil, false}, + {"UnsupportedType", struct{}{}, nil, true}, + {"Uint64HighBitSet", uint64(1 << 63), nil, true}, // Uint64 with high bit set + {"ValuerInterfaceSuccess", valuerSuccess, "valuer value", false}, + {"ValuerInterfaceError", valuerError, nil, true}, + {"UnsupportedSliceType", unsupportedSlice, nil, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := convertToDriverValue(tt.value) + if (err != nil) != tt.wantErr { + t.Errorf("convertToDriverValue() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("convertToDriverValue() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNullableValue_Uint32(t *testing.T) { + uint32Val := uint32(12345) + nullableUint32 := NewNullable(uint32Val) + + convertedValue, err := nullableUint32.Value() + + if err != nil { + t.Fatalf("Nullable[uint32].Value() returned an error: %v", err) + } + + if _, ok := convertedValue.(int64); !ok { + t.Fatalf("Nullable[uint32].Value() returned a non-int64 type: %T", convertedValue) + } + + if int64(uint32Val) != convertedValue.(int64) { + t.Errorf("Nullable[uint32].Value() returned %v, want %v", convertedValue, uint32Val) + } +}