Skip to content

Commit

Permalink
Fix decoding nested composite types.
Browse files Browse the repository at this point in the history
When using the pgtype.Value Set method, we have to use interface{} slices all
the way down. Create "assigner" functions that create an []interface slice for
composite and array types that we can use from the top level encoder functions.

Relates to #29.
  • Loading branch information
jschaf committed Apr 21, 2021
1 parent b25f3a5 commit 21dddee
Show file tree
Hide file tree
Showing 10 changed files with 320 additions and 51 deletions.
4 changes: 4 additions & 0 deletions example/complex_params/query.sql
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
-- name: ParamNested1 :one
SELECT pggen.arg('dimensions')::dimensions;

-- name: ParamNested2 :one
SELECT pggen.arg('image')::product_image_type;

89 changes: 85 additions & 4 deletions example/complex_params/query.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 32 additions & 1 deletion example/complex_params/query.sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func TestNewQuerier_ParamNested1(t *testing.T) {
assert.Equal(t, wantDim, row)
})

t.Run("ArrayNested2Batch", func(t *testing.T) {
t.Run("ParamNested1Batch", func(t *testing.T) {
batch := &pgx.Batch{}
q.ParamNested1Batch(batch, wantDim)
results := conn.SendBatch(ctx, batch)
Expand All @@ -35,3 +35,34 @@ func TestNewQuerier_ParamNested1(t *testing.T) {
assert.Equal(t, wantDim, row)
})
}

func TestNewQuerier_ParamNested2(t *testing.T) {
conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"})
defer cleanup()

q := NewQuerier(conn)
ctx := context.Background()

wantDim := Dimensions{Width: 77, Height: 77}
wantImg := ProductImageType{
Source: "src",
Dimensions: wantDim,
}

t.Run("ParamNested2", func(t *testing.T) {
row, err := q.ParamNested2(ctx, wantImg)
require.NoError(t, err)
assert.Equal(t, wantImg, row)
})

t.Run("ParamNested2Batch", func(t *testing.T) {
batch := &pgx.Batch{}
q.ParamNested2Batch(batch, wantImg)
results := conn.SendBatch(ctx, batch)
row, err := q.ParamNested2Scan(results)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, wantImg, row)
})
}
37 changes: 24 additions & 13 deletions internal/codegen/golang/declarer.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,28 +46,39 @@ func (d DeclarerSet) ListAll() []Declarer {
// the input parameters. Returns nil if no declarers are needed.
func FindInputDeclarers(typ gotype.Type) DeclarerSet {
decls := NewDeclarerSet()
findInputDeclsHelper(typ, decls, false)
findOutputDeclsHelper(typ, decls, false) // inputs depend on output transcoders
// Only top level types need the encoder. Descendant types need the assigner.
switch typ := typ.(type) {
case gotype.CompositeType:
decls.AddAll(NewCompositeEncoderDeclarer(typ))
case gotype.ArrayType:
switch typ.Elem.(type) {
case gotype.CompositeType, gotype.EnumType:
decls.AddAll(NewArrayEncoderDeclarer(typ))
}
}
findInputDeclsHelper(typ, decls /*isFirst*/, true)
// Inputs depend on output transcoders.
findOutputDeclsHelper(typ, decls /*hadCompositeParent*/, false)
return decls
}

func findInputDeclsHelper(typ gotype.Type, decls DeclarerSet, hadCompositeParent bool) {
func findInputDeclsHelper(typ gotype.Type, decls DeclarerSet, isFirst bool) {
switch typ := typ.(type) {
case gotype.CompositeType:
decls.AddAll(
NewTextEncoderDeclarer(),
NewCompositeEncoderDeclarer(typ),
)
decls.AddAll(NewTextEncoderDeclarer())
if !isFirst {
decls.AddAll(NewCompositeAssignerDeclarer(typ))
}
for _, childType := range typ.FieldTypes {
findInputDeclsHelper(childType, decls, true)
findInputDeclsHelper(childType, decls, false)
}

case gotype.ArrayType:
decls.AddAll(
NewTextEncoderDeclarer(),
NewArrayEncoderDeclarer(typ),
)
findInputDeclsHelper(typ.Elem, decls, hadCompositeParent)
decls.AddAll(NewTextEncoderDeclarer())
if !isFirst {
decls.AddAll(NewArrayAssignerDeclarer(typ))
}
findInputDeclsHelper(typ.Elem, decls, false)

default:
return
Expand Down
59 changes: 58 additions & 1 deletion internal/codegen/golang/declarer_array.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,25 @@ import (
"strings"
)

// NameArrayDecoderFunc returns the function name that creates a
// pgtype.ValueTranscoder for the array type that's used to decode rows returned
// by Postgres.
func NameArrayDecoderFunc(typ gotype.ArrayType) string {
return "new" + strings.TrimPrefix(typ.Name, "[]") + "ArrayDecoder"
}

// NameArrayEncoderFunc returns the function name that creates a textEncoder for
// the array type that's used to encode query parameters. This function is only
// necessary for top-level types. Descendant types use the assigner functions.
func NameArrayEncoderFunc(typ gotype.ArrayType) string {
return "encode" + typ.Name
return "new" + typ.Name + "ArrayEncoder"
}

// NameArrayAssignerFunc returns the function name that create the []interface{}
// array for the array type so that we can use it with a parent encoder
// function, like NameCompositeEncoderFunc, in the pgtype.Value Set call.
func NameArrayAssignerFunc(typ gotype.ArrayType) string {
return "assign" + typ.Name + "Array"
}

// ArrayDecoderDeclarer declares a new Go function that creates a pgx
Expand Down Expand Up @@ -124,3 +137,47 @@ func (c ArrayEncoderDeclarer) Declare(string) (string, error) {
sb.WriteString("}")
return sb.String(), nil
}

// ArrayAssignerDeclarer declares a new Go function that returns all fields
// as a generic array: []interface{}. Necessary because we can only set
// pgtype.ArrayType from a []interface{}.
type ArrayAssignerDeclarer struct {
typ gotype.ArrayType
}

func NewArrayAssignerDeclarer(typ gotype.ArrayType) ArrayAssignerDeclarer {
return ArrayAssignerDeclarer{typ}
}

func (c ArrayAssignerDeclarer) DedupeKey() string {
return "array_assigner::" + c.typ.Name
}

func (c ArrayAssignerDeclarer) Declare(string) (string, error) {
funcName := NameArrayAssignerFunc(c.typ)
sb := &strings.Builder{}
sb.Grow(256)

// Doc comment
sb.WriteString("// ")
sb.WriteString(funcName)
sb.WriteString(" returns all elements for the Postgres '")
sb.WriteString(c.typ.PgArray.Name)
sb.WriteString("' array type as a\n")
sb.WriteString("// slice of interface{} for use with the pgtype.Value Set method.\n")

// Function signature
sb.WriteString("func ")
sb.WriteString(funcName)
sb.WriteString("(ps ")
sb.WriteString(c.typ.Name)
sb.WriteString(") []interface{} {\n\t")

// Function body
sb.WriteString("elems := make([]interface{}, len(p))\n\t")
sb.WriteString("for i, p := range ps {\n\t\t")
sb.WriteString("elems[i] = p\n\t")
sb.WriteString("}\n")
sb.WriteString("}")
return sb.String(), nil
}
Loading

0 comments on commit 21dddee

Please sign in to comment.