Skip to content

Commit 911e497

Browse files
committed
Improved quoting.
1 parent 3469460 commit 911e497

File tree

4 files changed

+47
-11
lines changed

4 files changed

+47
-11
lines changed

embed/bcw2/init.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
// import _ "github.com/ncruces/go-sqlite3/embed/bcw2"
77
//
88
// [BEGIN CONCURRENT]: https://sqlite.org/src/doc/begin-concurrent/doc/begin_concurrent.md
9-
// [Wal2]: https://www.sqlite.org/cgi/src/doc/wal2/doc/wal2.md
9+
// [Wal2]: https://sqlite.org/cgi/src/doc/wal2/doc/wal2.md
1010
package bcw2
1111

1212
import (

quote.go

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package sqlite3
33
import (
44
"bytes"
55
"math"
6+
"reflect"
67
"strconv"
78
"strings"
89
"time"
@@ -13,6 +14,9 @@ import (
1314

1415
// Quote escapes and quotes a value
1516
// making it safe to embed in SQL text.
17+
// Strings with embedded NUL characters are truncated.
18+
//
19+
// https://sqlite.org/lang_corefunc.html#quote
1620
func Quote(value any) string {
1721
switch v := value.(type) {
1822
case nil:
@@ -42,8 +46,8 @@ func Quote(value any) string {
4246
return "'" + v.Format(time.RFC3339Nano) + "'"
4347

4448
case string:
45-
if strings.IndexByte(v, 0) >= 0 {
46-
break
49+
if i := strings.IndexByte(v, 0); i >= 0 {
50+
v = v[:i]
4751
}
4852

4953
buf := make([]byte, 2+len(v)+strings.Count(v, "'"))
@@ -75,22 +79,46 @@ func Quote(value any) string {
7579
return unsafe.String(&buf[0], len(buf))
7680

7781
case ZeroBlob:
78-
if v > ZeroBlob(1e9-3)/2 {
79-
break
80-
}
81-
8282
buf := bytes.Repeat([]byte("0"), int(3+2*int64(v)))
8383
buf[1] = '\''
8484
buf[0] = 'x'
8585
buf[len(buf)-1] = '\''
8686
return unsafe.String(&buf[0], len(buf))
8787
}
8888

89+
v := reflect.ValueOf(value)
90+
k := v.Kind()
91+
92+
if k == reflect.Interface || k == reflect.Pointer {
93+
if v.IsNil() {
94+
return "NULL"
95+
}
96+
v = v.Elem()
97+
k = v.Kind()
98+
}
99+
100+
switch {
101+
case v.CanInt():
102+
return strconv.FormatInt(v.Int(), 10)
103+
case v.CanUint():
104+
return strconv.FormatUint(v.Uint(), 10)
105+
case v.CanFloat():
106+
return Quote(v.Float())
107+
case k == reflect.Bool:
108+
return Quote(v.Bool())
109+
case k == reflect.String:
110+
return Quote(v.String())
111+
case (k == reflect.Slice || k == reflect.Array && v.CanAddr()) &&
112+
v.Type().Elem().Kind() == reflect.Uint8:
113+
return Quote(v.Bytes())
114+
}
115+
89116
panic(util.ValueErr)
90117
}
91118

92119
// QuoteIdentifier escapes and quotes an identifier
93120
// making it safe to embed in SQL text.
121+
// Strings with embedded NUL characters panic.
94122
func QuoteIdentifier(id string) string {
95123
if strings.IndexByte(id, 0) >= 0 {
96124
panic(util.ValueErr)

stmt.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ func (s *Stmt) BindValue(param int, value Value) error {
379379

380380
// DataCount resets the number of columns in a result set.
381381
//
382-
// https://www.sqlite.org/c3ref/data_count.html
382+
// https://sqlite.org/c3ref/data_count.html
383383
func (s *Stmt) DataCount() int {
384384
r := s.c.call("sqlite3_data_count",
385385
uint64(s.handle))

tests/quote_test.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package tests
22

33
import (
4+
"database/sql"
5+
"encoding/json"
46
"math"
57
"reflect"
68
"testing"
@@ -19,8 +21,8 @@ func TestQuote(t *testing.T) {
1921
{`a'bc`, "'a''bc'"},
2022
{"\x07bc", "'\abc'"},
2123
{"\x1c\n", "'\x1c\n'"},
24+
{"\xB0\x00\x0B", "'\xB0'"},
2225
{[]byte("\xB0\x00\x0B"), "x'B0000B'"},
23-
{"\xB0\x00\x0B", ""},
2426

2527
{0, "0"},
2628
{true, "1"},
@@ -33,7 +35,13 @@ func TestQuote(t *testing.T) {
3335
{int64(math.MaxInt64), "9223372036854775807"},
3436
{time.Unix(0, 0).UTC(), "'1970-01-01T00:00:00Z'"},
3537
{sqlite3.ZeroBlob(4), "x'00000000'"},
36-
{sqlite3.ZeroBlob(1e9), ""},
38+
{int8(0), "0"},
39+
{uint(0), "0"},
40+
{float32(0), "0"},
41+
{(*string)(nil), "NULL"},
42+
{json.Number("0"), "'0'"},
43+
{&sql.RawBytes{'0'}, "x'30'"},
44+
{t, ""}, // panic
3745
}
3846

3947
for _, tt := range tests {
@@ -62,7 +70,7 @@ func TestQuoteIdentifier(t *testing.T) {
6270
{`a'bc`, `"a'bc"`},
6371
{"\x07bc", "\"\abc\""},
6472
{"\x1c\n", "\"\x1c\n\""},
65-
{"\xB0\x00\x0B", ""},
73+
{"\xB0\x00\x0B", ""}, // panic
6674
}
6775

6876
for _, tt := range tests {

0 commit comments

Comments
 (0)