Skip to content

Commit

Permalink
Merge branch 'main' into add_or_else
Browse files Browse the repository at this point in the history
  • Loading branch information
LukaGiorgadze authored Jan 22, 2024
2 parents 54e874d + 032d23a commit d80b30d
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 1 deletion.
50 changes: 49 additions & 1 deletion gonull.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"reflect"
)

Expand Down Expand Up @@ -68,7 +69,54 @@ func (n Nullable[T]) Value() (driver.Value, error) {
if valuer, ok := interface{}(n.Val).(driver.Valuer); ok {
return valuer.Value()
}
return n.Val, nil

return convertToDriverValue(n.Val)
}

func convertToDriverValue(v any) (driver.Value, error) {
if valuer, ok := v.(driver.Valuer); ok {
return valuer.Value()
}

rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Pointer:
if rv.IsNil() {
return nil, nil
}
return convertToDriverValue(rv.Elem().Interface())

case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return rv.Int(), nil

case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
return int64(rv.Uint()), nil

case reflect.Uint64:
u64 := rv.Uint()
if u64 >= 1<<63 {
return nil, fmt.Errorf("uint64 values with high bit set are not supported")
}
return int64(u64), nil

case reflect.Float32, reflect.Float64:
return rv.Float(), nil

case reflect.Bool:
return rv.Bool(), nil

case reflect.Slice:
if rv.Type().Elem().Kind() == reflect.Uint8 {
return rv.Bytes(), nil
}
return nil, fmt.Errorf("unsupported slice type: %s", rv.Type().Elem().Kind())

case reflect.String:
return rv.String(), nil

default:
return nil, fmt.Errorf("unsupported type: %T", v)
}
}

// UnmarshalJSON implements the json.Unmarshaler interface for Nullable, allowing it to be used as a nullable field in JSON operations.
Expand Down
97 changes: 97 additions & 0 deletions gonull_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"reflect"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -504,3 +505,99 @@ func TestNullableOrElse(t *testing.T) {
var empty Nullable[string]
assert.Equal(t, "world", empty.OrElse("world"))
}

type customValuer struct {
value any
err error
}

func (cv customValuer) Value() (driver.Value, error) {
return cv.value, cv.err
}

func TestConvertToDriverValue(t *testing.T) {
var (
intVal int = 123
int8Val int8 = 12
int16Val int16 = 1234
int32Val int32 = 12345
int64Val int64 = 123456
uintVal uint = 123
uint8Val uint8 = 12
uint16Val uint16 = 1234
uint32Val uint32 = 12345
uint64Val uint64 = 1 << 62
float32Val float32 = 12.34
float64Val float64 = 123.456
boolVal bool = true
stringVal string = "test"
byteSlice []byte = []byte("byte slice")
ptrToInt *int = &intVal
nilPtr *int = nil
valuerSuccess customValuer = customValuer{value: "valuer value", err: nil}
valuerError customValuer = customValuer{err: errors.New("valuer error")}
unsupportedSlice = []int{1, 2, 3}
)

tests := []struct {
name string
value any
want driver.Value
wantErr bool
}{
{"Int", intVal, int64(intVal), false},
{"Int8", int8Val, int64(int8Val), false},
{"Int16", int16Val, int64(int16Val), false},
{"Int32", int32Val, int64(int32Val), false},
{"Int64", int64Val, int64(int64Val), false},
{"Uint", uintVal, int64(uintVal), false},
{"Uint8", uint8Val, int64(uint8Val), false},
{"Uint16", uint16Val, int64(uint16Val), false},
{"Uint32", uint32Val, int64(uint32Val), false},
{"Uint64", uint64Val, int64(uint64Val), false},
{"Float32", float32Val, float64(float32Val), false},
{"Float64", float64Val, float64(float64Val), false},
{"Bool", boolVal, boolVal, false},
{"String", stringVal, stringVal, false},
{"ByteSlice", byteSlice, byteSlice, false},
{"PointerToInt", ptrToInt, int64(*ptrToInt), false},
{"NilPointer", nilPtr, nil, false},
{"UnsupportedType", struct{}{}, nil, true},
{"Uint64HighBitSet", uint64(1 << 63), nil, true}, // Uint64 with high bit set
{"ValuerInterfaceSuccess", valuerSuccess, "valuer value", false},
{"ValuerInterfaceError", valuerError, nil, true},
{"UnsupportedSliceType", unsupportedSlice, nil, true},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := convertToDriverValue(tt.value)
if (err != nil) != tt.wantErr {
t.Errorf("convertToDriverValue() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("convertToDriverValue() = %v, want %v", got, tt.want)
}
})
}
}

func TestNullableValue_Uint32(t *testing.T) {
uint32Val := uint32(12345)
nullableUint32 := NewNullable(uint32Val)

convertedValue, err := nullableUint32.Value()

if err != nil {
t.Fatalf("Nullable[uint32].Value() returned an error: %v", err)
}

if _, ok := convertedValue.(int64); !ok {
t.Fatalf("Nullable[uint32].Value() returned a non-int64 type: %T", convertedValue)
}

if int64(uint32Val) != convertedValue.(int64) {
t.Errorf("Nullable[uint32].Value() returned %v, want %v", convertedValue, uint32Val)
}
}

0 comments on commit d80b30d

Please sign in to comment.