Skip to content

Commit

Permalink
fix: postgres syntax errors for pointers and slices #877
Browse files Browse the repository at this point in the history
  • Loading branch information
rfarrjr committed Jan 22, 2025
1 parent dbae5e6 commit d44f743
Show file tree
Hide file tree
Showing 5 changed files with 261 additions and 14 deletions.
70 changes: 68 additions & 2 deletions dialect/pgdialect/array.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package pgdialect
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"fmt"
"math"
"reflect"
Expand All @@ -11,6 +12,7 @@ import (

"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/internal/parser"
"github.com/uptrace/bun/schema"
)

Expand Down Expand Up @@ -109,6 +111,8 @@ func (d *Dialect) arrayAppender(typ reflect.Type) schema.AppenderFunc {
}

return func(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
brackets := "{}"

kind := v.Kind()
switch kind {
case reflect.Ptr, reflect.Slice:
Expand All @@ -121,18 +125,35 @@ func (d *Dialect) arrayAppender(typ reflect.Type) schema.AppenderFunc {
v = v.Elem()
}

b = append(b, "'{"...)
if kind == reflect.Slice {
elemType := v.Type().Elem()

for elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
}

if elemType.Kind() == reflect.Struct {
brackets = "[]"
}
}

b = append(b, '\'', brackets[0])

ln := v.Len()
for i := 0; i < ln; i++ {
elem := v.Index(i)

for elem.Kind() == reflect.Ptr {
elem = elem.Elem()
}

if i > 0 {
b = append(b, ',')
}
b = appendElem(fmter, b, elem)
}

b = append(b, "}'"...)
b = append(b, brackets[1], '\'')

return b
}
Expand All @@ -143,8 +164,12 @@ func (d *Dialect) arrayElemAppender(typ reflect.Type) schema.AppenderFunc {
return arrayAppendDriverValue
}
switch typ.Kind() {
case reflect.Ptr:
return d.arrayElemAppender(typ.Elem())
case reflect.String:
return appendStringElemValue
case reflect.Struct:
return arrayAppendJSONValue
case reflect.Slice:
if typ.Elem().Kind() == reflect.Uint8 {
return appendBytesElemValue
Expand All @@ -153,6 +178,47 @@ func (d *Dialect) arrayElemAppender(typ reflect.Type) schema.AppenderFunc {
return schema.Appender(d, typ)
}

func appendJSONUnquoted(b, jsonb []byte) []byte {
p := parser.New(jsonb)
for p.Valid() {
c := p.Read()
switch c {
case '"':
b = append(b, '"')
case '\'':
b = append(b, "''"...)
case '\000':
continue
case '\\':
if err := p.SkipPrefix([]byte("u0000")); err == nil {
b = append(b, `\\u0000`...)
} else {
b = append(b, '\\')
if p.Valid() {
b = append(b, p.Read())
}
}
default:
b = append(b, c)
}
}

return b
}

func arrayAppendJSONValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
bb, err := json.Marshal(v.Interface())
if err != nil {
return dialect.AppendError(b, err)
}

if len(bb) > 0 && bb[len(bb)-1] == '\n' {
bb = bb[:len(bb)-1]
}

return appendJSONUnquoted(b, bb)
}

func appendStringElemValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
return appendStringElem(b, v.String())
}
Expand Down
45 changes: 33 additions & 12 deletions dialect/pgdialect/array_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,18 @@ type arrayParser struct {

elem []byte
err error

isJson bool
}

