Skip to content

Commit 36bbd67

Browse files
authored
Add ColumnTypeScanType to driver (#199).
1 parent 7f5ea54 commit 36bbd67

File tree

3 files changed

+185
-6
lines changed

3 files changed

+185
-6
lines changed

driver/driver.go

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ import (
8181
"fmt"
8282
"io"
8383
"net/url"
84+
"reflect"
8485
"strings"
8586
"time"
8687
"unsafe"
@@ -579,8 +580,22 @@ type rows struct {
579580
names []string
580581
types []string
581582
nulls []bool
583+
scans []scantype
582584
}
583585

586+
type scantype byte
587+
588+
const (
589+
_ANY scantype = iota
590+
_INT scantype = scantype(sqlite3.INTEGER)
591+
_REAL scantype = scantype(sqlite3.FLOAT)
592+
_TEXT scantype = scantype(sqlite3.TEXT)
593+
_BLOB scantype = scantype(sqlite3.BLOB)
594+
_NULL scantype = scantype(sqlite3.NULL)
595+
_BOOL scantype = iota
596+
_TIME
597+
)
598+
584599
var (
585600
// Ensure these interfaces are implemented:
586601
_ driver.RowsColumnTypeDatabaseTypeName = &rows{}
@@ -604,22 +619,42 @@ func (r *rows) Columns() []string {
604619
return r.names
605620
}
606621

607-
func (r *rows) loadTypes() {
622+
func (r *rows) loadColumnMetadata() {
608623
if r.nulls == nil {
609624
count := r.Stmt.ColumnCount()
610625
nulls := make([]bool, count)
611626
types := make([]string, count)
627+
scans := make([]scantype, count)
612628
for i := range nulls {
613629
if col := r.Stmt.ColumnOriginName(i); col != "" {
614630
types[i], _, nulls[i], _, _, _ = r.Stmt.Conn().TableColumnMetadata(
615631
r.Stmt.ColumnDatabaseName(i),
616632
r.Stmt.ColumnTableName(i),
617633
col)
618634
types[i] = strings.ToUpper(types[i])
635+
// These types are only used before we have rows,
636+
// and otherwise as type hints.
637+
// The first few ensure STRICT tables are strictly typed.
638+
// The other two are type hints for booleans and time.
639+
switch types[i] {
640+
case "INT", "INTEGER":
641+
scans[i] = _INT
642+
case "REAL":
643+
scans[i] = _REAL
644+
case "TEXT":
645+
scans[i] = _TEXT
646+
case "BLOB":
647+
scans[i] = _BLOB
648+
case "BOOLEAN":
649+
scans[i] = _BOOL
650+
case "DATE", "TIME", "DATETIME", "TIMESTAMP":
651+
scans[i] = _TIME
652+
}
619653
}
620654
}
621655
r.nulls = nulls
622656
r.types = types
657+
r.scans = scans
623658
}
624659
}
625660

@@ -636,7 +671,7 @@ func (r *rows) declType(index int) string {
636671
}
637672

638673
func (r *rows) ColumnTypeDatabaseTypeName(index int) string {
639-
r.loadTypes()
674+
r.loadColumnMetadata()
640675
decltype := r.types[index]
641676
if len := len(decltype); len > 0 && decltype[len-1] == ')' {
642677
if i := strings.LastIndexByte(decltype, '('); i >= 0 {
@@ -647,13 +682,57 @@ func (r *rows) ColumnTypeDatabaseTypeName(index int) string {
647682
}
648683

649684
func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) {
650-
r.loadTypes()
685+
r.loadColumnMetadata()
651686
if r.nulls[index] {
652687
return false, true
653688
}
654689
return true, false
655690
}
656691

692+
func (r *rows) ColumnTypeScanType(index int) (typ reflect.Type) {
693+
r.loadColumnMetadata()
694+
scan := r.scans[index]
695+
696+
if r.Stmt.Busy() {
697+
// SQLite is dynamically typed and we now have a row.
698+
// Always use the type of the value itself,
699+
// unless the scan type is more specific
700+
// and can scan the actual value.
701+
val := scantype(r.Stmt.ColumnType(index))
702+
useValType := true
703+
switch {
704+
case scan == _TIME && val != _BLOB && val != _NULL:
705+
t := r.Stmt.ColumnTime(index, r.tmRead)
706+
useValType = t == time.Time{}
707+
case scan == _BOOL && val == _INT:
708+
i := r.Stmt.ColumnInt64(index)
709+
useValType = i != 0 && i != 1
710+
case scan == _BLOB && val == _NULL:
711+
useValType = false
712+
}
713+
if useValType {
714+
scan = val
715+
}
716+
}
717+
718+
switch scan {
719+
case _INT:
720+
return reflect.TypeOf(int64(0))
721+
case _REAL:
722+
return reflect.TypeOf(float64(0))
723+
case _TEXT:
724+
return reflect.TypeOf("")
725+
case _BLOB:
726+
return reflect.TypeOf([]byte{})
727+
case _BOOL:
728+
return reflect.TypeOf(false)
729+
case _TIME:
730+
return reflect.TypeOf(time.Time{})
731+
default:
732+
return reflect.TypeOf((*any)(nil)).Elem()
733+
}
734+
}
735+
657736
func (r *rows) Next(dest []driver.Value) error {
658737
old := r.Stmt.Conn().SetInterrupt(r.ctx)
659738
defer r.Stmt.Conn().SetInterrupt(old)

driver/driver_test.go

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"errors"
88
"math"
99
"net/url"
10+
"reflect"
1011
"testing"
1112
"time"
1213

@@ -365,3 +366,104 @@ func Test_time(t *testing.T) {
365366
})
366367
}
367368
}
369+
370+
func Test_ColumnType_ScanType(t *testing.T) {
371+
var (
372+
INT = reflect.TypeOf(int64(0))
373+
REAL = reflect.TypeOf(float64(0))
374+
TEXT = reflect.TypeOf("")
375+
BLOB = reflect.TypeOf([]byte{})
376+
BOOL = reflect.TypeOf(false)
377+
TIME = reflect.TypeOf(time.Time{})
378+
ANY = reflect.TypeOf((*any)(nil)).Elem()
379+
)
380+
381+
t.Parallel()
382+
tmp := memdb.TestDB(t)
383+
384+
db, err := sql.Open("sqlite3", tmp)
385+
if err != nil {
386+
t.Fatal(err)
387+
}
388+
defer db.Close()
389+
390+
_, err = db.Exec(`
391+
CREATE TABLE test (
392+
col_int INTEGER,
393+
col_real REAL,
394+
col_text TEXT,
395+
col_blob BLOB,
396+
col_bool BOOLEAN,
397+
col_time DATETIME,
398+
col_decimal DECIMAL
399+
);
400+
INSERT INTO test VALUES
401+
(1, 1, 1, 1, 1, 1, 1),
402+
(2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0),
403+
('1', '1', '1', '1', '1', '1', '1'),
404+
('x', 'x', 'x', 'x', 'x', 'x', 'x'),
405+
(x'', x'', x'', x'', x'', x'', x''),
406+
('2006-01-02T15:04:05Z', '2006-01-02T15:04:05Z', '2006-01-02T15:04:05Z', '2006-01-02T15:04:05Z',
407+
'2006-01-02T15:04:05Z', '2006-01-02T15:04:05Z', '2006-01-02T15:04:05Z'),
408+
(TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE),
409+
(NULL, NULL, NULL, NULL, NULL, NULL, NULL);
410+
`)
411+
if err != nil {
412+
t.Fatal(err)
413+
}
414+
415+
rows, err := db.Query(`SELECT * FROM test`)
416+
if err != nil {
417+
t.Fatal(err)
418+
}
419+
defer rows.Close()
420+
421+
cols, err := rows.ColumnTypes()
422+
if err != nil {
423+
t.Fatal(err)
424+
}
425+
426+
want := [][]reflect.Type{
427+
{INT, REAL, TEXT, BLOB, BOOL, TIME, ANY},
428+
{INT, REAL, TEXT, INT, BOOL, TIME, INT},
429+
{INT, REAL, TEXT, REAL, INT, TIME, INT},
430+
{INT, REAL, TEXT, TEXT, BOOL, TIME, INT},
431+
{TEXT, TEXT, TEXT, TEXT, TEXT, TEXT, TEXT},
432+
{BLOB, BLOB, BLOB, BLOB, BLOB, BLOB, BLOB},
433+
{TEXT, TEXT, TEXT, TEXT, TEXT, TIME, TEXT},
434+
{INT, REAL, TEXT, INT, BOOL, TIME, INT},
435+
{ANY, ANY, ANY, BLOB, ANY, ANY, ANY},
436+
}
437+
for j, c := range cols {
438+
got := c.ScanType()
439+
if got != want[0][j] {
440+
t.Errorf("want %v, got %v, at column %d", want[0][j], got, j)
441+
}
442+
}
443+
444+
dest := make([]any, len(cols))
445+
for i := 1; rows.Next(); i++ {
446+
cols, err := rows.ColumnTypes()
447+
if err != nil {
448+
t.Fatal(err)
449+
}
450+
451+
for j, c := range cols {
452+
got := c.ScanType()
453+
if got != want[i][j] {
454+
t.Errorf("want %v, got %v, at row %d column %d", want[i][j], got, i, j)
455+
}
456+
dest[j] = reflect.New(got).Interface()
457+
}
458+
459+
err = rows.Scan(dest...)
460+
if err != nil {
461+
t.Error(err)
462+
}
463+
}
464+
465+
err = rows.Err()
466+
if err != nil {
467+
t.Fatal(err)
468+
}
469+
}

driver/time.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
package driver
22

3-
import (
4-
"time"
5-
)
3+
import "time"
64

75
// Convert a string in [time.RFC3339Nano] format into a [time.Time]
86
// if it roundtrips back to the same string.

0 commit comments

Comments
 (0)