diff --git a/gonull.go b/gonull.go index e8dd993..c43e1c9 100644 --- a/gonull.go +++ b/gonull.go @@ -99,22 +99,27 @@ func zeroValue[T any]() T { // convertToType is a helper function that attempts to convert the given value to type T. // This function is used by Scan to properly handle value conversion, ensuring that Nullable values are always of the correct type. +// ErrUnsupportedConversion is returned when a conversion cannot be made to the generic type T. func convertToType[T any](value interface{}) (T, error) { - switch v := value.(type) { - case T: - return v, nil - case int64: - // This case handles the situation when the input value is of type int64. - // It attempts to convert the int64 value to the target numeric type T if possible. - // If the conversion is successful, it returns the converted value of type T and a nil error. - // If the conversion is not possible, the function will continue to the next case (return an error). - switch t := reflect.Zero(reflect.TypeOf((*T)(nil)).Elem()).Interface().(type) { - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - if reflect.TypeOf(t).ConvertibleTo(reflect.TypeOf((*T)(nil)).Elem()) { - return reflect.ValueOf(value).Convert(reflect.TypeOf((*T)(nil)).Elem()).Interface().(T), nil - } + var zero T + if value == nil { + return zero, nil + } + + if reflect.TypeOf(value) == reflect.TypeOf(zero) { + return value.(T), nil + } + + // Check if the value is a numeric type and if T is also a numeric type. + valueType := reflect.TypeOf(value) + targetType := reflect.TypeOf(zero) + if valueType.Kind() >= reflect.Int && valueType.Kind() <= reflect.Float64 && + targetType.Kind() >= reflect.Int && targetType.Kind() <= reflect.Float64 { + if valueType.ConvertibleTo(targetType) { + convertedValue := reflect.ValueOf(value).Convert(targetType) + return convertedValue.Interface().(T), nil } } - var zero T + return zero, ErrUnsupportedConversion } diff --git a/gonull_test.go b/gonull_test.go index d4e3bed..c03615d 100644 --- a/gonull_test.go +++ b/gonull_test.go @@ -276,11 +276,8 @@ func TestNullableScanWithCustomEnum(t *testing.T) { model := TestModel{ID: 1, Field: gonull.NewNullable(TestEnumA)} err := model.Field.Scan(sqlReturnedValue) - if err != nil { - assert.Error(t, err, "Scan failed with unsupported type conversion") - } else { - assert.Equal(t, TestEnumA, model.Field.Val, "Scanned value does not match expected enum value") - } + assert.NoError(t, err, "Scan failed with unsupported type conversion") + assert.Equal(t, TestEnumA, model.Field.Val, "Scanned value does not match expected enum value") }