func newArrayParser(b []byte) *arrayParser {
p := new(arrayParser)

if len(b) < 2 || b[0] != '{' || b[len(b)-1] != '}' {
if len(b) < 2 || (b[0] != '{' && b[0] != '[') || (b[len(b)-1] != '}' && b[len(b)-1] != ']') {
p.err = fmt.Errorf("pgdialect: can't parse array: %q", b)
return p
}
p.isJson = b[0] == '['

p.p.Reset(b[1 : len(b)-1])
return p
Expand Down Expand Up @@ -51,7 +54,7 @@ func (p *arrayParser) readNext() error {
}

switch ch {
case '}':
case '}', ']':
return io.EOF
case '"':
b, err := p.p.ReadSubstring(ch)
Expand All @@ -78,16 +81,34 @@ func (p *arrayParser) readNext() error {
p.elem = rng
return nil
default:
lit := p.p.ReadLiteral(ch)
if bytes.Equal(lit, []byte("NULL")) {
lit = nil
if ch == '{' && p.isJson {
json, err := p.p.ReadJSON()
if err != nil {
return err
}

for {
if p.p.Peek() == ',' || p.p.Peek() == ' ' {
p.p.Advance()
} else {
break
}
}

p.elem = json
return nil
} else {
lit := p.p.ReadLiteral(ch)
if bytes.Equal(lit, []byte("NULL")) {
lit = nil
}

if p.p.Peek() == ',' {
p.p.Advance()
}

p.elem = lit
return nil
}

if p.p.Peek() == ',' {
p.p.Advance()
}

p.elem = lit
return nil
}
}
4 changes: 4 additions & 0 deletions dialect/pgdialect/array_parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ func TestArrayParser(t *testing.T) {
{`{"1","2"}`, []string{"1", "2"}},
{`{"{1}","{2}"}`, []string{"{1}", "{2}"}},
{`{[1,2),[3,4)}`, []string{"[1,2)", "[3,4)"}},

{`[]`, []string{}},
{`[{"'\"[]"}]`, []string{`{"'\"[]"}`}},
{`[{"id": 1}, {"id":2, "name":"bob"}]`, []string{"{\"id\": 1}", "{\"id\":2, \"name\":\"bob\"}"}},
}

for i, test := range tests {
Expand Down
120 changes: 120 additions & 0 deletions dialect/pgdialect/array_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package pgdialect

import (
"reflect"
"testing"

"github.com/uptrace/bun/schema"
)

type tag struct {
ID int32
Label string
}

func ptr[T any](v T) *T {
return &v
}

func TestArrayAppend(t *testing.T) {
tcases := []struct {
input interface{}
out string
}{
{
input: []byte{1, 2},
out: `'{1,2}'`,
},
{
input: []*byte{ptr(byte(1)), ptr(byte(2))},
out: `'{1,2}'`,
},
{
input: []int{1, 2},
out: `'{1,2}'`,
},
{
input: []*int{ptr(1), ptr(2)},
out: `'{1,2}'`,
},
{
input: []string{"foo", "bar"},
out: `'{"foo","bar"}'`,
},
{
input: []*string{ptr("foo"), ptr("bar")},
out: `'{"foo","bar"}'`,
},
{
input: []tag{{1, "foo1"}, {2, "bar"}},
out: `'[{"ID":1,"Label":"foo1"},{"ID":2,"Label":"bar"}]'`,
},
{
input: &[]tag{{1, "foo2"}, {2, "bar"}},
out: `'[{"ID":1,"Label":"foo2"},{"ID":2,"Label":"bar"}]'`,
},
{
input: &[]*tag{{1, "foo3"}, {2, "bar"}},
out: `'[{"ID":1,"Label":"foo3"},{"ID":2,"Label":"bar"}]'`,
},
{
input: ptr(&[]*tag{{1, "foo4"}, {2, "bar"}}),
out: `'[{"ID":1,"Label":"foo4"},{"ID":2,"Label":"bar"}]'`,
},
{
input: []**tag{ptr(ptr(tag{1, "foo5"})), ptr(ptr(tag{2, "bar"}))},
out: `'[{"ID":1,"Label":"foo5"},{"ID":2,"Label":"bar"}]'`,
},
{
input: &[]**tag{ptr(ptr(tag{1, "foo6"})), ptr(ptr(tag{2, "bar"}))},
out: `'[{"ID":1,"Label":"foo6"},{"ID":2,"Label":"bar"}]'`,
},
{
input: ptr(&[]**tag{ptr(ptr(tag{1, "foo7"})), ptr(ptr(tag{2, "bar"}))}),
out: `'[{"ID":1,"Label":"foo7"},{"ID":2,"Label":"bar"}]'`,
},
{
input: ptr(ptr(&[]**tag{ptr(ptr(tag{1, "foo8"})), ptr(ptr(tag{2, "bar"}))})),
out: `'[{"ID":1,"Label":"foo8"},{"ID":2,"Label":"bar"}]'`,
},
}

for _, tcase := range tcases {
out, err := Array(tcase.input).AppendQuery(schema.NewFormatter(New()), []byte{})
if err != nil {
t.Fatal(err)
}

if string(out) != tcase.out {
t.Errorf("expected output to be %s, was %s", tcase.out, string(out))
}
}
}

func Test_appendJSONUnquoted(t *testing.T) {
type args struct {
b []byte
jsonb []byte
}
tests := []struct {
name string
args args
want []byte
}{
{
name: "Test foo",
args: args{
b: []byte{},
jsonb: []byte(`{"foo":'\bar\u0000\'}\0`),
},
want: []byte(`{"foo":''\bar\\u0000\'}\0`),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := appendJSONUnquoted(tt.args.b, tt.args.jsonb); !reflect.DeepEqual(got, tt.want) {
t.Errorf("appendJSONUnquoted() = %v, want %v", got, tt.want)
}
})
}
}
36 changes: 36 additions & 0 deletions dialect/pgdialect/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,39 @@ func (p *pgparser) ReadRange(ch byte) ([]byte, error) {

return p.buf, nil
}

func (p *pgparser) ReadJSON() ([]byte, error) {
p.Unread()

c, err := p.ReadByte()
if err != nil {
return nil, err
}

p.buf = p.buf[:0]

depth := 0
for {
switch c {
case '{':
depth++
case '}':
depth--
}

p.buf = append(p.buf, c)

if depth == 0 {
break
}

next, err := p.ReadByte()
if err != nil {
return nil, err
}

c = next
}

return p.buf, nil
}

0 comments on commit d44f743

Please sign in to comment.