diff --git a/gonull.go b/gonull.go index 19795a8..9b3126e 100644 --- a/gonull.go +++ b/gonull.go @@ -3,6 +3,7 @@ package gonull import ( + "database/sql" "database/sql/driver" "encoding/json" "errors" @@ -43,6 +44,14 @@ func (n *Nullable[T]) Scan(value any) error { return nil } + if scanner, ok := interface{}(&n.Val).(sql.Scanner); ok { + if err := scanner.Scan(value); err != nil { + return err + } + n.Valid = true + return nil + } + var err error n.Val, err = convertToType[T](value) n.Valid = err == nil @@ -55,6 +64,10 @@ func (n Nullable[T]) Value() (driver.Value, error) { if !n.Valid { return nil, nil } + + if valuer, ok := interface{}(n.Val).(driver.Valuer); ok { + return valuer.Value() + } return n.Val, nil } diff --git a/gonull_test.go b/gonull_test.go index 8cc9ec1..4ffc6eb 100644 --- a/gonull_test.go +++ b/gonull_test.go @@ -3,6 +3,7 @@ package gonull import ( "database/sql/driver" "encoding/json" + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -411,3 +412,74 @@ func TestPresent(t *testing.T) { assert.Equal(t, true, nullable3.Foo.Present) assert.Nil(t, nullable3.Foo.Val) } + +type testValuerScannerStruct struct { + b []byte +} + +//goland:noinspection GoMixedReceiverTypes +func (t testValuerScannerStruct) Value() (driver.Value, error) { + return t.b, nil +} + +//goland:noinspection GoMixedReceiverTypes +func (t *testValuerScannerStruct) Scan(src any) error { + if src == nil { + return nil + } + switch v := src.(type) { + case string: + t.b = []byte(v) + return nil + case []byte: + t.b = v + return nil + default: + return fmt.Errorf("unsupported type: %t", v) + } +} + +func TestValuerAndScanner(t *testing.T) { + valueNullable1 := Nullable[testValuerScannerStruct]{ + Val: testValuerScannerStruct{b: []byte("test output string")}, + Valid: true, + Present: true, + } + valueNullable2 := Nullable[testValuerScannerStruct]{ + Valid: false, + Present: true, + } + + valueResult1, valueErr1 := valueNullable1.Value() + assert.NoError(t, valueErr1) + assert.Equal(t, []byte("test output string"), valueResult1) + + valueResult2, valueErr2 := valueNullable2.Value() + assert.NoError(t, valueErr2) + assert.Equal(t, nil, valueResult2) + + scannerData1 := []byte("test input string") + + var scannerNullable1 Nullable[testValuerScannerStruct] + var scannerNullable2 Nullable[testValuerScannerStruct] + + scannerErr1 := scannerNullable1.Scan(scannerData1) + assert.NoError(t, scannerErr1) + assert.Equal(t, Nullable[testValuerScannerStruct]{ + Present: true, + Valid: true, + Val: testValuerScannerStruct{ + b: []byte("test input string"), + }, + }, scannerNullable1) + + scannerErr2 := scannerNullable2.Scan(nil) + assert.NoError(t, scannerErr2) + assert.Equal(t, Nullable[testValuerScannerStruct]{ + Present: true, + Valid: false, + Val: testValuerScannerStruct{ + b: []byte(nil), + }, + }, scannerNullable2) +}