Skip to content

Commit

Permalink
Use DefaultQueryExecMode in CopyFrom
Browse files Browse the repository at this point in the history
CopyFrom had to create a prepared statement to get the OIDs of the data
types that were going to be copied into the table. Every COPY operation
required an extra round trips to retrieve the type information. There
was no way to customize this behavior.

By leveraging the QueryExecMode feature, like in `Conn.Query`, users can
specify if they want to cache the prepared statements, execute
them on every request (like the old behavior), or bypass the prepared
statement relying on the pgtype.Map to get the type information.

The `QueryExecMode` behave exactly like in `Conn.Query` in the way the
data type OIDs are fetched, meaning that:

- `QueryExecModeCacheStatement`: caches the statement.
- `QueryExecModeCacheDescribe`: caches the statement and assumes they do
  not change.
- `QueryExecModeDescribeExec`: gets the statement description on every
  execution. This is like to the old behavior of `CopyFrom`.
- `QueryExecModeExec` and `QueryExecModeSimpleProtocol`: maintain the
  same behavior as before, which is the same as `QueryExecModeDescribeExec`.
  It will keep getting the statement description on every execution

The `QueryExecMode` can only be set via
`ConnConfig.DefaultQueryExecMode`, unlike `Conn.Query` there's no
support for specifying the `QueryExecMode` via optional arguments
in the function signature.
  • Loading branch information
alejandrodnm authored and jackc committed Dec 23, 2022
1 parent 456a242 commit c4ac6d8
Show file tree
Hide file tree
Showing 6 changed files with 253 additions and 42 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,5 @@ _testmain.go

.envrc
/.testdb

.DS_Store
83 changes: 46 additions & 37 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -721,43 +721,10 @@ optionLoop:
sd, explicitPreparedStatement := c.preparedStatements[sql]
if sd != nil || mode == QueryExecModeCacheStatement || mode == QueryExecModeCacheDescribe || mode == QueryExecModeDescribeExec {
if sd == nil {
switch mode {
case QueryExecModeCacheStatement:
if c.statementCache == nil {
err = errDisabledStatementCache
rows.fatal(err)
return rows, err
}
sd = c.statementCache.Get(sql)
if sd == nil {
sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql)
if err != nil {
rows.fatal(err)
return rows, err
}
c.statementCache.Put(sd)
}
case QueryExecModeCacheDescribe:
if c.descriptionCache == nil {
err = errDisabledDescriptionCache
rows.fatal(err)
return rows, err
}
sd = c.descriptionCache.Get(sql)
if sd == nil {
sd, err = c.Prepare(ctx, "", sql)
if err != nil {
rows.fatal(err)
return rows, err
}
c.descriptionCache.Put(sd)
}
case QueryExecModeDescribeExec:
sd, err = c.Prepare(ctx, "", sql)
if err != nil {
rows.fatal(err)
return rows, err
}
sd, err = c.getStatementDescription(ctx, mode, sql)
if err != nil {
rows.fatal(err)
return rows, err
}
}

Expand Down Expand Up @@ -827,6 +794,48 @@ optionLoop:
return rows, rows.err
}

// getStatementDescription returns the statement description of the sql query
// according to the given mode.
//
// If the mode is one that doesn't require to know the param and result OIDs
// then nil is returned without error.
func (c *Conn) getStatementDescription(
ctx context.Context,
mode QueryExecMode,
sql string,
) (sd *pgconn.StatementDescription, err error) {

switch mode {
case QueryExecModeCacheStatement:
if c.statementCache == nil {
return nil, errDisabledStatementCache
}
sd = c.statementCache.Get(sql)
if sd == nil {
sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql)
if err != nil {
return nil, err
}
c.statementCache.Put(sd)
}
case QueryExecModeCacheDescribe:
if c.descriptionCache == nil {
return nil, errDisabledDescriptionCache
}
sd = c.descriptionCache.Get(sql)
if sd == nil {
sd, err = c.Prepare(ctx, "", sql)
if err != nil {
return nil, err
}
c.descriptionCache.Put(sd)
}
case QueryExecModeDescribeExec:
return c.Prepare(ctx, "", sql)
}
return sd, err
}

// QueryRow is a convenience wrapper over Query. Any error that occurs while
// querying is deferred until calling Scan on the returned Row. That Row will
// error with ErrNoRows if no rows are returned.
Expand Down
28 changes: 25 additions & 3 deletions copy_from.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ type copyFrom struct {
columnNames []string
rowSrc CopyFromSource
readerErrChan chan error
mode QueryExecMode
}

func (ct *copyFrom) run(ctx context.Context) (int64, error) {
Expand All @@ -105,9 +106,29 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) {
}
quotedColumnNames := cbuf.String()

sd, err := ct.conn.Prepare(ctx, "", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName))
if err != nil {
return 0, err
var sd *pgconn.StatementDescription
switch ct.mode {
case QueryExecModeExec, QueryExecModeSimpleProtocol:
// These modes don't support the binary format. Before the inclusion of the
// QueryExecModes, Conn.Prepare was called on every COPY operation to get
// the OIDs. These prepared statements were not cached.
//
// Since that's the same behavior provided by QueryExecModeDescribeExec,
// we'll default to that mode.
ct.mode = QueryExecModeDescribeExec
fallthrough
case QueryExecModeCacheStatement, QueryExecModeCacheDescribe, QueryExecModeDescribeExec:
var err error
sd, err = ct.conn.getStatementDescription(
ctx,
ct.mode,
fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName),
)
if err != nil {
return 0, fmt.Errorf("statement description failed: %w", err)
}
default:
return 0, fmt.Errorf("unknown QueryExecMode: %v", ct.mode)
}

r, w := io.Pipe()
Expand Down Expand Up @@ -208,6 +229,7 @@ func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames [
columnNames: columnNames,
rowSrc: rowSrc,
readerErrChan: make(chan error),
mode: c.config.DefaultQueryExecMode,
}

return ct.run(ctx)
Expand Down
125 changes: 124 additions & 1 deletion copy_from_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,129 @@ import (
"github.com/stretchr/testify/require"
)

func TestConnCopyWithAllQueryExecModes(t *testing.T) {
for _, mode := range pgxtest.AllQueryExecModes {
t.Run(mode.String(), func(t *testing.T) {
t.Parallel()

cfg := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
cfg.DefaultQueryExecMode = mode
conn := mustConnect(t, cfg)
defer closeConn(t, conn)

mustExec(t, conn, `create temporary table foo(
a int2,
b int4,
c int8,
d text,
e timestamptz
)`)

tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)

inputRows := [][]any{
{int16(0), int32(1), int64(2), "abc", tzedTime},
{nil, nil, nil, nil, nil},
}

copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e"}, pgx.CopyFromRows(inputRows))
if err != nil {
t.Errorf("Unexpected error for CopyFrom: %v", err)
}
if int(copyCount) != len(inputRows) {
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
}

rows, err := conn.Query(context.Background(), "select * from foo")
if err != nil {
t.Errorf("Unexpected error for Query: %v", err)
}

var outputRows [][]any
for rows.Next() {
row, err := rows.Values()
if err != nil {
t.Errorf("Unexpected error for rows.Values(): %v", err)
}
outputRows = append(outputRows, row)
}

if rows.Err() != nil {
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
}

if !reflect.DeepEqual(inputRows, outputRows) {
t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
}

ensureConnValid(t, conn)
})
}
}

func TestConnCopyWithKnownOIDQueryExecModes(t *testing.T) {

for _, mode := range pgxtest.KnownOIDQueryExecModes {
t.Run(mode.String(), func(t *testing.T) {
t.Parallel()

cfg := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
cfg.DefaultQueryExecMode = mode
conn := mustConnect(t, cfg)
defer closeConn(t, conn)

mustExec(t, conn, `create temporary table foo(
a int2,
b int4,
c int8,
d varchar,
e text,
f date,
g timestamptz
)`)

tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)

inputRows := [][]any{
{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime},
{nil, nil, nil, nil, nil, nil, nil},
}

copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows))
if err != nil {
t.Errorf("Unexpected error for CopyFrom: %v", err)
}
if int(copyCount) != len(inputRows) {
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
}

rows, err := conn.Query(context.Background(), "select * from foo")
if err != nil {
t.Errorf("Unexpected error for Query: %v", err)
}

var outputRows [][]any
for rows.Next() {
row, err := rows.Values()
if err != nil {
t.Errorf("Unexpected error for rows.Values(): %v", err)
}
outputRows = append(outputRows, row)
}

if rows.Err() != nil {
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
}

if !reflect.DeepEqual(inputRows, outputRows) {
t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
}

ensureConnValid(t, conn)
})
}
}

func TestConnCopyFromSmall(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -220,7 +343,7 @@ func TestConnCopyFromEnum(t *testing.T) {
conn.TypeMap().RegisterType(typ)
}

_, err = tx.Exec(ctx, `create table foo(
_, err = tx.Exec(ctx, `create temporary table foo(
a text,
b color,
c fruit,
Expand Down
55 changes: 55 additions & 0 deletions pgconn/pgconn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (
"testing"
"time"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/internal/pgio"
"github.com/jackc/pgx/v5/internal/pgmock"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgproto3"
Expand Down Expand Up @@ -1666,6 +1668,59 @@ func TestConnCopyFrom(t *testing.T) {
ensureConnValid(t, pgConn)
}

func TestConnCopyFromBinary(t *testing.T) {
t.Parallel()

pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)

_, err = pgConn.Exec(context.Background(), `create temporary table foo(
a int4,
b varchar
)`).ReadAll()
require.NoError(t, err)

buf := []byte{}
buf = append(buf, "PGCOPY\n\377\r\n\000"...)
buf = pgio.AppendInt32(buf, 0)
buf = pgio.AppendInt32(buf, 0)

inputRows := [][][]byte{}
for i := 0; i < 1000; i++ {
// Number of elements in the tuple
buf = pgio.AppendInt16(buf, int16(2))
a := i

// Length of element for column `a int4`
buf = pgio.AppendInt32(buf, 4)
buf, err = pgtype.NewMap().Encode(pgtype.Int4OID, pgx.BinaryFormatCode, a, buf)
require.NoError(t, err)

b := "foo " + strconv.Itoa(a) + " bar"
lenB := int32(len([]byte(b)))
// Length of element for column `b varchar`
buf = pgio.AppendInt32(buf, lenB)
buf, err = pgtype.NewMap().Encode(pgtype.VarcharOID, pgx.BinaryFormatCode, b, buf)
require.NoError(t, err)

inputRows = append(inputRows, [][]byte{[]byte(strconv.Itoa(a)), []byte(b)})
}

srcBuf := &bytes.Buffer{}
srcBuf.Write(buf)
ct, err := pgConn.CopyFrom(context.Background(), srcBuf, "COPY foo (a, b) FROM STDIN BINARY;")
require.NoError(t, err)
assert.Equal(t, int64(len(inputRows)), ct.RowsAffected())

result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read()
require.NoError(t, result.Err)

assert.Equal(t, inputRows, result.Rows)

ensureConnValid(t, pgConn)
}

func TestConnCopyFromCanceled(t *testing.T) {
t.Parallel()

Expand Down
2 changes: 1 addition & 1 deletion tracelog/tracelog_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ func TestLogCopyFrom(t *testing.T) {
return config
}

pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, pgxtest.KnownOIDQueryExecModes, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
_, err := conn.Exec(context.Background(), `create temporary table foo(a int4)`)
require.NoError(t, err)

Expand Down

0 comments on commit c4ac6d8

Please sign in to comment.