From acf6e1863e9423a8663c9582e98ee246443c7244 Mon Sep 17 00:00:00 2001 From: pratap0007 Date: Fri, 12 Jan 2024 14:51:49 +0530 Subject: [PATCH] Bump gorm.io/driver/postgres from 1.5.2 to 1.5.4 Signed-off-by: Shiv Verma --- go.mod | 4 +- go.sum | 8 +- vendor/github.com/jackc/pgx/v5/CHANGELOG.md | 49 ++ vendor/github.com/jackc/pgx/v5/README.md | 14 +- vendor/github.com/jackc/pgx/v5/batch.go | 28 +- vendor/github.com/jackc/pgx/v5/conn.go | 46 +- vendor/github.com/jackc/pgx/v5/doc.go | 8 +- .../pgx/v5/internal/nbconn/bufferqueue.go | 70 --- .../jackc/pgx/v5/internal/nbconn/nbconn.go | 520 ------------------ .../internal/nbconn/nbconn_fake_non_block.go | 11 - .../internal/nbconn/nbconn_real_non_block.go | 81 --- .../jackc/pgx/v5/pgconn/auth_scram.go | 4 +- .../github.com/jackc/pgx/v5/pgconn/config.go | 2 +- .../v5/pgconn/internal/bgreader/bgreader.go | 139 +++++ vendor/github.com/jackc/pgx/v5/pgconn/krb5.go | 2 +- .../github.com/jackc/pgx/v5/pgconn/pgconn.go | 373 +++++++++---- .../jackc/pgx/v5/pgproto3/frontend.go | 4 + .../github.com/jackc/pgx/v5/pgproto3/trace.go | 260 ++++----- .../github.com/jackc/pgx/v5/pgtype/array.go | 29 +- vendor/github.com/jackc/pgx/v5/pgtype/bool.go | 38 +- .../github.com/jackc/pgx/v5/pgtype/convert.go | 368 ------------- .../github.com/jackc/pgx/v5/pgtype/hstore.go | 423 +++++++------- vendor/github.com/jackc/pgx/v5/pgtype/json.go | 39 +- .../github.com/jackc/pgx/v5/pgtype/numeric.go | 37 +- .../github.com/jackc/pgx/v5/pgtype/pgtype.go | 268 ++------- .../jackc/pgx/v5/pgtype/pgtype_default.go | 223 ++++++++ .../github.com/jackc/pgx/v5/pgtype/point.go | 2 +- .../jackc/pgx/v5/pgtype/timestamp.go | 50 ++ vendor/github.com/jackc/pgx/v5/pgtype/uuid.go | 2 +- vendor/github.com/jackc/pgx/v5/rows.go | 64 ++- vendor/github.com/jackc/pgx/v5/tx.go | 9 +- .../driver/postgres/error_translator.go | 24 +- vendor/gorm.io/driver/postgres/migrator.go | 58 +- vendor/gorm.io/driver/postgres/postgres.go | 10 +- vendor/modules.txt | 6 +- 35 files changed, 1378 insertions(+), 1895 deletions(-) delete mode 100644 vendor/github.com/jackc/pgx/v5/internal/nbconn/bufferqueue.go delete mode 100644 vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn.go delete mode 100644 vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_fake_non_block.go delete mode 100644 vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_real_non_block.go create mode 100644 vendor/github.com/jackc/pgx/v5/pgconn/internal/bgreader/bgreader.go create mode 100644 vendor/github.com/jackc/pgx/v5/pgtype/pgtype_default.go diff --git a/go.mod b/go.mod index 848a130ae1..187e6701c1 100644 --- a/go.mod +++ b/go.mod @@ -31,7 +31,7 @@ require ( golang.org/x/term v0.16.0 golang.org/x/text v0.14.0 gopkg.in/h2non/gock.v1 v1.1.2 - gorm.io/driver/postgres v1.5.2 + gorm.io/driver/postgres v1.5.4 gorm.io/gorm v1.25.5 gotest.tools/v3 v3.5.1 k8s.io/apimachinery v0.28.3 @@ -99,7 +99,7 @@ require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect - github.com/jackc/pgx/v5 v5.3.1 // indirect + github.com/jackc/pgx/v5 v5.4.3 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/josharian/intern v1.0.0 // indirect diff --git a/go.sum b/go.sum index f7a1ae0827..edce65953a 100644 --- a/go.sum +++ b/go.sum @@ -728,8 +728,8 @@ github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/ github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgtype v1.14.0 h1:y+xUdabmyMkJLyApYuPj38mW+aAIqCe5uuBB51rH3Vw= github.com/jackc/pgx/v4 v4.18.1 h1:YP7G1KABtKpB5IHrO9vYwSrCOhs7p3uqhvhhQBptya0= -github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU= -github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8= +github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY= +github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA= github.com/jarcoal/httpmock v0.0.0-20180424175123-9c70cfe4a1da/go.mod h1:ks+b9deReOc7jgqp+e7LuFiCBH6Rm5hL32cLcEAArb4= github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 h1:BQSFePA1RWJOlocH6Fxy8MmwDt+yVQYULKfN0RoTN8A= github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99/go.mod h1:1lJo3i6rXxKeerYnT8Nvf0QmHCRC1n8sfWVwXF2Frvo= @@ -1854,8 +1854,8 @@ gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gorm.io/driver/postgres v1.5.2 h1:ytTDxxEv+MplXOfFe3Lzm7SjG09fcdb3Z/c056DTBx0= -gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBpCgl8= +gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo= +gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0= gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls= gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= diff --git a/vendor/github.com/jackc/pgx/v5/CHANGELOG.md b/vendor/github.com/jackc/pgx/v5/CHANGELOG.md index ec90631cd5..fb2304a2fb 100644 --- a/vendor/github.com/jackc/pgx/v5/CHANGELOG.md +++ b/vendor/github.com/jackc/pgx/v5/CHANGELOG.md @@ -1,3 +1,52 @@ +# 5.4.3 (August 5, 2023) + +* Fix: QCharArrayOID was defined with the wrong OID (Christoph Engelbert) +* Fix: connect_timeout for sslmode=allow|prefer (smaher-edb) +* Fix: pgxpool: background health check cannot overflow pool +* Fix: Check for nil in defer when sending batch (recover properly from panic) +* Fix: json scan of non-string pointer to pointer +* Fix: zeronull.Timestamptz should use pgtype.Timestamptz +* Fix: NewConnsCount was not correctly counting connections created by Acquire directly. (James Hartig) +* RowTo(AddrOf)StructByPos ignores fields with "-" db tag +* Optimization: improve text format numeric parsing (horpto) + +# 5.4.2 (July 11, 2023) + +* Fix: RowScanner errors are fatal to Rows +* Fix: Enable failover efforts when pg_hba.conf disallows non-ssl connections (Brandon Kauffman) +* Hstore text codec internal improvements (Evan Jones) +* Fix: Stop timers for background reader when not in use. Fixes memory leak when closing connections (Adrian-Stefan Mares) +* Fix: Stop background reader as soon as possible. +* Add PgConn.SyncConn(). This combined with the above fix makes it safe to directly use the underlying net.Conn. + +# 5.4.1 (June 18, 2023) + +* Fix: concurrency bug with pgtypeDefaultMap and simple protocol (Lev Zakharov) +* Add TxOptions.BeginQuery to allow overriding the default BEGIN query + +# 5.4.0 (June 14, 2023) + +* Replace platform specific syscalls for non-blocking IO with more traditional goroutines and deadlines. This returns to the v4 approach with some additional improvements and fixes. This restores the ability to use a pgx.Conn over an ssh.Conn as well as other non-TCP or Unix socket connections. In addition, it is a significantly simpler implementation that is less likely to have cross platform issues. +* Optimization: The default type registrations are now shared among all connections. This saves about 100KB of memory per connection. `pgtype.Type` and `pgtype.Codec` values are now required to be immutable after registration. This was already necessary in most cases but wasn't documented until now. (Lev Zakharov) +* Fix: Ensure pgxpool.Pool.QueryRow.Scan releases connection on panic +* CancelRequest: don't try to read the reply (Nicola Murino) +* Fix: correctly handle bool type aliases (Wichert Akkerman) +* Fix: pgconn.CancelRequest: Fix unix sockets: don't use RemoteAddr() +* Fix: pgx.Conn memory leak with prepared statement caching (Evan Jones) +* Add BeforeClose to pgxpool.Pool (Evan Cordell) +* Fix: various hstore fixes and optimizations (Evan Jones) +* Fix: RowToStructByPos with embedded unexported struct +* Support different bool string representations (Lev Zakharov) +* Fix: error when using BatchResults.Exec on a select that returns an error after some rows. +* Fix: pipelineBatchResults.Exec() not returning error from ResultReader +* Fix: pipeline batch results not closing pipeline when error occurs while reading directly from results instead of using + a callback. +* Fix: scanning a table type into a struct +* Fix: scan array of record to pointer to slice of struct +* Fix: handle null for json (Cemre Mengu) +* Batch Query callback is called even when there is an error +* Add RowTo(AddrOf)StructByNameLax (Audi P. Risa P) + # 5.3.1 (February 27, 2023) * Fix: Support v4 and v5 stdlib in same program (Tomáš Procházka) diff --git a/vendor/github.com/jackc/pgx/v5/README.md b/vendor/github.com/jackc/pgx/v5/README.md index 29d9521c69..522206f95a 100644 --- a/vendor/github.com/jackc/pgx/v5/README.md +++ b/vendor/github.com/jackc/pgx/v5/README.md @@ -1,5 +1,5 @@ [![Go Reference](https://pkg.go.dev/badge/github.com/jackc/pgx/v5.svg)](https://pkg.go.dev/github.com/jackc/pgx/v5) -![Build Status](https://github.com/jackc/pgx/actions/workflows/ci.yml/badge.svg) +[![Build Status](https://github.com/jackc/pgx/actions/workflows/ci.yml/badge.svg)](https://github.com/jackc/pgx/actions/workflows/ci.yml) # pgx - PostgreSQL Driver and Toolkit @@ -132,13 +132,25 @@ These adapters can be used with the tracelog package. * [github.com/jackc/pgx-logrus](https://github.com/jackc/pgx-logrus) * [github.com/jackc/pgx-zap](https://github.com/jackc/pgx-zap) * [github.com/jackc/pgx-zerolog](https://github.com/jackc/pgx-zerolog) +* [github.com/mcosta74/pgx-slog](https://github.com/mcosta74/pgx-slog) +* [github.com/kataras/pgx-golog](https://github.com/kataras/pgx-golog) ## 3rd Party Libraries with PGX Support +### [github.com/pashagolub/pgxmock](https://github.com/pashagolub/pgxmock) + +pgxmock is a mock library implementing pgx interfaces. +pgxmock has one and only purpose - to simulate pgx behavior in tests, without needing a real database connection. + ### [github.com/georgysavva/scany](https://github.com/georgysavva/scany) Library for scanning data from a database into Go structs and more. +### [github.com/vingarcia/ksql](https://github.com/vingarcia/ksql) + +A carefully designed SQL client for making using SQL easier, +more productive, and less error-prone on Golang. + ### [https://github.com/otan/gopgkrb5](https://github.com/otan/gopgkrb5) Adds GSSAPI / Kerberos authentication support. diff --git a/vendor/github.com/jackc/pgx/v5/batch.go b/vendor/github.com/jackc/pgx/v5/batch.go index af62039f81..8f6ea4f0d5 100644 --- a/vendor/github.com/jackc/pgx/v5/batch.go +++ b/vendor/github.com/jackc/pgx/v5/batch.go @@ -21,13 +21,10 @@ type batchItemFunc func(br BatchResults) error // Query sets fn to be called when the response to qq is received. func (qq *QueuedQuery) Query(fn func(rows Rows) error) { qq.fn = func(br BatchResults) error { - rows, err := br.Query() - if err != nil { - return err - } + rows, _ := br.Query() defer rows.Close() - err = fn(rows) + err := fn(rows) if err != nil { return err } @@ -142,7 +139,10 @@ func (br *batchResults) Exec() (pgconn.CommandTag, error) { } commandTag, err := br.mrr.ResultReader().Close() - br.err = err + if err != nil { + br.err = err + br.mrr.Close() + } if br.conn.batchTracer != nil { br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{ @@ -228,7 +228,7 @@ func (br *batchResults) Close() error { for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) { if br.b.queuedQueries[br.qqIdx].fn != nil { err := br.b.queuedQueries[br.qqIdx].fn(br) - if err != nil && br.err == nil { + if err != nil { br.err = err } } else { @@ -290,7 +290,7 @@ func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) { results, err := br.pipeline.GetResults() if err != nil { br.err = err - return pgconn.CommandTag{}, err + return pgconn.CommandTag{}, br.err } var commandTag pgconn.CommandTag switch results := results.(type) { @@ -309,7 +309,7 @@ func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) { }) } - return commandTag, err + return commandTag, br.err } // Query reads the results from the next query in the batch as if the query has been sent with Query. @@ -384,24 +384,20 @@ func (br *pipelineBatchResults) Close() error { } }() - if br.err != nil { - return br.err - } - - if br.lastRows != nil && br.lastRows.err != nil { + if br.err == nil && br.lastRows != nil && br.lastRows.err != nil { br.err = br.lastRows.err return br.err } if br.closed { - return nil + return br.err } // Read and run fn for all remaining items for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) { if br.b.queuedQueries[br.qqIdx].fn != nil { err := br.b.queuedQueries[br.qqIdx].fn(br) - if err != nil && br.err == nil { + if err != nil { br.err = err } } else { diff --git a/vendor/github.com/jackc/pgx/v5/conn.go b/vendor/github.com/jackc/pgx/v5/conn.go index 92b6f3e4a7..7c7081b487 100644 --- a/vendor/github.com/jackc/pgx/v5/conn.go +++ b/vendor/github.com/jackc/pgx/v5/conn.go @@ -178,7 +178,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con case "simple_protocol": defaultQueryExecMode = QueryExecModeSimpleProtocol default: - return nil, fmt.Errorf("invalid default_query_exec_mode: %v", err) + return nil, fmt.Errorf("invalid default_query_exec_mode: %s", s) } } @@ -194,7 +194,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con return connConfig, nil } -// ParseConfig creates a ConnConfig from a connection string. ParseConfig handles all options that pgconn.ParseConfig +// ParseConfig creates a ConnConfig from a connection string. ParseConfig handles all options that [pgconn.ParseConfig] // does. In addition, it accepts the following options: // // - default_query_exec_mode. @@ -382,11 +382,9 @@ func quoteIdentifier(s string) string { return `"` + strings.ReplaceAll(s, `"`, `""`) + `"` } -// Ping executes an empty sql statement against the *Conn -// If the sql returns without error, the database Ping is considered successful, otherwise, the error is returned. +// Ping delegates to the underlying *pgconn.PgConn.Ping. func (c *Conn) Ping(ctx context.Context) error { - _, err := c.Exec(ctx, ";") - return err + return c.pgConn.Ping(ctx) } // PgConn returns the underlying *pgconn.PgConn. This is an escape hatch method that allows lower level access to the @@ -509,7 +507,7 @@ func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []a mrr := c.pgConn.Exec(ctx, sql) for mrr.NextResult() { - commandTag, err = mrr.ResultReader().Close() + commandTag, _ = mrr.ResultReader().Close() } err = mrr.Close() return commandTag, err @@ -585,8 +583,10 @@ const ( QueryExecModeCacheDescribe // Get the statement description on every execution. This uses the extended protocol. Queries require two round trips - // to execute. It does not use prepared statements (allowing usage with most connection poolers) and is safe even - // when the the database schema is modified concurrently. + // to execute. It does not use named prepared statements. But it does use the unnamed prepared statement to get the + // statement description on the first round trip and then uses it to execute the query on the second round trip. This + // may cause problems with connection poolers that switch the underlying connection between round trips. It is safe + // even when the the database schema is modified concurrently. QueryExecModeDescribeExec // Assume the PostgreSQL query parameter types based on the Go type of the arguments. This uses the extended protocol @@ -648,6 +648,9 @@ type QueryRewriter interface { // returned Rows even if an error is returned. The error will be the available in rows.Err() after rows are closed. It // is allowed to ignore the error returned from Query and handle it in Rows. // +// It is possible for a call of FieldDescriptions on the returned Rows to return nil even if the Query call did not +// return an error. +// // It is possible for a query to return one or more rows before encountering an error. In most cases the rows should be // collected before processing rather than processed while receiving each row. This avoids the possibility of the // application processing rows from a query that the server rejected. The CollectRows function is useful here. @@ -975,7 +978,7 @@ func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchR func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) { if c.statementCache == nil { - return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledStatementCache} + return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledStatementCache, closed: true} } distinctNewQueries := []*pgconn.StatementDescription{} @@ -1007,7 +1010,7 @@ func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batc func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) { if c.descriptionCache == nil { - return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledDescriptionCache} + return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledDescriptionCache, closed: true} } distinctNewQueries := []*pgconn.StatementDescription{} @@ -1061,7 +1064,7 @@ func (c *Conn) sendBatchQueryExecModeDescribeExec(ctx context.Context, b *Batch) func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, distinctNewQueries []*pgconn.StatementDescription, sdCache stmtcache.Cache) (pbr *pipelineBatchResults) { pipeline := c.pgConn.StartPipeline(context.Background()) defer func() { - if pbr.err != nil { + if pbr != nil && pbr.err != nil { pipeline.Close() } }() @@ -1074,18 +1077,18 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d err := pipeline.Sync() if err != nil { - return &pipelineBatchResults{ctx: ctx, conn: c, err: err} + return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} } for _, sd := range distinctNewQueries { results, err := pipeline.GetResults() if err != nil { - return &pipelineBatchResults{ctx: ctx, conn: c, err: err} + return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} } resultSD, ok := results.(*pgconn.StatementDescription) if !ok { - return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results)} + return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results), closed: true} } // Fill in the previously empty / pending statement descriptions. @@ -1095,12 +1098,12 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d results, err := pipeline.GetResults() if err != nil { - return &pipelineBatchResults{ctx: ctx, conn: c, err: err} + return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} } _, ok := results.(*pgconn.PipelineSync) if !ok { - return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results)} + return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results), closed: true} } } @@ -1117,7 +1120,7 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d if err != nil { // we wrap the error so we the user can understand which query failed inside the batch err = fmt.Errorf("error building query %s: %w", bi.query, err) - return &pipelineBatchResults{ctx: ctx, conn: c, err: err} + return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} } if bi.sd.Name == "" { @@ -1129,7 +1132,7 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d err := pipeline.Sync() if err != nil { - return &pipelineBatchResults{ctx: ctx, conn: c, err: err} + return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} } return &pipelineBatchResults{ @@ -1282,7 +1285,9 @@ func (c *Conn) getCompositeFields(ctx context.Context, oid uint32) ([]pgtype.Com var fieldOID uint32 rows, _ := c.Query(ctx, `select attname, atttypid from pg_attribute -where attrelid=$1 and not attisdropped +where attrelid=$1 + and not attisdropped + and attnum > 0 order by attnum`, typrelid, ) @@ -1324,6 +1329,7 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error for _, sd := range invalidatedStatements { pipeline.SendDeallocate(sd.Name) + delete(c.preparedStatements, sd.Name) } err := pipeline.Sync() diff --git a/vendor/github.com/jackc/pgx/v5/doc.go b/vendor/github.com/jackc/pgx/v5/doc.go index 0db8cbb140..7486f42c5d 100644 --- a/vendor/github.com/jackc/pgx/v5/doc.go +++ b/vendor/github.com/jackc/pgx/v5/doc.go @@ -7,17 +7,17 @@ details. Establishing a Connection -The primary way of establishing a connection is with `pgx.Connect`. +The primary way of establishing a connection is with [pgx.Connect]: conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL")) The database connection string can be in URL or DSN format. Both PostgreSQL settings and pgx settings can be specified -here. In addition, a config struct can be created by `ParseConfig` and modified before establishing the connection with -`ConnectConfig` to configure settings such as tracing that cannot be configured with a connection string. +here. In addition, a config struct can be created by [ParseConfig] and modified before establishing the connection with +[ConnectConfig] to configure settings such as tracing that cannot be configured with a connection string. Connection Pool -`*pgx.Conn` represents a single connection to the database and is not concurrency safe. Use package +[*pgx.Conn] represents a single connection to the database and is not concurrency safe. Use package github.com/jackc/pgx/v5/pgxpool for a concurrency safe connection pool. Query Interface diff --git a/vendor/github.com/jackc/pgx/v5/internal/nbconn/bufferqueue.go b/vendor/github.com/jackc/pgx/v5/internal/nbconn/bufferqueue.go deleted file mode 100644 index 4bf25481c5..0000000000 --- a/vendor/github.com/jackc/pgx/v5/internal/nbconn/bufferqueue.go +++ /dev/null @@ -1,70 +0,0 @@ -package nbconn - -import ( - "sync" -) - -const minBufferQueueLen = 8 - -type bufferQueue struct { - lock sync.Mutex - queue []*[]byte - r, w int -} - -func (bq *bufferQueue) pushBack(buf *[]byte) { - bq.lock.Lock() - defer bq.lock.Unlock() - - if bq.w >= len(bq.queue) { - bq.growQueue() - } - bq.queue[bq.w] = buf - bq.w++ -} - -func (bq *bufferQueue) pushFront(buf *[]byte) { - bq.lock.Lock() - defer bq.lock.Unlock() - - if bq.w >= len(bq.queue) { - bq.growQueue() - } - copy(bq.queue[bq.r+1:bq.w+1], bq.queue[bq.r:bq.w]) - bq.queue[bq.r] = buf - bq.w++ -} - -func (bq *bufferQueue) popFront() *[]byte { - bq.lock.Lock() - defer bq.lock.Unlock() - - if bq.r == bq.w { - return nil - } - - buf := bq.queue[bq.r] - bq.queue[bq.r] = nil // Clear reference so it can be garbage collected. - bq.r++ - - if bq.r == bq.w { - bq.r = 0 - bq.w = 0 - if len(bq.queue) > minBufferQueueLen { - bq.queue = make([]*[]byte, minBufferQueueLen) - } - } - - return buf -} - -func (bq *bufferQueue) growQueue() { - desiredLen := (len(bq.queue) + 1) * 3 / 2 - if desiredLen < minBufferQueueLen { - desiredLen = minBufferQueueLen - } - - newQueue := make([]*[]byte, desiredLen) - copy(newQueue, bq.queue) - bq.queue = newQueue -} diff --git a/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn.go b/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn.go deleted file mode 100644 index 7a38383f0e..0000000000 --- a/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn.go +++ /dev/null @@ -1,520 +0,0 @@ -// Package nbconn implements a non-blocking net.Conn wrapper. -// -// It is designed to solve three problems. -// -// The first is resolving the deadlock that can occur when both sides of a connection are blocked writing because all -// buffers between are full. See https://github.com/jackc/pgconn/issues/27 for discussion. -// -// The second is the inability to use a write deadline with a TLS.Conn without killing the connection. -// -// The third is to efficiently check if a connection has been closed via a non-blocking read. -package nbconn - -import ( - "crypto/tls" - "errors" - "net" - "os" - "sync" - "sync/atomic" - "syscall" - "time" - - "github.com/jackc/pgx/v5/internal/iobufpool" -) - -var errClosed = errors.New("closed") -var ErrWouldBlock = new(wouldBlockError) - -const fakeNonblockingWriteWaitDuration = 100 * time.Millisecond -const minNonblockingReadWaitDuration = time.Microsecond -const maxNonblockingReadWaitDuration = 100 * time.Millisecond - -// NonBlockingDeadline is a magic value that when passed to Set[Read]Deadline places the connection in non-blocking read -// mode. -var NonBlockingDeadline = time.Date(1900, 1, 1, 0, 0, 0, 608536336, time.UTC) - -// disableSetDeadlineDeadline is a magic value that when passed to Set[Read|Write]Deadline causes those methods to -// ignore all future calls. -var disableSetDeadlineDeadline = time.Date(1900, 1, 1, 0, 0, 0, 968549727, time.UTC) - -// wouldBlockError implements net.Error so tls.Conn will recognize ErrWouldBlock as a temporary error. -type wouldBlockError struct{} - -func (*wouldBlockError) Error() string { - return "would block" -} - -func (*wouldBlockError) Timeout() bool { return true } -func (*wouldBlockError) Temporary() bool { return true } - -// Conn is a net.Conn where Write never blocks and always succeeds. Flush or Read must be called to actually write to -// the underlying connection. -type Conn interface { - net.Conn - - // Flush flushes any buffered writes. - Flush() error - - // BufferReadUntilBlock reads and buffers any successfully read bytes until the read would block. - BufferReadUntilBlock() error -} - -// NetConn is a non-blocking net.Conn wrapper. It implements net.Conn. -type NetConn struct { - // 64 bit fields accessed with atomics must be at beginning of struct to guarantee alignment for certain 32-bit - // architectures. See BUGS section of https://pkg.go.dev/sync/atomic and https://github.com/jackc/pgx/issues/1288 and - // https://github.com/jackc/pgx/issues/1307. Only access with atomics - closed int64 // 0 = not closed, 1 = closed - - conn net.Conn - rawConn syscall.RawConn - - readQueue bufferQueue - writeQueue bufferQueue - - readFlushLock sync.Mutex - // non-blocking writes with syscall.RawConn are done with a callback function. By using these fields instead of the - // callback functions closure to pass the buf argument and receive the n and err results we avoid some allocations. - nonblockWriteFunc func(fd uintptr) (done bool) - nonblockWriteBuf []byte - nonblockWriteErr error - nonblockWriteN int - - // non-blocking reads with syscall.RawConn are done with a callback function. By using these fields instead of the - // callback functions closure to pass the buf argument and receive the n and err results we avoid some allocations. - nonblockReadFunc func(fd uintptr) (done bool) - nonblockReadBuf []byte - nonblockReadErr error - nonblockReadN int - - readDeadlineLock sync.Mutex - readDeadline time.Time - readNonblocking bool - fakeNonBlockingShortReadCount int - fakeNonblockingReadWaitDuration time.Duration - - writeDeadlineLock sync.Mutex - writeDeadline time.Time -} - -func NewNetConn(conn net.Conn, fakeNonBlockingIO bool) *NetConn { - nc := &NetConn{ - conn: conn, - fakeNonblockingReadWaitDuration: maxNonblockingReadWaitDuration, - } - - if !fakeNonBlockingIO { - if sc, ok := conn.(syscall.Conn); ok { - if rawConn, err := sc.SyscallConn(); err == nil { - nc.rawConn = rawConn - } - } - } - - return nc -} - -// Read implements io.Reader. -func (c *NetConn) Read(b []byte) (n int, err error) { - if c.isClosed() { - return 0, errClosed - } - - c.readFlushLock.Lock() - defer c.readFlushLock.Unlock() - - err = c.flush() - if err != nil { - return 0, err - } - - for n < len(b) { - buf := c.readQueue.popFront() - if buf == nil { - break - } - copiedN := copy(b[n:], *buf) - if copiedN < len(*buf) { - *buf = (*buf)[copiedN:] - c.readQueue.pushFront(buf) - } else { - iobufpool.Put(buf) - } - n += copiedN - } - - // If any bytes were already buffered return them without trying to do a Read. Otherwise, when the caller is trying to - // Read up to len(b) bytes but all available bytes have already been buffered the underlying Read would block. - if n > 0 { - return n, nil - } - - var readNonblocking bool - c.readDeadlineLock.Lock() - readNonblocking = c.readNonblocking - c.readDeadlineLock.Unlock() - - var readN int - if readNonblocking { - readN, err = c.nonblockingRead(b[n:]) - } else { - readN, err = c.conn.Read(b[n:]) - } - n += readN - return n, err -} - -// Write implements io.Writer. It never blocks due to buffering all writes. It will only return an error if the Conn is -// closed. Call Flush to actually write to the underlying connection. -func (c *NetConn) Write(b []byte) (n int, err error) { - if c.isClosed() { - return 0, errClosed - } - - buf := iobufpool.Get(len(b)) - copy(*buf, b) - c.writeQueue.pushBack(buf) - return len(b), nil -} - -func (c *NetConn) Close() (err error) { - swapped := atomic.CompareAndSwapInt64(&c.closed, 0, 1) - if !swapped { - return errClosed - } - - defer func() { - closeErr := c.conn.Close() - if err == nil { - err = closeErr - } - }() - - c.readFlushLock.Lock() - defer c.readFlushLock.Unlock() - err = c.flush() - if err != nil { - return err - } - - return nil -} - -func (c *NetConn) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} - -func (c *NetConn) RemoteAddr() net.Addr { - return c.conn.RemoteAddr() -} - -// SetDeadline is the equivalent of calling SetReadDealine(t) and SetWriteDeadline(t). -func (c *NetConn) SetDeadline(t time.Time) error { - err := c.SetReadDeadline(t) - if err != nil { - return err - } - return c.SetWriteDeadline(t) -} - -// SetReadDeadline sets the read deadline as t. If t == NonBlockingDeadline then future reads will be non-blocking. -func (c *NetConn) SetReadDeadline(t time.Time) error { - if c.isClosed() { - return errClosed - } - - c.readDeadlineLock.Lock() - defer c.readDeadlineLock.Unlock() - if c.readDeadline == disableSetDeadlineDeadline { - return nil - } - if t == disableSetDeadlineDeadline { - c.readDeadline = t - return nil - } - - if t == NonBlockingDeadline { - c.readNonblocking = true - t = time.Time{} - } else { - c.readNonblocking = false - } - - c.readDeadline = t - - return c.conn.SetReadDeadline(t) -} - -func (c *NetConn) SetWriteDeadline(t time.Time) error { - if c.isClosed() { - return errClosed - } - - c.writeDeadlineLock.Lock() - defer c.writeDeadlineLock.Unlock() - if c.writeDeadline == disableSetDeadlineDeadline { - return nil - } - if t == disableSetDeadlineDeadline { - c.writeDeadline = t - return nil - } - - c.writeDeadline = t - - return c.conn.SetWriteDeadline(t) -} - -func (c *NetConn) Flush() error { - if c.isClosed() { - return errClosed - } - - c.readFlushLock.Lock() - defer c.readFlushLock.Unlock() - return c.flush() -} - -// flush does the actual work of flushing the writeQueue. readFlushLock must already be held. -func (c *NetConn) flush() error { - var stopChan chan struct{} - var errChan chan error - - defer func() { - if stopChan != nil { - select { - case stopChan <- struct{}{}: - case <-errChan: - } - } - }() - - for buf := c.writeQueue.popFront(); buf != nil; buf = c.writeQueue.popFront() { - remainingBuf := *buf - for len(remainingBuf) > 0 { - n, err := c.nonblockingWrite(remainingBuf) - remainingBuf = remainingBuf[n:] - if err != nil { - if !errors.Is(err, ErrWouldBlock) { - *buf = (*buf)[:len(remainingBuf)] - copy(*buf, remainingBuf) - c.writeQueue.pushFront(buf) - return err - } - - // Writing was blocked. Reading might unblock it. - if stopChan == nil { - stopChan, errChan = c.bufferNonblockingRead() - } - - select { - case err := <-errChan: - stopChan = nil - return err - default: - } - - } - } - iobufpool.Put(buf) - } - - return nil -} - -func (c *NetConn) BufferReadUntilBlock() error { - for { - buf := iobufpool.Get(8 * 1024) - n, err := c.nonblockingRead(*buf) - if n > 0 { - *buf = (*buf)[:n] - c.readQueue.pushBack(buf) - } else if n == 0 { - iobufpool.Put(buf) - } - - if err != nil { - if errors.Is(err, ErrWouldBlock) { - return nil - } else { - return err - } - } - } -} - -func (c *NetConn) bufferNonblockingRead() (stopChan chan struct{}, errChan chan error) { - stopChan = make(chan struct{}) - errChan = make(chan error, 1) - - go func() { - for { - err := c.BufferReadUntilBlock() - if err != nil { - errChan <- err - return - } - - select { - case <-stopChan: - return - default: - } - } - }() - - return stopChan, errChan -} - -func (c *NetConn) isClosed() bool { - closed := atomic.LoadInt64(&c.closed) - return closed == 1 -} - -func (c *NetConn) nonblockingWrite(b []byte) (n int, err error) { - if c.rawConn == nil { - return c.fakeNonblockingWrite(b) - } else { - return c.realNonblockingWrite(b) - } -} - -func (c *NetConn) fakeNonblockingWrite(b []byte) (n int, err error) { - c.writeDeadlineLock.Lock() - defer c.writeDeadlineLock.Unlock() - - deadline := time.Now().Add(fakeNonblockingWriteWaitDuration) - if c.writeDeadline.IsZero() || deadline.Before(c.writeDeadline) { - err = c.conn.SetWriteDeadline(deadline) - if err != nil { - return 0, err - } - defer func() { - // Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails. - c.conn.SetWriteDeadline(c.writeDeadline) - - if err != nil { - if errors.Is(err, os.ErrDeadlineExceeded) { - err = ErrWouldBlock - } - } - }() - } - - return c.conn.Write(b) -} - -func (c *NetConn) nonblockingRead(b []byte) (n int, err error) { - if c.rawConn == nil { - return c.fakeNonblockingRead(b) - } else { - return c.realNonblockingRead(b) - } -} - -func (c *NetConn) fakeNonblockingRead(b []byte) (n int, err error) { - c.readDeadlineLock.Lock() - defer c.readDeadlineLock.Unlock() - - // The first 5 reads only read 1 byte at a time. This should give us 4 chances to read when we are sure the bytes are - // already in Go or the OS's receive buffer. - if c.fakeNonBlockingShortReadCount < 5 && len(b) > 0 && c.fakeNonblockingReadWaitDuration < minNonblockingReadWaitDuration { - b = b[:1] - } - - startTime := time.Now() - deadline := startTime.Add(c.fakeNonblockingReadWaitDuration) - if c.readDeadline.IsZero() || deadline.Before(c.readDeadline) { - err = c.conn.SetReadDeadline(deadline) - if err != nil { - return 0, err - } - defer func() { - // If the read was successful and the wait duration is not already the minimum - if err == nil && c.fakeNonblockingReadWaitDuration > minNonblockingReadWaitDuration { - endTime := time.Now() - - if n > 0 && c.fakeNonBlockingShortReadCount < 5 { - c.fakeNonBlockingShortReadCount++ - } - - // The wait duration should be 2x the fastest read that has occurred. This should give reasonable assurance that - // a Read deadline will not block a read before it has a chance to read data already in Go or the OS's receive - // buffer. - proposedWait := endTime.Sub(startTime) * 2 - if proposedWait < minNonblockingReadWaitDuration { - proposedWait = minNonblockingReadWaitDuration - } - if proposedWait < c.fakeNonblockingReadWaitDuration { - c.fakeNonblockingReadWaitDuration = proposedWait - } - } - - // Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails. - c.conn.SetReadDeadline(c.readDeadline) - - if err != nil { - if errors.Is(err, os.ErrDeadlineExceeded) { - err = ErrWouldBlock - } - } - }() - } - - return c.conn.Read(b) -} - -// syscall.Conn is interface - -// TLSClient establishes a TLS connection as a client over conn using config. -// -// To avoid the first Read on the returned *TLSConn also triggering a Write due to the TLS handshake and thereby -// potentially causing a read and write deadlines to behave unexpectedly, Handshake is called explicitly before the -// *TLSConn is returned. -func TLSClient(conn *NetConn, config *tls.Config) (*TLSConn, error) { - tc := tls.Client(conn, config) - err := tc.Handshake() - if err != nil { - return nil, err - } - - // Ensure last written part of Handshake is actually sent. - err = conn.Flush() - if err != nil { - return nil, err - } - - return &TLSConn{ - tlsConn: tc, - nbConn: conn, - }, nil -} - -// TLSConn is a TLS wrapper around a *Conn. It works around a temporary write error (such as a timeout) being fatal to a -// tls.Conn. -type TLSConn struct { - tlsConn *tls.Conn - nbConn *NetConn -} - -func (tc *TLSConn) Read(b []byte) (n int, err error) { return tc.tlsConn.Read(b) } -func (tc *TLSConn) Write(b []byte) (n int, err error) { return tc.tlsConn.Write(b) } -func (tc *TLSConn) BufferReadUntilBlock() error { return tc.nbConn.BufferReadUntilBlock() } -func (tc *TLSConn) Flush() error { return tc.nbConn.Flush() } -func (tc *TLSConn) LocalAddr() net.Addr { return tc.tlsConn.LocalAddr() } -func (tc *TLSConn) RemoteAddr() net.Addr { return tc.tlsConn.RemoteAddr() } - -func (tc *TLSConn) Close() error { - // tls.Conn.closeNotify() sets a 5 second deadline to avoid blocking, sends a TLS alert close notification, and then - // sets the deadline to now. This causes NetConn's Close not to be able to flush the write buffer. Instead we set our - // own 5 second deadline then make all set deadlines no-op. - tc.tlsConn.SetDeadline(time.Now().Add(time.Second * 5)) - tc.tlsConn.SetDeadline(disableSetDeadlineDeadline) - - return tc.tlsConn.Close() -} - -func (tc *TLSConn) SetDeadline(t time.Time) error { return tc.tlsConn.SetDeadline(t) } -func (tc *TLSConn) SetReadDeadline(t time.Time) error { return tc.tlsConn.SetReadDeadline(t) } -func (tc *TLSConn) SetWriteDeadline(t time.Time) error { return tc.tlsConn.SetWriteDeadline(t) } diff --git a/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_fake_non_block.go b/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_fake_non_block.go deleted file mode 100644 index 4915c62198..0000000000 --- a/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_fake_non_block.go +++ /dev/null @@ -1,11 +0,0 @@ -//go:build !unix - -package nbconn - -func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) { - return c.fakeNonblockingWrite(b) -} - -func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { - return c.fakeNonblockingRead(b) -} diff --git a/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_real_non_block.go b/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_real_non_block.go deleted file mode 100644 index e93372f256..0000000000 --- a/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_real_non_block.go +++ /dev/null @@ -1,81 +0,0 @@ -//go:build unix - -package nbconn - -import ( - "errors" - "io" - "syscall" -) - -// realNonblockingWrite does a non-blocking write. readFlushLock must already be held. -func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) { - if c.nonblockWriteFunc == nil { - c.nonblockWriteFunc = func(fd uintptr) (done bool) { - c.nonblockWriteN, c.nonblockWriteErr = syscall.Write(int(fd), c.nonblockWriteBuf) - return true - } - } - c.nonblockWriteBuf = b - c.nonblockWriteN = 0 - c.nonblockWriteErr = nil - - err = c.rawConn.Write(c.nonblockWriteFunc) - n = c.nonblockWriteN - c.nonblockWriteBuf = nil // ensure that no reference to b is kept. - if err == nil && c.nonblockWriteErr != nil { - if errors.Is(c.nonblockWriteErr, syscall.EWOULDBLOCK) { - err = ErrWouldBlock - } else { - err = c.nonblockWriteErr - } - } - if err != nil { - // n may be -1 when an error occurs. - if n < 0 { - n = 0 - } - - return n, err - } - - return n, nil -} - -func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { - if c.nonblockReadFunc == nil { - c.nonblockReadFunc = func(fd uintptr) (done bool) { - c.nonblockReadN, c.nonblockReadErr = syscall.Read(int(fd), c.nonblockReadBuf) - return true - } - } - c.nonblockReadBuf = b - c.nonblockReadN = 0 - c.nonblockReadErr = nil - - err = c.rawConn.Read(c.nonblockReadFunc) - n = c.nonblockReadN - c.nonblockReadBuf = nil // ensure that no reference to b is kept. - if err == nil && c.nonblockReadErr != nil { - if errors.Is(c.nonblockReadErr, syscall.EWOULDBLOCK) { - err = ErrWouldBlock - } else { - err = c.nonblockReadErr - } - } - if err != nil { - // n may be -1 when an error occurs. - if n < 0 { - n = 0 - } - - return n, err - } - - // syscall read did not return an error and 0 bytes were read means EOF. - if n == 0 { - return 0, io.EOF - } - - return n, nil -} diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go b/vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go index 6ca9e33791..8c4b2de3cb 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go @@ -42,7 +42,7 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { Data: sc.clientFirstMessage(), } c.frontend.Send(saslInitialResponse) - err = c.frontend.Flush() + err = c.flushWithPotentialWriteReadDeadlock() if err != nil { return err } @@ -62,7 +62,7 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { Data: []byte(sc.clientFinalMessage()), } c.frontend.Send(saslResponse) - err = c.frontend.Flush() + err = c.flushWithPotentialWriteReadDeadlock() if err != nil { return err } diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/config.go b/vendor/github.com/jackc/pgx/v5/pgconn/config.go index 24bf837ce1..1c2c647d9f 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/config.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/config.go @@ -26,7 +26,7 @@ type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error type GetSSLPasswordFunc func(ctx context.Context) string -// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig. A +// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by [ParseConfig]. A // manually initialized Config will cause ConnectConfig to panic. type Config struct { Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp) diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/internal/bgreader/bgreader.go b/vendor/github.com/jackc/pgx/v5/pgconn/internal/bgreader/bgreader.go new file mode 100644 index 0000000000..e65c2c2bf2 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/pgconn/internal/bgreader/bgreader.go @@ -0,0 +1,139 @@ +// Package bgreader provides a io.Reader that can optionally buffer reads in the background. +package bgreader + +import ( + "io" + "sync" + + "github.com/jackc/pgx/v5/internal/iobufpool" +) + +const ( + StatusStopped = iota + StatusRunning + StatusStopping +) + +// BGReader is an io.Reader that can optionally buffer reads in the background. It is safe for concurrent use. +type BGReader struct { + r io.Reader + + cond *sync.Cond + status int32 + readResults []readResult +} + +type readResult struct { + buf *[]byte + err error +} + +// Start starts the backgrounder reader. If the background reader is already running this is a no-op. The background +// reader will stop automatically when the underlying reader returns an error. +func (r *BGReader) Start() { + r.cond.L.Lock() + defer r.cond.L.Unlock() + + switch r.status { + case StatusStopped: + r.status = StatusRunning + go r.bgRead() + case StatusRunning: + // no-op + case StatusStopping: + r.status = StatusRunning + } +} + +// Stop tells the background reader to stop after the in progress Read returns. It is safe to call Stop when the +// background reader is not running. +func (r *BGReader) Stop() { + r.cond.L.Lock() + defer r.cond.L.Unlock() + + switch r.status { + case StatusStopped: + // no-op + case StatusRunning: + r.status = StatusStopping + case StatusStopping: + // no-op + } +} + +// Status returns the current status of the background reader. +func (r *BGReader) Status() int32 { + r.cond.L.Lock() + defer r.cond.L.Unlock() + return r.status +} + +func (r *BGReader) bgRead() { + keepReading := true + for keepReading { + buf := iobufpool.Get(8192) + n, err := r.r.Read(*buf) + *buf = (*buf)[:n] + + r.cond.L.Lock() + r.readResults = append(r.readResults, readResult{buf: buf, err: err}) + if r.status == StatusStopping || err != nil { + r.status = StatusStopped + keepReading = false + } + r.cond.L.Unlock() + r.cond.Broadcast() + } +} + +// Read implements the io.Reader interface. +func (r *BGReader) Read(p []byte) (int, error) { + r.cond.L.Lock() + defer r.cond.L.Unlock() + + if len(r.readResults) > 0 { + return r.readFromReadResults(p) + } + + // There are no unread background read results and the background reader is stopped. + if r.status == StatusStopped { + return r.r.Read(p) + } + + // Wait for results from the background reader + for len(r.readResults) == 0 { + r.cond.Wait() + } + return r.readFromReadResults(p) +} + +// readBackgroundResults reads a result previously read by the background reader. r.cond.L must be held. +func (r *BGReader) readFromReadResults(p []byte) (int, error) { + buf := r.readResults[0].buf + var err error + + n := copy(p, *buf) + if n == len(*buf) { + err = r.readResults[0].err + iobufpool.Put(buf) + if len(r.readResults) == 1 { + r.readResults = nil + } else { + r.readResults = r.readResults[1:] + } + } else { + *buf = (*buf)[n:] + r.readResults[0].buf = buf + } + + return n, err +} + +func New(r io.Reader) *BGReader { + return &BGReader{ + r: r, + cond: &sync.Cond{ + L: &sync.Mutex{}, + }, + } +} diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/krb5.go b/vendor/github.com/jackc/pgx/v5/pgconn/krb5.go index 969675fd27..3c1af34773 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/krb5.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/krb5.go @@ -63,7 +63,7 @@ func (c *PgConn) gssAuth() error { Data: nextData, } c.frontend.Send(gssResponse) - err = c.frontend.Flush() + err = c.flushWithPotentialWriteReadDeadlock() if err != nil { return err } diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go b/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go index 8656ea5180..8f602e4090 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go @@ -13,11 +13,12 @@ import ( "net" "strconv" "strings" + "sync" "time" "github.com/jackc/pgx/v5/internal/iobufpool" - "github.com/jackc/pgx/v5/internal/nbconn" "github.com/jackc/pgx/v5/internal/pgio" + "github.com/jackc/pgx/v5/pgconn/internal/bgreader" "github.com/jackc/pgx/v5/pgconn/internal/ctxwatch" "github.com/jackc/pgx/v5/pgproto3" ) @@ -65,17 +66,24 @@ type NotificationHandler func(*PgConn, *Notification) // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. type PgConn struct { - conn nbconn.Conn // the non-blocking wrapper for the underlying TCP or unix domain socket connection + conn net.Conn pid uint32 // backend pid secretKey uint32 // key to use to send a cancel query message to the server parameterStatuses map[string]string // parameters that have been reported by the server txStatus byte frontend *pgproto3.Frontend + bgReader *bgreader.BGReader + slowWriteTimer *time.Timer config *Config status byte // One of connStatus* constants + bufferingReceive bool + bufferingReceiveMux sync.Mutex + bufferingReceiveMsg pgproto3.BackendMessage + bufferingReceiveErr error + peekedMsg pgproto3.BackendMessage // Reusable / preallocated resources @@ -89,7 +97,7 @@ type PgConn struct { } // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) -// to provide configuration. See documentation for ParseConfig for details. ctx can be used to cancel a connect attempt. +// to provide configuration. See documentation for [ParseConfig] for details. ctx can be used to cancel a connect attempt. func Connect(ctx context.Context, connString string) (*PgConn, error) { config, err := ParseConfig(connString) if err != nil { @@ -100,7 +108,7 @@ func Connect(ctx context.Context, connString string) (*PgConn, error) { } // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) -// and ParseConfigOptions to provide additional configuration. See documentation for ParseConfig for details. ctx can be +// and ParseConfigOptions to provide additional configuration. See documentation for [ParseConfig] for details. ctx can be // used to cancel a connect attempt. func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptions ParseConfigOptions) (*PgConn, error) { config, err := ParseConfigWithOptions(connString, parseConfigOptions) @@ -112,7 +120,7 @@ func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptio } // Connect establishes a connection to a PostgreSQL server using config. config must have been constructed with -// ParseConfig. ctx can be used to cancel a connect attempt. +// [ParseConfig]. ctx can be used to cancel a connect attempt. // // If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An // authentication error will terminate the chain of attempts (like libpq: @@ -146,12 +154,15 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er foundBestServer := false var fallbackConfig *FallbackConfig - for _, fc := range fallbackConfigs { + for i, fc := range fallbackConfigs { // ConnectTimeout restricts the whole connection process. if config.ConnectTimeout != 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(octx, config.ConnectTimeout) - defer cancel() + // create new context first time or when previous host was different + if i == 0 || (fallbackConfigs[i].Host != fallbackConfigs[i-1].Host) { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(octx, config.ConnectTimeout) + defer cancel() + } } else { ctx = octx } @@ -166,7 +177,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege if pgerr.Code == ERRCODE_INVALID_PASSWORD || - pgerr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION || + pgerr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION && fc.TLSConfig != nil || pgerr.Code == ERRCODE_INVALID_CATALOG_NAME || pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE { break @@ -255,7 +266,8 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba } func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig, - ignoreNotPreferredErr bool) (*PgConn, error) { + ignoreNotPreferredErr bool, +) (*PgConn, error) { pgConn := new(PgConn) pgConn.config = config pgConn.cleanupDone = make(chan struct{}) @@ -266,14 +278,13 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig if err != nil { return nil, &connectError{config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)} } - nbNetConn := nbconn.NewNetConn(netConn, false) - pgConn.conn = nbNetConn - pgConn.contextWatcher = newContextWatcher(nbNetConn) + pgConn.conn = netConn + pgConn.contextWatcher = newContextWatcher(netConn) pgConn.contextWatcher.Watch(ctx) if fallbackConfig.TLSConfig != nil { - nbTLSConn, err := startTLS(nbNetConn, fallbackConfig.TLSConfig) + nbTLSConn, err := startTLS(netConn, fallbackConfig.TLSConfig) pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS. if err != nil { netConn.Close() @@ -289,7 +300,10 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn.parameterStatuses = make(map[string]string) pgConn.status = connStatusConnecting - pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn) + pgConn.bgReader = bgreader.New(pgConn.conn) + pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), pgConn.bgReader.Start) + pgConn.slowWriteTimer.Stop() + pgConn.frontend = config.BuildFrontend(pgConn.bgReader, pgConn.conn) startupMsg := pgproto3.StartupMessage{ ProtocolVersion: pgproto3.ProtocolVersionNumber, @@ -307,9 +321,9 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } pgConn.frontend.Send(&startupMsg) - if err := pgConn.frontend.Flush(); err != nil { + if err := pgConn.flushWithPotentialWriteReadDeadlock(); err != nil { pgConn.conn.Close() - return nil, &connectError{config: config, msg: "failed to write startup message", err: err} + return nil, &connectError{config: config, msg: "failed to write startup message", err: normalizeTimeoutError(ctx, err)} } for { @@ -392,7 +406,7 @@ func newContextWatcher(conn net.Conn) *ctxwatch.ContextWatcher { ) } -func startTLS(conn *nbconn.NetConn, tlsConfig *tls.Config) (*nbconn.TLSConn, error) { +func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103}) if err != nil { return nil, err @@ -407,17 +421,12 @@ func startTLS(conn *nbconn.NetConn, tlsConfig *tls.Config) (*nbconn.TLSConn, err return nil, errors.New("server refused TLS connection") } - tlsConn, err := nbconn.TLSClient(conn, tlsConfig) - if err != nil { - return nil, err - } - - return tlsConn, nil + return tls.Client(conn, tlsConfig), nil } func (pgConn *PgConn) txPasswordMessage(password string) (err error) { pgConn.frontend.Send(&pgproto3.PasswordMessage{Password: password}) - return pgConn.frontend.Flush() + return pgConn.flushWithPotentialWriteReadDeadlock() } func hexMD5(s string) string { @@ -426,6 +435,24 @@ func hexMD5(s string) string { return hex.EncodeToString(hash.Sum(nil)) } +func (pgConn *PgConn) signalMessage() chan struct{} { + if pgConn.bufferingReceive { + panic("BUG: signalMessage when already in progress") + } + + pgConn.bufferingReceive = true + pgConn.bufferingReceiveMux.Lock() + + ch := make(chan struct{}) + go func() { + pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.frontend.Receive() + pgConn.bufferingReceiveMux.Unlock() + close(ch) + }() + + return ch +} + // ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the // connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages // are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger @@ -454,7 +481,8 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa err = &pgconnError{ msg: "receive message failed", err: normalizeTimeoutError(ctx, err), - safeToRetry: true} + safeToRetry: true, + } } return msg, err } @@ -465,13 +493,25 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) { return pgConn.peekedMsg, nil } - msg, err := pgConn.frontend.Receive() - - if err != nil { - if errors.Is(err, nbconn.ErrWouldBlock) { - return nil, err + var msg pgproto3.BackendMessage + var err error + if pgConn.bufferingReceive { + pgConn.bufferingReceiveMux.Lock() + msg = pgConn.bufferingReceiveMsg + err = pgConn.bufferingReceiveErr + pgConn.bufferingReceiveMux.Unlock() + pgConn.bufferingReceive = false + + // If a timeout error happened in the background try the read again. + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + msg, err = pgConn.frontend.Receive() } + } else { + msg, err = pgConn.frontend.Receive() + } + if err != nil { // Close on anything other than timeout error - everything else is fatal var netErr net.Error isNetErr := errors.As(err, &netErr) @@ -519,7 +559,8 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { return msg, nil } -// Conn returns the underlying net.Conn. This rarely necessary. +// Conn returns the underlying net.Conn. This rarely necessary. If the connection will be directly used for reading or +// writing then SyncConn should usually be called before Conn. func (pgConn *PgConn) Conn() net.Conn { return pgConn.conn } @@ -582,7 +623,7 @@ func (pgConn *PgConn) Close(ctx context.Context) error { // // See https://github.com/jackc/pgx/issues/637 pgConn.frontend.Send(&pgproto3.Terminate{}) - pgConn.frontend.Flush() + pgConn.flushWithPotentialWriteReadDeadlock() return pgConn.conn.Close() } @@ -609,7 +650,7 @@ func (pgConn *PgConn) asyncClose() { pgConn.conn.SetDeadline(deadline) pgConn.frontend.Send(&pgproto3.Terminate{}) - pgConn.frontend.Flush() + pgConn.flushWithPotentialWriteReadDeadlock() }() } @@ -784,7 +825,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ pgConn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}) pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name}) pgConn.frontend.SendSync(&pgproto3.Sync{}) - err := pgConn.frontend.Flush() + err := pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { pgConn.asyncClose() return nil, err @@ -857,9 +898,28 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // the connection config. This is important in high availability configurations where fallback connections may be // specified or DNS may be used to load balance. serverAddr := pgConn.conn.RemoteAddr() - cancelConn, err := pgConn.config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) + var serverNetwork string + var serverAddress string + if serverAddr.Network() == "unix" { + // for unix sockets, RemoteAddr() calls getpeername() which returns the name the + // server passed to bind(). For Postgres, this is always a relative path "./.s.PGSQL.5432" + // so connecting to it will fail. Fall back to the config's value + serverNetwork, serverAddress = NetworkAddress(pgConn.config.Host, pgConn.config.Port) + } else { + serverNetwork, serverAddress = serverAddr.Network(), serverAddr.String() + } + cancelConn, err := pgConn.config.DialFunc(ctx, serverNetwork, serverAddress) if err != nil { - return err + // In case of unix sockets, RemoteAddr() returns only the file part of the path. If the + // first connect failed, try the config. + if serverAddr.Network() != "unix" { + return err + } + serverNetwork, serverAddr := NetworkAddress(pgConn.config.Host, pgConn.config.Port) + cancelConn, err = pgConn.config.DialFunc(ctx, serverNetwork, serverAddr) + if err != nil { + return err + } } defer cancelConn.Close() @@ -877,17 +937,11 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { binary.BigEndian.PutUint32(buf[4:8], 80877102) binary.BigEndian.PutUint32(buf[8:12], uint32(pgConn.pid)) binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey)) + // Postgres will process the request and close the connection + // so when don't need to read the reply + // https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.6.7.10 _, err = cancelConn.Write(buf) - if err != nil { - return err - } - - _, err = cancelConn.Read(buf) - if err != io.EOF { - return err - } - - return nil + return err } // WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not @@ -953,7 +1007,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { } pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) - err := pgConn.frontend.Flush() + err := pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { pgConn.asyncClose() pgConn.contextWatcher.Unwatch() @@ -1064,7 +1118,7 @@ func (pgConn *PgConn) execExtendedSuffix(result *ResultReader) { pgConn.frontend.SendExecute(&pgproto3.Execute{}) pgConn.frontend.SendSync(&pgproto3.Sync{}) - err := pgConn.frontend.Flush() + err := pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { pgConn.asyncClose() result.concludeCommand(CommandTag{}, err) @@ -1097,7 +1151,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm // Send copy to command pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) - err := pgConn.frontend.Flush() + err := pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { pgConn.asyncClose() pgConn.unlock() @@ -1153,85 +1207,91 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co defer pgConn.contextWatcher.Unwatch() } - // Send copy to command + // Send copy from query pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) - err := pgConn.frontend.Flush() + err := pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { pgConn.asyncClose() return CommandTag{}, err } - err = pgConn.conn.SetReadDeadline(nbconn.NonBlockingDeadline) - if err != nil { - pgConn.asyncClose() - return CommandTag{}, err - } - nonblocking := true - defer func() { - if nonblocking { - pgConn.conn.SetReadDeadline(time.Time{}) - } - }() + // Send copy data + abortCopyChan := make(chan struct{}) + copyErrChan := make(chan error, 1) + signalMessageChan := pgConn.signalMessage() + var wg sync.WaitGroup + wg.Add(1) - buf := iobufpool.Get(65536) - defer iobufpool.Put(buf) - (*buf)[0] = 'd' - - var readErr, pgErr error - for pgErr == nil { - // Read chunk from r. - var n int - n, readErr = r.Read((*buf)[5:cap(*buf)]) + go func() { + defer wg.Done() + buf := iobufpool.Get(65536) + defer iobufpool.Put(buf) + (*buf)[0] = 'd' - // Send chunk to PostgreSQL. - if n > 0 { - *buf = (*buf)[0 : n+5] - pgio.SetInt32((*buf)[1:], int32(n+4)) + for { + n, readErr := r.Read((*buf)[5:cap(*buf)]) + if n > 0 { + *buf = (*buf)[0 : n+5] + pgio.SetInt32((*buf)[1:], int32(n+4)) + + writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(*buf) + if writeErr != nil { + // Write errors are always fatal, but we can't use asyncClose because we are in a different goroutine. Not + // setting pgConn.status or closing pgConn.cleanupDone for the same reason. + pgConn.conn.Close() - writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(*buf) - if writeErr != nil { - pgConn.asyncClose() - return CommandTag{}, err + copyErrChan <- writeErr + return + } + } + if readErr != nil { + copyErrChan <- readErr + return } - } - // Abort loop if there was a read error. - if readErr != nil { - break + select { + case <-abortCopyChan: + return + default: + } } + }() - // Read messages until error or none available. - for pgErr == nil { - msg, err := pgConn.receiveMessage() - if err != nil { - if errors.Is(err, nbconn.ErrWouldBlock) { - break - } - pgConn.asyncClose() + var pgErr error + var copyErr error + for copyErr == nil && pgErr == nil { + select { + case copyErr = <-copyErrChan: + case <-signalMessageChan: + // If pgConn.receiveMessage encounters an error it will call pgConn.asyncClose. But that is a race condition with + // the goroutine. So instead check pgConn.bufferingReceiveErr which will have been set by the signalMessage. If an + // error is found then forcibly close the connection without sending the Terminate message. + if err := pgConn.bufferingReceiveErr; err != nil { + pgConn.status = connStatusClosed + pgConn.conn.Close() + close(pgConn.cleanupDone) return CommandTag{}, normalizeTimeoutError(ctx, err) } + msg, _ := pgConn.receiveMessage() switch msg := msg.(type) { case *pgproto3.ErrorResponse: pgErr = ErrorResponseToPgError(msg) - break + default: + signalMessageChan = pgConn.signalMessage() } } } + close(abortCopyChan) + // Make sure io goroutine finishes before writing. + wg.Wait() - err = pgConn.conn.SetReadDeadline(time.Time{}) - if err != nil { - pgConn.asyncClose() - return CommandTag{}, err - } - nonblocking = false - - if readErr == io.EOF || pgErr != nil { + if copyErr == io.EOF || pgErr != nil { pgConn.frontend.Send(&pgproto3.CopyDone{}) } else { - pgConn.frontend.Send(&pgproto3.CopyFail{Message: readErr.Error()}) + pgConn.frontend.Send(&pgproto3.CopyFail{Message: copyErr.Error()}) } - err = pgConn.frontend.Flush() + err = pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { pgConn.asyncClose() return CommandTag{}, err @@ -1283,7 +1343,6 @@ func (mrr *MultiResultReader) ReadAll() ([]*Result, error) { func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) { msg, err := mrr.pgConn.receiveMessage() - if err != nil { mrr.pgConn.contextWatcher.Unwatch() mrr.err = normalizeTimeoutError(mrr.ctx, err) @@ -1426,7 +1485,8 @@ func (rr *ResultReader) NextRow() bool { } // FieldDescriptions returns the field descriptions for the current result set. The returned slice is only valid until -// the ResultReader is closed. +// the ResultReader is closed. It may return nil (for example, if the query did not return a result set or an error was +// encountered.) func (rr *ResultReader) FieldDescriptions() []FieldDescription { return rr.fieldDescriptions } @@ -1592,6 +1652,8 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) + pgConn.enterPotentialWriteReadDeadlock() + defer pgConn.exitPotentialWriteReadDeadlock() _, err := pgConn.conn.Write(batch.buf) if err != nil { multiResult.closed = true @@ -1620,29 +1682,99 @@ func (pgConn *PgConn) EscapeString(s string) (string, error) { return strings.Replace(s, "'", "''", -1), nil } -// CheckConn checks the underlying connection without writing any bytes. This is currently implemented by reading and -// buffering until the read would block or an error occurs. This can be used to check if the server has closed the -// connection. If this is done immediately before sending a query it reduces the chances a query will be sent that fails +// CheckConn checks the underlying connection without writing any bytes. This is currently implemented by doing a read +// with a very short deadline. This can be useful because a TCP connection can be broken such that a write will appear +// to succeed even though it will never actually reach the server. Reading immediately before a write will detect this +// condition. If this is done immediately before sending a query it reduces the chances a query will be sent that fails // without the client knowing whether the server received it or not. +// +// Deprecated: CheckConn is deprecated in favor of Ping. CheckConn cannot detect all types of broken connections where +// the write would still appear to succeed. Prefer Ping unless on a high latency connection. func (pgConn *PgConn) CheckConn() error { - err := pgConn.conn.BufferReadUntilBlock() - if err != nil && !errors.Is(err, nbconn.ErrWouldBlock) { - return err + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + + _, err := pgConn.ReceiveMessage(ctx) + if err != nil { + if !Timeout(err) { + return err + } } + return nil } +// Ping pings the server. This can be useful because a TCP connection can be broken such that a write will appear to +// succeed even though it will never actually reach the server. Pinging immediately before sending a query reduces the +// chances a query will be sent that fails without the client knowing whether the server received it or not. +func (pgConn *PgConn) Ping(ctx context.Context) error { + return pgConn.Exec(ctx, "-- ping").Close() +} + // makeCommandTag makes a CommandTag. It does not retain a reference to buf or buf's underlying memory. func (pgConn *PgConn) makeCommandTag(buf []byte) CommandTag { return CommandTag{s: string(buf)} } +// enterPotentialWriteReadDeadlock must be called before a write that could deadlock if the server is simultaneously +// blocked writing to us. +func (pgConn *PgConn) enterPotentialWriteReadDeadlock() { + // The time to wait is somewhat arbitrary. A Write should only take as long as the syscall and memcpy to the OS + // outbound network buffer unless the buffer is full (which potentially is a block). It needs to be long enough for + // the normal case, but short enough not to kill performance if a block occurs. + // + // In addition, on Windows the default timer resolution is 15.6ms. So setting the timer to less than that is + // ineffective. + if pgConn.slowWriteTimer.Reset(15 * time.Millisecond) { + panic("BUG: slow write timer already active") + } +} + +// exitPotentialWriteReadDeadlock must be called after a call to enterPotentialWriteReadDeadlock. +func (pgConn *PgConn) exitPotentialWriteReadDeadlock() { + // The state of the timer is not relevant upon exiting the potential slow write. It may both + // fire (due to a slow write), or not fire (due to a fast write). + _ = pgConn.slowWriteTimer.Stop() + pgConn.bgReader.Stop() +} + +func (pgConn *PgConn) flushWithPotentialWriteReadDeadlock() error { + pgConn.enterPotentialWriteReadDeadlock() + defer pgConn.exitPotentialWriteReadDeadlock() + err := pgConn.frontend.Flush() + return err +} + +// SyncConn prepares the underlying net.Conn for direct use. PgConn may internally buffer reads or use goroutines for +// background IO. This means that any direct use of the underlying net.Conn may be corrupted if a read is already +// buffered or a read is in progress. SyncConn drains read buffers and stops background IO. In some cases this may +// require sending a ping to the server. ctx can be used to cancel this operation. This should be called before any +// operation that will use the underlying net.Conn directly. e.g. Before Conn() or Hijack(). +// +// This should not be confused with the PostgreSQL protocol Sync message. +func (pgConn *PgConn) SyncConn(ctx context.Context) error { + for i := 0; i < 10; i++ { + if pgConn.bgReader.Status() == bgreader.StatusStopped && pgConn.frontend.ReadBufferLen() == 0 { + return nil + } + + err := pgConn.Ping(ctx) + if err != nil { + return fmt.Errorf("SyncConn: Ping failed while syncing conn: %w", err) + } + } + + // This should never happen. Only way I can imagine this occuring is if the server is constantly sending data such as + // LISTEN/NOTIFY or log notifications such that we never can get an empty buffer. + return errors.New("SyncConn: conn never synchronized") +} + // HijackedConn is the result of hijacking a connection. // // Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning // compatibility. type HijackedConn struct { - Conn nbconn.Conn // the non-blocking wrapper of the underlying TCP or unix domain socket connection + Conn net.Conn PID uint32 // backend pid SecretKey uint32 // key to use to send a cancel query message to the server ParameterStatuses map[string]string // parameters that have been reported by the server @@ -1651,9 +1783,9 @@ type HijackedConn struct { Config *Config } -// Hijack extracts the internal connection data. pgConn must be in an idle state. pgConn is unusable after hijacking. -// Hijacking is typically only useful when using pgconn to establish a connection, but taking complete control of the -// raw connection after that (e.g. a load balancer or proxy). +// Hijack extracts the internal connection data. pgConn must be in an idle state. SyncConn should be called immediately +// before Hijack. pgConn is unusable after hijacking. Hijacking is typically only useful when using pgconn to establish +// a connection, but taking complete control of the raw connection after that (e.g. a load balancer or proxy). // // Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning // compatibility. @@ -1677,6 +1809,8 @@ func (pgConn *PgConn) Hijack() (*HijackedConn, error) { // Construct created a PgConn from an already established connection to a PostgreSQL server. This is the inverse of // PgConn.Hijack. The connection must be in an idle state. // +// hc.Frontend is replaced by a new pgproto3.Frontend built by hc.Config.BuildFrontend. +// // Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning // compatibility. func Construct(hc *HijackedConn) (*PgConn, error) { @@ -1695,6 +1829,10 @@ func Construct(hc *HijackedConn) (*PgConn, error) { } pgConn.contextWatcher = newContextWatcher(pgConn.conn) + pgConn.bgReader = bgreader.New(pgConn.conn) + pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), pgConn.bgReader.Start) + pgConn.slowWriteTimer.Stop() + pgConn.frontend = hc.Config.BuildFrontend(pgConn.bgReader, pgConn.conn) return pgConn, nil } @@ -1817,7 +1955,7 @@ func (p *Pipeline) Flush() error { return errors.New("pipeline closed") } - err := p.conn.frontend.Flush() + err := p.conn.flushWithPotentialWriteReadDeadlock() if err != nil { err = normalizeTimeoutError(p.ctx, err) @@ -1896,7 +2034,6 @@ func (p *Pipeline) GetResults() (results any, err error) { } } - } func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) { diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/frontend.go b/vendor/github.com/jackc/pgx/v5/pgproto3/frontend.go index 83dea96383..33c3882a2c 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/frontend.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/frontend.go @@ -361,3 +361,7 @@ func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, er func (f *Frontend) GetAuthType() uint32 { return f.authType } + +func (f *Frontend) ReadBufferLen() int { + return f.cr.wp - f.cr.rp +} diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/trace.go b/vendor/github.com/jackc/pgx/v5/pgproto3/trace.go index c09f68d1a6..6cc7d3e36c 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/trace.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/trace.go @@ -6,15 +6,18 @@ import ( "io" "strconv" "strings" + "sync" "time" ) // tracer traces the messages send to and from a Backend or Frontend. The format it produces roughly mimics the // format produced by the libpq C function PQtrace. type tracer struct { + TracerOptions + + mux sync.Mutex w io.Writer buf *bytes.Buffer - TracerOptions } // TracerOptions controls tracing behavior. It is roughly equivalent to the libpq function PQsetTraceFlags. @@ -119,278 +122,255 @@ func (t *tracer) traceMessage(sender byte, encodedLen int32, msg Message) { case *Terminate: t.traceTerminate(sender, encodedLen, msg) default: - t.beginTrace(sender, encodedLen, "Unknown") - t.finishTrace() + t.writeTrace(sender, encodedLen, "Unknown", nil) } } func (t *tracer) traceAuthenticationCleartextPassword(sender byte, encodedLen int32, msg *AuthenticationCleartextPassword) { - t.beginTrace(sender, encodedLen, "AuthenticationCleartextPassword") - t.finishTrace() + t.writeTrace(sender, encodedLen, "AuthenticationCleartextPassword", nil) } func (t *tracer) traceAuthenticationGSS(sender byte, encodedLen int32, msg *AuthenticationGSS) { - t.beginTrace(sender, encodedLen, "AuthenticationGSS") - t.finishTrace() + t.writeTrace(sender, encodedLen, "AuthenticationGSS", nil) } func (t *tracer) traceAuthenticationGSSContinue(sender byte, encodedLen int32, msg *AuthenticationGSSContinue) { - t.beginTrace(sender, encodedLen, "AuthenticationGSSContinue") - t.finishTrace() + t.writeTrace(sender, encodedLen, "AuthenticationGSSContinue", nil) } func (t *tracer) traceAuthenticationMD5Password(sender byte, encodedLen int32, msg *AuthenticationMD5Password) { - t.beginTrace(sender, encodedLen, "AuthenticationMD5Password") - t.finishTrace() + t.writeTrace(sender, encodedLen, "AuthenticationMD5Password", nil) } func (t *tracer) traceAuthenticationOk(sender byte, encodedLen int32, msg *AuthenticationOk) { - t.beginTrace(sender, encodedLen, "AuthenticationOk") - t.finishTrace() + t.writeTrace(sender, encodedLen, "AuthenticationOk", nil) } func (t *tracer) traceAuthenticationSASL(sender byte, encodedLen int32, msg *AuthenticationSASL) { - t.beginTrace(sender, encodedLen, "AuthenticationSASL") - t.finishTrace() + t.writeTrace(sender, encodedLen, "AuthenticationSASL", nil) } func (t *tracer) traceAuthenticationSASLContinue(sender byte, encodedLen int32, msg *AuthenticationSASLContinue) { - t.beginTrace(sender, encodedLen, "AuthenticationSASLContinue") - t.finishTrace() + t.writeTrace(sender, encodedLen, "AuthenticationSASLContinue", nil) } func (t *tracer) traceAuthenticationSASLFinal(sender byte, encodedLen int32, msg *AuthenticationSASLFinal) { - t.beginTrace(sender, encodedLen, "AuthenticationSASLFinal") - t.finishTrace() + t.writeTrace(sender, encodedLen, "AuthenticationSASLFinal", nil) } func (t *tracer) traceBackendKeyData(sender byte, encodedLen int32, msg *BackendKeyData) { - t.beginTrace(sender, encodedLen, "BackendKeyData") - if t.RegressMode { - t.buf.WriteString("\t NNNN NNNN") - } else { - fmt.Fprintf(t.buf, "\t %d %d", msg.ProcessID, msg.SecretKey) - } - t.finishTrace() + t.writeTrace(sender, encodedLen, "BackendKeyData", func() { + if t.RegressMode { + t.buf.WriteString("\t NNNN NNNN") + } else { + fmt.Fprintf(t.buf, "\t %d %d", msg.ProcessID, msg.SecretKey) + } + }) } func (t *tracer) traceBind(sender byte, encodedLen int32, msg *Bind) { - t.beginTrace(sender, encodedLen, "Bind") - fmt.Fprintf(t.buf, "\t %s %s %d", traceDoubleQuotedString([]byte(msg.DestinationPortal)), traceDoubleQuotedString([]byte(msg.PreparedStatement)), len(msg.ParameterFormatCodes)) - for _, fc := range msg.ParameterFormatCodes { - fmt.Fprintf(t.buf, " %d", fc) - } - fmt.Fprintf(t.buf, " %d", len(msg.Parameters)) - for _, p := range msg.Parameters { - fmt.Fprintf(t.buf, " %s", traceSingleQuotedString(p)) - } - fmt.Fprintf(t.buf, " %d", len(msg.ResultFormatCodes)) - for _, fc := range msg.ResultFormatCodes { - fmt.Fprintf(t.buf, " %d", fc) - } - t.finishTrace() + t.writeTrace(sender, encodedLen, "Bind", func() { + fmt.Fprintf(t.buf, "\t %s %s %d", traceDoubleQuotedString([]byte(msg.DestinationPortal)), traceDoubleQuotedString([]byte(msg.PreparedStatement)), len(msg.ParameterFormatCodes)) + for _, fc := range msg.ParameterFormatCodes { + fmt.Fprintf(t.buf, " %d", fc) + } + fmt.Fprintf(t.buf, " %d", len(msg.Parameters)) + for _, p := range msg.Parameters { + fmt.Fprintf(t.buf, " %s", traceSingleQuotedString(p)) + } + fmt.Fprintf(t.buf, " %d", len(msg.ResultFormatCodes)) + for _, fc := range msg.ResultFormatCodes { + fmt.Fprintf(t.buf, " %d", fc) + } + }) } func (t *tracer) traceBindComplete(sender byte, encodedLen int32, msg *BindComplete) { - t.beginTrace(sender, encodedLen, "BindComplete") - t.finishTrace() + t.writeTrace(sender, encodedLen, "BindComplete", nil) } func (t *tracer) traceCancelRequest(sender byte, encodedLen int32, msg *CancelRequest) { - t.beginTrace(sender, encodedLen, "CancelRequest") - t.finishTrace() + t.writeTrace(sender, encodedLen, "CancelRequest", nil) } func (t *tracer) traceClose(sender byte, encodedLen int32, msg *Close) { - t.beginTrace(sender, encodedLen, "Close") - t.finishTrace() + t.writeTrace(sender, encodedLen, "Close", nil) } func (t *tracer) traceCloseComplete(sender byte, encodedLen int32, msg *CloseComplete) { - t.beginTrace(sender, encodedLen, "CloseComplete") - t.finishTrace() + t.writeTrace(sender, encodedLen, "CloseComplete", nil) } func (t *tracer) traceCommandComplete(sender byte, encodedLen int32, msg *CommandComplete) { - t.beginTrace(sender, encodedLen, "CommandComplete") - fmt.Fprintf(t.buf, "\t %s", traceDoubleQuotedString(msg.CommandTag)) - t.finishTrace() + t.writeTrace(sender, encodedLen, "CommandComplete", func() { + fmt.Fprintf(t.buf, "\t %s", traceDoubleQuotedString(msg.CommandTag)) + }) } func (t *tracer) traceCopyBothResponse(sender byte, encodedLen int32, msg *CopyBothResponse) { - t.beginTrace(sender, encodedLen, "CopyBothResponse") - t.finishTrace() + t.writeTrace(sender, encodedLen, "CopyBothResponse", nil) } func (t *tracer) traceCopyData(sender byte, encodedLen int32, msg *CopyData) { - t.beginTrace(sender, encodedLen, "CopyData") - t.finishTrace() + t.writeTrace(sender, encodedLen, "CopyData", nil) } func (t *tracer) traceCopyDone(sender byte, encodedLen int32, msg *CopyDone) { - t.beginTrace(sender, encodedLen, "CopyDone") - t.finishTrace() + t.writeTrace(sender, encodedLen, "CopyDone", nil) } func (t *tracer) traceCopyFail(sender byte, encodedLen int32, msg *CopyFail) { - t.beginTrace(sender, encodedLen, "CopyFail") - fmt.Fprintf(t.buf, "\t %s", traceDoubleQuotedString([]byte(msg.Message))) - t.finishTrace() + t.writeTrace(sender, encodedLen, "CopyFail", func() { + fmt.Fprintf(t.buf, "\t %s", traceDoubleQuotedString([]byte(msg.Message))) + }) } func (t *tracer) traceCopyInResponse(sender byte, encodedLen int32, msg *CopyInResponse) { - t.beginTrace(sender, encodedLen, "CopyInResponse") - t.finishTrace() + t.writeTrace(sender, encodedLen, "CopyInResponse", nil) } func (t *tracer) traceCopyOutResponse(sender byte, encodedLen int32, msg *CopyOutResponse) { - t.beginTrace(sender, encodedLen, "CopyOutResponse") - t.finishTrace() + t.writeTrace(sender, encodedLen, "CopyOutResponse", nil) } func (t *tracer) traceDataRow(sender byte, encodedLen int32, msg *DataRow) { - t.beginTrace(sender, encodedLen, "DataRow") - fmt.Fprintf(t.buf, "\t %d", len(msg.Values)) - for _, v := range msg.Values { - if v == nil { - t.buf.WriteString(" -1") - } else { - fmt.Fprintf(t.buf, " %d %s", len(v), traceSingleQuotedString(v)) + t.writeTrace(sender, encodedLen, "DataRow", func() { + fmt.Fprintf(t.buf, "\t %d", len(msg.Values)) + for _, v := range msg.Values { + if v == nil { + t.buf.WriteString(" -1") + } else { + fmt.Fprintf(t.buf, " %d %s", len(v), traceSingleQuotedString(v)) + } } - } - t.finishTrace() + }) } func (t *tracer) traceDescribe(sender byte, encodedLen int32, msg *Describe) { - t.beginTrace(sender, encodedLen, "Describe") - fmt.Fprintf(t.buf, "\t %c %s", msg.ObjectType, traceDoubleQuotedString([]byte(msg.Name))) - t.finishTrace() + t.writeTrace(sender, encodedLen, "Describe", func() { + fmt.Fprintf(t.buf, "\t %c %s", msg.ObjectType, traceDoubleQuotedString([]byte(msg.Name))) + }) } func (t *tracer) traceEmptyQueryResponse(sender byte, encodedLen int32, msg *EmptyQueryResponse) { - t.beginTrace(sender, encodedLen, "EmptyQueryResponse") - t.finishTrace() + t.writeTrace(sender, encodedLen, "EmptyQueryResponse", nil) } func (t *tracer) traceErrorResponse(sender byte, encodedLen int32, msg *ErrorResponse) { - t.beginTrace(sender, encodedLen, "ErrorResponse") - t.finishTrace() + t.writeTrace(sender, encodedLen, "ErrorResponse", nil) } func (t *tracer) TraceQueryute(sender byte, encodedLen int32, msg *Execute) { - t.beginTrace(sender, encodedLen, "Execute") - fmt.Fprintf(t.buf, "\t %s %d", traceDoubleQuotedString([]byte(msg.Portal)), msg.MaxRows) - t.finishTrace() + t.writeTrace(sender, encodedLen, "Execute", func() { + fmt.Fprintf(t.buf, "\t %s %d", traceDoubleQuotedString([]byte(msg.Portal)), msg.MaxRows) + }) } func (t *tracer) traceFlush(sender byte, encodedLen int32, msg *Flush) { - t.beginTrace(sender, encodedLen, "Flush") - t.finishTrace() + t.writeTrace(sender, encodedLen, "Flush", nil) } func (t *tracer) traceFunctionCall(sender byte, encodedLen int32, msg *FunctionCall) { - t.beginTrace(sender, encodedLen, "FunctionCall") - t.finishTrace() + t.writeTrace(sender, encodedLen, "FunctionCall", nil) } func (t *tracer) traceFunctionCallResponse(sender byte, encodedLen int32, msg *FunctionCallResponse) { - t.beginTrace(sender, encodedLen, "FunctionCallResponse") - t.finishTrace() + t.writeTrace(sender, encodedLen, "FunctionCallResponse", nil) } func (t *tracer) traceGSSEncRequest(sender byte, encodedLen int32, msg *GSSEncRequest) { - t.beginTrace(sender, encodedLen, "GSSEncRequest") - t.finishTrace() + t.writeTrace(sender, encodedLen, "GSSEncRequest", nil) } func (t *tracer) traceNoData(sender byte, encodedLen int32, msg *NoData) { - t.beginTrace(sender, encodedLen, "NoData") - t.finishTrace() + t.writeTrace(sender, encodedLen, "NoData", nil) } func (t *tracer) traceNoticeResponse(sender byte, encodedLen int32, msg *NoticeResponse) { - t.beginTrace(sender, encodedLen, "NoticeResponse") - t.finishTrace() + t.writeTrace(sender, encodedLen, "NoticeResponse", nil) } func (t *tracer) traceNotificationResponse(sender byte, encodedLen int32, msg *NotificationResponse) { - t.beginTrace(sender, encodedLen, "NotificationResponse") - fmt.Fprintf(t.buf, "\t %d %s %s", msg.PID, traceDoubleQuotedString([]byte(msg.Channel)), traceDoubleQuotedString([]byte(msg.Payload))) - t.finishTrace() + t.writeTrace(sender, encodedLen, "NotificationResponse", func() { + fmt.Fprintf(t.buf, "\t %d %s %s", msg.PID, traceDoubleQuotedString([]byte(msg.Channel)), traceDoubleQuotedString([]byte(msg.Payload))) + }) } func (t *tracer) traceParameterDescription(sender byte, encodedLen int32, msg *ParameterDescription) { - t.beginTrace(sender, encodedLen, "ParameterDescription") - t.finishTrace() + t.writeTrace(sender, encodedLen, "ParameterDescription", nil) } func (t *tracer) traceParameterStatus(sender byte, encodedLen int32, msg *ParameterStatus) { - t.beginTrace(sender, encodedLen, "ParameterStatus") - fmt.Fprintf(t.buf, "\t %s %s", traceDoubleQuotedString([]byte(msg.Name)), traceDoubleQuotedString([]byte(msg.Value))) - t.finishTrace() + t.writeTrace(sender, encodedLen, "ParameterStatus", func() { + fmt.Fprintf(t.buf, "\t %s %s", traceDoubleQuotedString([]byte(msg.Name)), traceDoubleQuotedString([]byte(msg.Value))) + }) } func (t *tracer) traceParse(sender byte, encodedLen int32, msg *Parse) { - t.beginTrace(sender, encodedLen, "Parse") - fmt.Fprintf(t.buf, "\t %s %s %d", traceDoubleQuotedString([]byte(msg.Name)), traceDoubleQuotedString([]byte(msg.Query)), len(msg.ParameterOIDs)) - for _, oid := range msg.ParameterOIDs { - fmt.Fprintf(t.buf, " %d", oid) - } - t.finishTrace() + t.writeTrace(sender, encodedLen, "Parse", func() { + fmt.Fprintf(t.buf, "\t %s %s %d", traceDoubleQuotedString([]byte(msg.Name)), traceDoubleQuotedString([]byte(msg.Query)), len(msg.ParameterOIDs)) + for _, oid := range msg.ParameterOIDs { + fmt.Fprintf(t.buf, " %d", oid) + } + }) } func (t *tracer) traceParseComplete(sender byte, encodedLen int32, msg *ParseComplete) { - t.beginTrace(sender, encodedLen, "ParseComplete") - t.finishTrace() + t.writeTrace(sender, encodedLen, "ParseComplete", nil) } func (t *tracer) tracePortalSuspended(sender byte, encodedLen int32, msg *PortalSuspended) { - t.beginTrace(sender, encodedLen, "PortalSuspended") - t.finishTrace() + t.writeTrace(sender, encodedLen, "PortalSuspended", nil) } func (t *tracer) traceQuery(sender byte, encodedLen int32, msg *Query) { - t.beginTrace(sender, encodedLen, "Query") - fmt.Fprintf(t.buf, "\t %s", traceDoubleQuotedString([]byte(msg.String))) - t.finishTrace() + t.writeTrace(sender, encodedLen, "Query", func() { + fmt.Fprintf(t.buf, "\t %s", traceDoubleQuotedString([]byte(msg.String))) + }) } func (t *tracer) traceReadyForQuery(sender byte, encodedLen int32, msg *ReadyForQuery) { - t.beginTrace(sender, encodedLen, "ReadyForQuery") - fmt.Fprintf(t.buf, "\t %c", msg.TxStatus) - t.finishTrace() + t.writeTrace(sender, encodedLen, "ReadyForQuery", func() { + fmt.Fprintf(t.buf, "\t %c", msg.TxStatus) + }) } func (t *tracer) traceRowDescription(sender byte, encodedLen int32, msg *RowDescription) { - t.beginTrace(sender, encodedLen, "RowDescription") - fmt.Fprintf(t.buf, "\t %d", len(msg.Fields)) - for _, fd := range msg.Fields { - fmt.Fprintf(t.buf, ` %s %d %d %d %d %d %d`, traceDoubleQuotedString(fd.Name), fd.TableOID, fd.TableAttributeNumber, fd.DataTypeOID, fd.DataTypeSize, fd.TypeModifier, fd.Format) - } - t.finishTrace() + t.writeTrace(sender, encodedLen, "RowDescription", func() { + fmt.Fprintf(t.buf, "\t %d", len(msg.Fields)) + for _, fd := range msg.Fields { + fmt.Fprintf(t.buf, ` %s %d %d %d %d %d %d`, traceDoubleQuotedString(fd.Name), fd.TableOID, fd.TableAttributeNumber, fd.DataTypeOID, fd.DataTypeSize, fd.TypeModifier, fd.Format) + } + }) } func (t *tracer) traceSSLRequest(sender byte, encodedLen int32, msg *SSLRequest) { - t.beginTrace(sender, encodedLen, "SSLRequest") - t.finishTrace() + t.writeTrace(sender, encodedLen, "SSLRequest", nil) } func (t *tracer) traceStartupMessage(sender byte, encodedLen int32, msg *StartupMessage) { - t.beginTrace(sender, encodedLen, "StartupMessage") - t.finishTrace() + t.writeTrace(sender, encodedLen, "StartupMessage", nil) } func (t *tracer) traceSync(sender byte, encodedLen int32, msg *Sync) { - t.beginTrace(sender, encodedLen, "Sync") - t.finishTrace() + t.writeTrace(sender, encodedLen, "Sync", nil) } func (t *tracer) traceTerminate(sender byte, encodedLen int32, msg *Terminate) { - t.beginTrace(sender, encodedLen, "Terminate") - t.finishTrace() + t.writeTrace(sender, encodedLen, "Terminate", nil) } -func (t *tracer) beginTrace(sender byte, encodedLen int32, msgType string) { +func (t *tracer) writeTrace(sender byte, encodedLen int32, msgType string, writeDetails func()) { + t.mux.Lock() + defer t.mux.Unlock() + defer func() { + if t.buf.Cap() > 1024 { + t.buf = &bytes.Buffer{} + } else { + t.buf.Reset() + } + }() + if !t.SuppressTimestamps { now := time.Now() t.buf.WriteString(now.Format("2006-01-02 15:04:05.000000")) @@ -402,17 +382,13 @@ func (t *tracer) beginTrace(sender byte, encodedLen int32, msgType string) { t.buf.WriteString(msgType) t.buf.WriteByte('\t') t.buf.WriteString(strconv.FormatInt(int64(encodedLen), 10)) -} -func (t *tracer) finishTrace() { + if writeDetails != nil { + writeDetails() + } + t.buf.WriteByte('\n') t.buf.WriteTo(t.w) - - if t.buf.Cap() > 1024 { - t.buf = &bytes.Buffer{} - } else { - t.buf.Reset() - } } // traceDoubleQuotedString returns t.buf as a double-quoted string without any escaping. It is roughly equivalent to diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/array.go b/vendor/github.com/jackc/pgx/v5/pgtype/array.go index 0fa4c129b3..7376195689 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/array.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/array.go @@ -5,7 +5,6 @@ import ( "encoding/binary" "fmt" "io" - "reflect" "strconv" "strings" "unicode" @@ -363,38 +362,18 @@ func quoteArrayElement(src string) string { } func isSpace(ch byte) bool { - // see https://github.com/postgres/postgres/blob/REL_12_STABLE/src/backend/parser/scansup.c#L224 - return ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' || ch == '\f' + // see array_isspace: + // https://github.com/postgres/postgres/blob/master/src/backend/utils/adt/arrayfuncs.c + return ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' || ch == '\v' || ch == '\f' } func quoteArrayElementIfNeeded(src string) string { - if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || isSpace(src[0]) || isSpace(src[len(src)-1]) || strings.ContainsAny(src, `{},"\`) { + if src == "" || (len(src) == 4 && strings.EqualFold(src, "null")) || isSpace(src[0]) || isSpace(src[len(src)-1]) || strings.ContainsAny(src, `{},"\`) { return quoteArrayElement(src) } return src } -func findDimensionsFromValue(value reflect.Value, dimensions []ArrayDimension, elementsLength int) ([]ArrayDimension, int, bool) { - switch value.Kind() { - case reflect.Array: - fallthrough - case reflect.Slice: - length := value.Len() - if 0 == elementsLength { - elementsLength = length - } else { - elementsLength *= length - } - dimensions = append(dimensions, ArrayDimension{Length: int32(length), LowerBound: 1}) - for i := 0; i < length; i++ { - if d, l, ok := findDimensionsFromValue(value.Index(i), dimensions, elementsLength); ok { - return d, l, true - } - } - } - return dimensions, elementsLength, true -} - // Array represents a PostgreSQL array for T. It implements the ArrayGetter and ArraySetter interfaces. It preserves // PostgreSQL dimensions and custom lower bounds. Use FlatArray if these are not needed. type Array[T any] struct { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/bool.go b/vendor/github.com/jackc/pgx/v5/pgtype/bool.go index e7be27e2d6..71caffa74e 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/bool.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/bool.go @@ -1,10 +1,12 @@ package pgtype import ( + "bytes" "database/sql/driver" "encoding/json" "fmt" "strconv" + "strings" ) type BoolScanner interface { @@ -264,8 +266,8 @@ func (scanPlanTextAnyToBool) Scan(src []byte, dst any) error { return fmt.Errorf("cannot scan NULL into %T", dst) } - if len(src) != 1 { - return fmt.Errorf("invalid length for bool: %v", len(src)) + if len(src) == 0 { + return fmt.Errorf("cannot scan empty string into %T", dst) } p, ok := (dst).(*bool) @@ -273,7 +275,12 @@ func (scanPlanTextAnyToBool) Scan(src []byte, dst any) error { return ErrScanTargetTypeChanged } - *p = src[0] == 't' + v, err := planTextToBool(src) + if err != nil { + return err + } + + *p = v return nil } @@ -309,9 +316,28 @@ func (scanPlanTextAnyToBoolScanner) Scan(src []byte, dst any) error { return s.ScanBool(Bool{}) } - if len(src) != 1 { - return fmt.Errorf("invalid length for bool: %v", len(src)) + if len(src) == 0 { + return fmt.Errorf("cannot scan empty string into %T", dst) } - return s.ScanBool(Bool{Bool: src[0] == 't', Valid: true}) + v, err := planTextToBool(src) + if err != nil { + return err + } + + return s.ScanBool(Bool{Bool: v, Valid: true}) +} + +// https://www.postgresql.org/docs/11/datatype-boolean.html +func planTextToBool(src []byte) (bool, error) { + s := string(bytes.ToLower(bytes.TrimSpace(src))) + + switch { + case strings.HasPrefix("true", s), strings.HasPrefix("yes", s), s == "on", s == "1": + return true, nil + case strings.HasPrefix("false", s), strings.HasPrefix("no", s), strings.HasPrefix("off", s), s == "0": + return false, nil + default: + return false, fmt.Errorf("unknown boolean string representation %q", src) + } } diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/convert.go b/vendor/github.com/jackc/pgx/v5/pgtype/convert.go index 8a2afbe1e2..8a9cee9c3e 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/convert.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/convert.go @@ -1,377 +1,9 @@ package pgtype import ( - "database/sql" - "fmt" - "math" "reflect" - "time" ) -const ( - maxUint = ^uint(0) - maxInt = int(maxUint >> 1) - minInt = -maxInt - 1 -) - -// underlyingNumberType gets the underlying type that can be converted to Int2, Int4, Int8, Float4, or Float8 -func underlyingNumberType(val any) (any, bool) { - refVal := reflect.ValueOf(val) - - switch refVal.Kind() { - case reflect.Ptr: - if refVal.IsNil() { - return nil, false - } - convVal := refVal.Elem().Interface() - return convVal, true - case reflect.Int: - convVal := int(refVal.Int()) - return convVal, reflect.TypeOf(convVal) != refVal.Type() - case reflect.Int8: - convVal := int8(refVal.Int()) - return convVal, reflect.TypeOf(convVal) != refVal.Type() - case reflect.Int16: - convVal := int16(refVal.Int()) - return convVal, reflect.TypeOf(convVal) != refVal.Type() - case reflect.Int32: - convVal := int32(refVal.Int()) - return convVal, reflect.TypeOf(convVal) != refVal.Type() - case reflect.Int64: - convVal := int64(refVal.Int()) - return convVal, reflect.TypeOf(convVal) != refVal.Type() - case reflect.Uint: - convVal := uint(refVal.Uint()) - return convVal, reflect.TypeOf(convVal) != refVal.Type() - case reflect.Uint8: - convVal := uint8(refVal.Uint()) - return convVal, reflect.TypeOf(convVal) != refVal.Type() - case reflect.Uint16: - convVal := uint16(refVal.Uint()) - return convVal, reflect.TypeOf(convVal) != refVal.Type() - case reflect.Uint32: - convVal := uint32(refVal.Uint()) - return convVal, reflect.TypeOf(convVal) != refVal.Type() - case reflect.Uint64: - convVal := uint64(refVal.Uint()) - return convVal, reflect.TypeOf(convVal) != refVal.Type() - case reflect.Float32: - convVal := float32(refVal.Float()) - return convVal, reflect.TypeOf(convVal) != refVal.Type() - case reflect.Float64: - convVal := refVal.Float() - return convVal, reflect.TypeOf(convVal) != refVal.Type() - case reflect.String: - convVal := refVal.String() - return convVal, reflect.TypeOf(convVal) != refVal.Type() - } - - return nil, false -} - -// underlyingBoolType gets the underlying type that can be converted to Bool -func underlyingBoolType(val any) (any, bool) { - refVal := reflect.ValueOf(val) - - switch refVal.Kind() { - case reflect.Ptr: - if refVal.IsNil() { - return nil, false - } - convVal := refVal.Elem().Interface() - return convVal, true - case reflect.Bool: - convVal := refVal.Bool() - return convVal, reflect.TypeOf(convVal) != refVal.Type() - } - - return nil, false -} - -// underlyingBytesType gets the underlying type that can be converted to []byte -func underlyingBytesType(val any) (any, bool) { - refVal := reflect.ValueOf(val) - - switch refVal.Kind() { - case reflect.Ptr: - if refVal.IsNil() { - return nil, false - } - convVal := refVal.Elem().Interface() - return convVal, true - case reflect.Slice: - if refVal.Type().Elem().Kind() == reflect.Uint8 { - convVal := refVal.Bytes() - return convVal, reflect.TypeOf(convVal) != refVal.Type() - } - } - - return nil, false -} - -// underlyingStringType gets the underlying type that can be converted to String -func underlyingStringType(val any) (any, bool) { - refVal := reflect.ValueOf(val) - - switch refVal.Kind() { - case reflect.Ptr: - if refVal.IsNil() { - return nil, false - } - convVal := refVal.Elem().Interface() - return convVal, true - case reflect.String: - convVal := refVal.String() - return convVal, reflect.TypeOf(convVal) != refVal.Type() - } - - return nil, false -} - -// underlyingPtrType dereferences a pointer -func underlyingPtrType(val any) (any, bool) { - refVal := reflect.ValueOf(val) - - switch refVal.Kind() { - case reflect.Ptr: - if refVal.IsNil() { - return nil, false - } - convVal := refVal.Elem().Interface() - return convVal, true - } - - return nil, false -} - -// underlyingTimeType gets the underlying type that can be converted to time.Time -func underlyingTimeType(val any) (any, bool) { - refVal := reflect.ValueOf(val) - - switch refVal.Kind() { - case reflect.Ptr: - if refVal.IsNil() { - return nil, false - } - convVal := refVal.Elem().Interface() - return convVal, true - } - - timeType := reflect.TypeOf(time.Time{}) - if refVal.Type().ConvertibleTo(timeType) { - return refVal.Convert(timeType).Interface(), true - } - - return nil, false -} - -// underlyingUUIDType gets the underlying type that can be converted to [16]byte -func underlyingUUIDType(val any) (any, bool) { - refVal := reflect.ValueOf(val) - - switch refVal.Kind() { - case reflect.Ptr: - if refVal.IsNil() { - return time.Time{}, false - } - convVal := refVal.Elem().Interface() - return convVal, true - } - - uuidType := reflect.TypeOf([16]byte{}) - if refVal.Type().ConvertibleTo(uuidType) { - return refVal.Convert(uuidType).Interface(), true - } - - return nil, false -} - -// underlyingSliceType gets the underlying slice type -func underlyingSliceType(val any) (any, bool) { - refVal := reflect.ValueOf(val) - - switch refVal.Kind() { - case reflect.Ptr: - if refVal.IsNil() { - return nil, false - } - convVal := refVal.Elem().Interface() - return convVal, true - case reflect.Slice: - baseSliceType := reflect.SliceOf(refVal.Type().Elem()) - if refVal.Type().ConvertibleTo(baseSliceType) { - convVal := refVal.Convert(baseSliceType) - return convVal.Interface(), reflect.TypeOf(convVal.Interface()) != refVal.Type() - } - } - - return nil, false -} - -func int64AssignTo(srcVal int64, srcValid bool, dst any) error { - if srcValid { - switch v := dst.(type) { - case *int: - if srcVal < int64(minInt) { - return fmt.Errorf("%d is less than minimum value for int", srcVal) - } else if srcVal > int64(maxInt) { - return fmt.Errorf("%d is greater than maximum value for int", srcVal) - } - *v = int(srcVal) - case *int8: - if srcVal < math.MinInt8 { - return fmt.Errorf("%d is less than minimum value for int8", srcVal) - } else if srcVal > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for int8", srcVal) - } - *v = int8(srcVal) - case *int16: - if srcVal < math.MinInt16 { - return fmt.Errorf("%d is less than minimum value for int16", srcVal) - } else if srcVal > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for int16", srcVal) - } - *v = int16(srcVal) - case *int32: - if srcVal < math.MinInt32 { - return fmt.Errorf("%d is less than minimum value for int32", srcVal) - } else if srcVal > math.MaxInt32 { - return fmt.Errorf("%d is greater than maximum value for int32", srcVal) - } - *v = int32(srcVal) - case *int64: - if srcVal < math.MinInt64 { - return fmt.Errorf("%d is less than minimum value for int64", srcVal) - } else if srcVal > math.MaxInt64 { - return fmt.Errorf("%d is greater than maximum value for int64", srcVal) - } - *v = int64(srcVal) - case *uint: - if srcVal < 0 { - return fmt.Errorf("%d is less than zero for uint", srcVal) - } else if uint64(srcVal) > uint64(maxUint) { - return fmt.Errorf("%d is greater than maximum value for uint", srcVal) - } - *v = uint(srcVal) - case *uint8: - if srcVal < 0 { - return fmt.Errorf("%d is less than zero for uint8", srcVal) - } else if srcVal > math.MaxUint8 { - return fmt.Errorf("%d is greater than maximum value for uint8", srcVal) - } - *v = uint8(srcVal) - case *uint16: - if srcVal < 0 { - return fmt.Errorf("%d is less than zero for uint32", srcVal) - } else if srcVal > math.MaxUint16 { - return fmt.Errorf("%d is greater than maximum value for uint16", srcVal) - } - *v = uint16(srcVal) - case *uint32: - if srcVal < 0 { - return fmt.Errorf("%d is less than zero for uint32", srcVal) - } else if srcVal > math.MaxUint32 { - return fmt.Errorf("%d is greater than maximum value for uint32", srcVal) - } - *v = uint32(srcVal) - case *uint64: - if srcVal < 0 { - return fmt.Errorf("%d is less than zero for uint64", srcVal) - } - *v = uint64(srcVal) - case sql.Scanner: - return v.Scan(srcVal) - default: - if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { - el := v.Elem() - switch el.Kind() { - // if dst is a pointer to pointer, strip the pointer and try again - case reflect.Ptr: - if el.IsNil() { - // allocate destination - el.Set(reflect.New(el.Type().Elem())) - } - return int64AssignTo(srcVal, srcValid, el.Interface()) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - if el.OverflowInt(int64(srcVal)) { - return fmt.Errorf("cannot put %d into %T", srcVal, dst) - } - el.SetInt(int64(srcVal)) - return nil - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - if srcVal < 0 { - return fmt.Errorf("%d is less than zero for %T", srcVal, dst) - } - if el.OverflowUint(uint64(srcVal)) { - return fmt.Errorf("cannot put %d into %T", srcVal, dst) - } - el.SetUint(uint64(srcVal)) - return nil - } - } - return fmt.Errorf("cannot assign %v into %T", srcVal, dst) - } - return nil - } - - // if dst is a pointer to pointer and srcStatus is not Valid, nil it out - if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { - el := v.Elem() - if el.Kind() == reflect.Ptr { - el.Set(reflect.Zero(el.Type())) - return nil - } - } - - return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcValid, dst) -} - -func float64AssignTo(srcVal float64, srcValid bool, dst any) error { - if srcValid { - switch v := dst.(type) { - case *float32: - *v = float32(srcVal) - case *float64: - *v = srcVal - default: - if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { - el := v.Elem() - switch el.Kind() { - // if dst is a type alias of a float32 or 64, set dst val - case reflect.Float32, reflect.Float64: - el.SetFloat(srcVal) - return nil - // if dst is a pointer to pointer, strip the pointer and try again - case reflect.Ptr: - if el.IsNil() { - // allocate destination - el.Set(reflect.New(el.Type().Elem())) - } - return float64AssignTo(srcVal, srcValid, el.Interface()) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - i64 := int64(srcVal) - if float64(i64) == srcVal { - return int64AssignTo(i64, srcValid, dst) - } - } - } - return fmt.Errorf("cannot assign %v into %T", srcVal, dst) - } - return nil - } - - // if dst is a pointer to pointer and srcStatus is not Valid, nil it out - if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { - el := v.Elem() - if el.Kind() == reflect.Ptr { - el.Set(reflect.Zero(el.Type())) - return nil - } - } - - return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcValid, dst) -} - func NullAssignTo(dst any) error { dstPtr := reflect.ValueOf(dst) diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/hstore.go b/vendor/github.com/jackc/pgx/v5/pgtype/hstore.go index 4743643e5a..2f34f4c9e2 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/hstore.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/hstore.go @@ -1,14 +1,11 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "errors" "fmt" "strings" - "unicode" - "unicode/utf8" "github.com/jackc/pgx/v5/internal/pgio" ) @@ -43,7 +40,7 @@ func (h *Hstore) Scan(src any) error { switch src := src.(type) { case string: - return scanPlanTextAnyToHstoreScanner{}.Scan([]byte(src), h) + return scanPlanTextAnyToHstoreScanner{}.scanString(src, h) } return fmt.Errorf("cannot scan %T", src) @@ -124,8 +121,15 @@ func (encodePlanHstoreCodecText) Encode(value any, buf []byte) (newBuf []byte, e return nil, err } - if hstore == nil { - return nil, nil + if len(hstore) == 0 { + // distinguish between empty and nil: Not strictly required by Postgres, since its protocol + // explicitly marks NULL column values separately. However, the Binary codec does this, and + // this means we can "round trip" Encode and Scan without data loss. + // nil: []byte(nil); empty: []byte{} + if hstore == nil { + return nil, nil + } + return []byte{}, nil } firstPair := true @@ -134,16 +138,23 @@ func (encodePlanHstoreCodecText) Encode(value any, buf []byte) (newBuf []byte, e if firstPair { firstPair = false } else { - buf = append(buf, ',') + buf = append(buf, ',', ' ') } - buf = append(buf, quoteHstoreElementIfNeeded(k)...) + // unconditionally quote hstore keys/values like Postgres does + // this avoids a Mac OS X Postgres hstore parsing bug: + // https://www.postgresql.org/message-id/CA%2BHWA9awUW0%2BRV_gO9r1ABZwGoZxPztcJxPy8vMFSTbTfi4jig%40mail.gmail.com + buf = append(buf, '"') + buf = append(buf, quoteArrayReplacer.Replace(k)...) + buf = append(buf, '"') buf = append(buf, "=>"...) if v == nil { buf = append(buf, "NULL"...) } else { - buf = append(buf, quoteHstoreElementIfNeeded(*v)...) + buf = append(buf, '"') + buf = append(buf, quoteArrayReplacer.Replace(*v)...) + buf = append(buf, '"') } } @@ -174,25 +185,28 @@ func (scanPlanBinaryHstoreToHstoreScanner) Scan(src []byte, dst any) error { scanner := (dst).(HstoreScanner) if src == nil { - return scanner.ScanHstore(Hstore{}) + return scanner.ScanHstore(Hstore(nil)) } rp := 0 - if len(src[rp:]) < 4 { + const uint32Len = 4 + if len(src[rp:]) < uint32Len { return fmt.Errorf("hstore incomplete %v", src) } pairCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 + rp += uint32Len hstore := make(Hstore, pairCount) + // one allocation for all *string, rather than one per string, just like text parsing + valueStrings := make([]string, pairCount) for i := 0; i < pairCount; i++ { - if len(src[rp:]) < 4 { + if len(src[rp:]) < uint32Len { return fmt.Errorf("hstore incomplete %v", src) } keyLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 + rp += uint32Len if len(src[rp:]) < keyLen { return fmt.Errorf("hstore incomplete %v", src) @@ -200,26 +214,17 @@ func (scanPlanBinaryHstoreToHstoreScanner) Scan(src []byte, dst any) error { key := string(src[rp : rp+keyLen]) rp += keyLen - if len(src[rp:]) < 4 { + if len(src[rp:]) < uint32Len { return fmt.Errorf("hstore incomplete %v", src) } valueLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) rp += 4 - var valueBuf []byte if valueLen >= 0 { - valueBuf = src[rp : rp+valueLen] + valueStrings[i] = string(src[rp : rp+valueLen]) rp += valueLen - } - var value Text - err := scanPlanTextAnyToTextScanner{}.Scan(valueBuf, &value) - if err != nil { - return err - } - - if value.Valid { - hstore[key] = &value.String + hstore[key] = &valueStrings[i] } else { hstore[key] = nil } @@ -230,28 +235,22 @@ func (scanPlanBinaryHstoreToHstoreScanner) Scan(src []byte, dst any) error { type scanPlanTextAnyToHstoreScanner struct{} -func (scanPlanTextAnyToHstoreScanner) Scan(src []byte, dst any) error { +func (s scanPlanTextAnyToHstoreScanner) Scan(src []byte, dst any) error { scanner := (dst).(HstoreScanner) if src == nil { - return scanner.ScanHstore(Hstore{}) + return scanner.ScanHstore(Hstore(nil)) } + return s.scanString(string(src), scanner) +} - keys, values, err := parseHstore(string(src)) +// scanString does not return nil hstore values because string cannot be nil. +func (scanPlanTextAnyToHstoreScanner) scanString(src string, scanner HstoreScanner) error { + hstore, err := parseHstore(src) if err != nil { return err } - - m := make(Hstore, len(keys)) - for i := range keys { - if values[i].Valid { - m[keys[i]] = &values[i].String - } else { - m[keys[i]] = nil - } - } - - return scanner.ScanHstore(m) + return scanner.ScanHstore(hstore) } func (c HstoreCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { @@ -271,191 +270,217 @@ func (c HstoreCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) ( return hstore, nil } -var quoteHstoreReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) +type hstoreParser struct { + str string + pos int + nextBackslash int +} + +func newHSP(in string) *hstoreParser { + return &hstoreParser{ + pos: 0, + str: in, + nextBackslash: strings.IndexByte(in, '\\'), + } +} -func quoteHstoreElement(src string) string { - return `"` + quoteArrayReplacer.Replace(src) + `"` +func (p *hstoreParser) atEnd() bool { + return p.pos >= len(p.str) } -func quoteHstoreElementIfNeeded(src string) string { - if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || strings.ContainsAny(src, ` {},"\=>`) { - return quoteArrayElement(src) +// consume returns the next byte of the string, or end if the string is done. +func (p *hstoreParser) consume() (b byte, end bool) { + if p.pos >= len(p.str) { + return 0, true } - return src + b = p.str[p.pos] + p.pos++ + return b, false } -const ( - hsPre = iota - hsKey - hsSep - hsVal - hsNul - hsNext -) +func unexpectedByteErr(actualB byte, expectedB byte) error { + return fmt.Errorf("expected '%c' ('%#v'); found '%c' ('%#v')", expectedB, expectedB, actualB, actualB) +} -type hstoreParser struct { - str string - pos int +// consumeExpectedByte consumes expectedB from the string, or returns an error. +func (p *hstoreParser) consumeExpectedByte(expectedB byte) error { + nextB, end := p.consume() + if end { + return fmt.Errorf("expected '%c' ('%#v'); found end", expectedB, expectedB) + } + if nextB != expectedB { + return unexpectedByteErr(nextB, expectedB) + } + return nil } -func newHSP(in string) *hstoreParser { - return &hstoreParser{ - pos: 0, - str: in, +// consumeExpected2 consumes two expected bytes or returns an error. +// This was a bit faster than using a string argument (better inlining? Not sure). +func (p *hstoreParser) consumeExpected2(one byte, two byte) error { + if p.pos+2 > len(p.str) { + return errors.New("unexpected end of string") + } + if p.str[p.pos] != one { + return unexpectedByteErr(p.str[p.pos], one) } + if p.str[p.pos+1] != two { + return unexpectedByteErr(p.str[p.pos+1], two) + } + p.pos += 2 + return nil } -func (p *hstoreParser) Consume() (r rune, end bool) { - if p.pos >= len(p.str) { - end = true - return +var errEOSInQuoted = errors.New(`found end before closing double-quote ('"')`) + +// consumeDoubleQuoted consumes a double-quoted string from p. The double quote must have been +// parsed already. This copies the string from the backing string so it can be garbage collected. +func (p *hstoreParser) consumeDoubleQuoted() (string, error) { + // fast path: assume most keys/values do not contain escapes + nextDoubleQuote := strings.IndexByte(p.str[p.pos:], '"') + if nextDoubleQuote == -1 { + return "", errEOSInQuoted + } + nextDoubleQuote += p.pos + if p.nextBackslash == -1 || p.nextBackslash > nextDoubleQuote { + // clone the string from the source string to ensure it can be garbage collected separately + // TODO: use strings.Clone on Go 1.20; this could get optimized away + s := strings.Clone(p.str[p.pos:nextDoubleQuote]) + p.pos = nextDoubleQuote + 1 + return s, nil + } + + // slow path: string contains escapes + s, err := p.consumeDoubleQuotedWithEscapes(p.nextBackslash) + p.nextBackslash = strings.IndexByte(p.str[p.pos:], '\\') + if p.nextBackslash != -1 { + p.nextBackslash += p.pos } - r, w := utf8.DecodeRuneInString(p.str[p.pos:]) - p.pos += w - return + return s, err } -func (p *hstoreParser) Peek() (r rune, end bool) { - if p.pos >= len(p.str) { - end = true - return +// consumeDoubleQuotedWithEscapes consumes a double-quoted string containing escapes, starting +// at p.pos, and with the first backslash at firstBackslash. This copies the string so it can be +// garbage collected separately. +func (p *hstoreParser) consumeDoubleQuotedWithEscapes(firstBackslash int) (string, error) { + // copy the prefix that does not contain backslashes + var builder strings.Builder + builder.WriteString(p.str[p.pos:firstBackslash]) + + // skip to the backslash + p.pos = firstBackslash + + // copy bytes until the end, unescaping backslashes + for { + nextB, end := p.consume() + if end { + return "", errEOSInQuoted + } else if nextB == '"' { + break + } else if nextB == '\\' { + // escape: skip the backslash and copy the char + nextB, end = p.consume() + if end { + return "", errEOSInQuoted + } + if !(nextB == '\\' || nextB == '"') { + return "", fmt.Errorf("unexpected escape in quoted string: found '%#v'", nextB) + } + builder.WriteByte(nextB) + } else { + // normal byte: copy it + builder.WriteByte(nextB) + } } - r, _ = utf8.DecodeRuneInString(p.str[p.pos:]) - return + return builder.String(), nil } -// parseHstore parses the string representation of an hstore column (the same -// you would get from an ordinary SELECT) into two slices of keys and values. it -// is used internally in the default parsing of hstores. -func parseHstore(s string) (k []string, v []Text, err error) { - if s == "" { - return +// consumePairSeparator consumes the Hstore pair separator ", " or returns an error. +func (p *hstoreParser) consumePairSeparator() error { + return p.consumeExpected2(',', ' ') +} + +// consumeKVSeparator consumes the Hstore key/value separator "=>" or returns an error. +func (p *hstoreParser) consumeKVSeparator() error { + return p.consumeExpected2('=', '>') +} + +// consumeDoubleQuotedOrNull consumes the Hstore key/value separator "=>" or returns an error. +func (p *hstoreParser) consumeDoubleQuotedOrNull() (Text, error) { + // peek at the next byte + if p.atEnd() { + return Text{}, errors.New("found end instead of value") + } + next := p.str[p.pos] + if next == 'N' { + // must be the exact string NULL: use consumeExpected2 twice + err := p.consumeExpected2('N', 'U') + if err != nil { + return Text{}, err + } + err = p.consumeExpected2('L', 'L') + if err != nil { + return Text{}, err + } + return Text{String: "", Valid: false}, nil + } else if next != '"' { + return Text{}, unexpectedByteErr(next, '"') + } + + // skip the double quote + p.pos += 1 + s, err := p.consumeDoubleQuoted() + if err != nil { + return Text{}, err } + return Text{String: s, Valid: true}, nil +} - buf := bytes.Buffer{} - keys := []string{} - values := []Text{} +func parseHstore(s string) (Hstore, error) { p := newHSP(s) - r, end := p.Consume() - state := hsPre - - for !end { - switch state { - case hsPre: - if r == '"' { - state = hsKey - } else { - err = errors.New("String does not begin with \"") - } - case hsKey: - switch r { - case '"': //End of the key - keys = append(keys, buf.String()) - buf = bytes.Buffer{} - state = hsSep - case '\\': //Potential escaped character - n, end := p.Consume() - switch { - case end: - err = errors.New("Found EOS in key, expecting character or \"") - case n == '"', n == '\\': - buf.WriteRune(n) - default: - buf.WriteRune(r) - buf.WriteRune(n) - } - default: //Any other character - buf.WriteRune(r) - } - case hsSep: - if r == '=' { - r, end = p.Consume() - switch { - case end: - err = errors.New("Found EOS after '=', expecting '>'") - case r == '>': - r, end = p.Consume() - switch { - case end: - err = errors.New("Found EOS after '=>', expecting '\"' or 'NULL'") - case r == '"': - state = hsVal - case r == 'N': - state = hsNul - default: - err = fmt.Errorf("Invalid character '%c' after '=>', expecting '\"' or 'NULL'", r) - } - default: - err = fmt.Errorf("Invalid character after '=', expecting '>'") - } - } else { - err = fmt.Errorf("Invalid character '%c' after value, expecting '='", r) - } - case hsVal: - switch r { - case '"': //End of the value - values = append(values, Text{String: buf.String(), Valid: true}) - buf = bytes.Buffer{} - state = hsNext - case '\\': //Potential escaped character - n, end := p.Consume() - switch { - case end: - err = errors.New("Found EOS in key, expecting character or \"") - case n == '"', n == '\\': - buf.WriteRune(n) - default: - buf.WriteRune(r) - buf.WriteRune(n) - } - default: //Any other character - buf.WriteRune(r) - } - case hsNul: - nulBuf := make([]rune, 3) - nulBuf[0] = r - for i := 1; i < 3; i++ { - r, end = p.Consume() - if end { - err = errors.New("Found EOS in NULL value") - return - } - nulBuf[i] = r - } - if nulBuf[0] == 'U' && nulBuf[1] == 'L' && nulBuf[2] == 'L' { - values = append(values, Text{}) - state = hsNext - } else { - err = fmt.Errorf("Invalid NULL value: 'N%s'", string(nulBuf)) - } - case hsNext: - if r == ',' { - r, end = p.Consume() - switch { - case end: - err = errors.New("Found EOS after ',', expcting space") - case (unicode.IsSpace(r)): - r, end = p.Consume() - state = hsKey - default: - err = fmt.Errorf("Invalid character '%c' after ', ', expecting \"", r) - } - } else { - err = fmt.Errorf("Invalid character '%c' after value, expecting ','", r) + // This is an over-estimate of the number of key/value pairs. Use '>' because I am guessing it + // is less likely to occur in keys/values than '=' or ','. + numPairsEstimate := strings.Count(s, ">") + // makes one allocation of strings for the entire Hstore, rather than one allocation per value. + valueStrings := make([]string, 0, numPairsEstimate) + result := make(Hstore, numPairsEstimate) + first := true + for !p.atEnd() { + if !first { + err := p.consumePairSeparator() + if err != nil { + return nil, err } + } else { + first = false } + err := p.consumeExpectedByte('"') if err != nil { - return + return nil, err + } + + key, err := p.consumeDoubleQuoted() + if err != nil { + return nil, err + } + + err = p.consumeKVSeparator() + if err != nil { + return nil, err + } + + value, err := p.consumeDoubleQuotedOrNull() + if err != nil { + return nil, err + } + if value.Valid { + valueStrings = append(valueStrings, value.String) + result[key] = &valueStrings[len(valueStrings)-1] + } else { + result[key] = nil } - r, end = p.Consume() - } - if state != hsNext { - err = errors.New("Improperly formatted hstore") - return } - k = keys - v = values - return + + return result, nil } diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/json.go b/vendor/github.com/jackc/pgx/v5/pgtype/json.go index 69861bf88d..d332dd0db1 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/json.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/json.go @@ -25,6 +25,13 @@ func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Encod case []byte: return encodePlanJSONCodecEitherFormatByteSlice{} + // Must come before trying wrap encode plans because a pointer to a struct may be unwrapped to a struct that can be + // marshalled. + // + // https://github.com/jackc/pgx/issues/1681 + case json.Marshaler: + return encodePlanJSONCodecEitherFormatMarshal{} + // Cannot rely on driver.Valuer being handled later because anything can be marshalled. // // https://github.com/jackc/pgx/issues/1430 @@ -85,6 +92,23 @@ func (JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan switch target.(type) { case *string: return scanPlanAnyToString{} + + case **string: + // This is to fix **string scanning. It seems wrong to special case **string, but it's not clear what a better + // solution would be. + // + // https://github.com/jackc/pgx/issues/1470 -- **string + // https://github.com/jackc/pgx/issues/1691 -- ** anything else + + if wrapperPlan, nextDst, ok := TryPointerPointerScanPlan(target); ok { + if nextPlan := m.planScan(oid, format, nextDst); nextPlan != nil { + if _, failed := nextPlan.(*scanPlanFail); !failed { + wrapperPlan.SetNext(nextPlan) + return wrapperPlan + } + } + } + case *[]byte: return scanPlanJSONToByteSlice{} case BytesScanner: @@ -97,19 +121,6 @@ func (JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan return &scanPlanSQLScanner{formatCode: format} } - // This is to fix **string scanning. It seems wrong to special case sql.Scanner and pointer to pointer, but it's not - // clear what a better solution would be. - // - // https://github.com/jackc/pgx/issues/1470 - if wrapperPlan, nextDst, ok := TryPointerPointerScanPlan(target); ok { - if nextPlan := m.planScan(oid, format, nextDst); nextPlan != nil { - if _, failed := nextPlan.(*scanPlanFail); !failed { - wrapperPlan.SetNext(nextPlan) - return wrapperPlan - } - } - } - return scanPlanJSONToJSONUnmarshal{} } @@ -150,7 +161,7 @@ func (scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error { if dstValue.Kind() == reflect.Ptr { el := dstValue.Elem() switch el.Kind() { - case reflect.Ptr, reflect.Slice, reflect.Map: + case reflect.Ptr, reflect.Slice, reflect.Map, reflect.Interface: el.Set(reflect.Zero(el.Type())) return nil } diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/numeric.go b/vendor/github.com/jackc/pgx/v5/pgtype/numeric.go index 376c03fe85..0e58fd0765 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/numeric.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/numeric.go @@ -33,23 +33,6 @@ var big10 *big.Int = big.NewInt(10) var big100 *big.Int = big.NewInt(100) var big1000 *big.Int = big.NewInt(1000) -var bigMaxInt8 *big.Int = big.NewInt(math.MaxInt8) -var bigMinInt8 *big.Int = big.NewInt(math.MinInt8) -var bigMaxInt16 *big.Int = big.NewInt(math.MaxInt16) -var bigMinInt16 *big.Int = big.NewInt(math.MinInt16) -var bigMaxInt32 *big.Int = big.NewInt(math.MaxInt32) -var bigMinInt32 *big.Int = big.NewInt(math.MinInt32) -var bigMaxInt64 *big.Int = big.NewInt(math.MaxInt64) -var bigMinInt64 *big.Int = big.NewInt(math.MinInt64) -var bigMaxInt *big.Int = big.NewInt(int64(maxInt)) -var bigMinInt *big.Int = big.NewInt(int64(minInt)) - -var bigMaxUint8 *big.Int = big.NewInt(math.MaxUint8) -var bigMaxUint16 *big.Int = big.NewInt(math.MaxUint16) -var bigMaxUint32 *big.Int = big.NewInt(math.MaxUint32) -var bigMaxUint64 *big.Int = (&big.Int{}).SetUint64(uint64(math.MaxUint64)) -var bigMaxUint *big.Int = (&big.Int{}).SetUint64(uint64(maxUint)) - var bigNBase *big.Int = big.NewInt(nbase) var bigNBaseX2 *big.Int = big.NewInt(nbase * nbase) var bigNBaseX3 *big.Int = big.NewInt(nbase * nbase * nbase) @@ -161,20 +144,20 @@ func (n *Numeric) toBigInt() (*big.Int, error) { } func parseNumericString(str string) (n *big.Int, exp int32, err error) { - parts := strings.SplitN(str, ".", 2) - digits := strings.Join(parts, "") + idx := strings.IndexByte(str, '.') - if len(parts) > 1 { - exp = int32(-len(parts[1])) - } else { - for len(digits) > 1 && digits[len(digits)-1] == '0' && digits[len(digits)-2] != '-' { - digits = digits[:len(digits)-1] + if idx == -1 { + for len(str) > 1 && str[len(str)-1] == '0' && str[len(str)-2] != '-' { + str = str[:len(str)-1] exp++ } + } else { + exp = int32(-(len(str) - idx - 1)) + str = str[:idx] + str[idx+1:] } accum := &big.Int{} - if _, ok := accum.SetString(digits, 10); !ok { + if _, ok := accum.SetString(str, 10); !ok { return nil, 0, fmt.Errorf("%s is not a number", str) } @@ -241,11 +224,11 @@ func (n Numeric) MarshalJSON() ([]byte, error) { } func (n *Numeric) UnmarshalJSON(src []byte) error { - if bytes.Compare(src, []byte(`null`)) == 0 { + if bytes.Equal(src, []byte(`null`)) { *n = Numeric{} return nil } - if bytes.Compare(src, []byte(`"NaN"`)) == 0 { + if bytes.Equal(src, []byte(`"NaN"`)) { *n = Numeric{NaN: true, Valid: true} return nil } diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/pgtype.go b/vendor/github.com/jackc/pgx/v5/pgtype/pgtype.go index 83b349cee7..59d833a19e 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/pgtype.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/pgtype.go @@ -44,7 +44,7 @@ const ( MacaddrOID = 829 InetOID = 869 BoolArrayOID = 1000 - QCharArrayOID = 1003 + QCharArrayOID = 1002 NameArrayOID = 1003 Int2ArrayOID = 1005 Int4ArrayOID = 1007 @@ -147,7 +147,7 @@ const ( BinaryFormatCode = 1 ) -// A Codec converts between Go and PostgreSQL values. +// A Codec converts between Go and PostgreSQL values. A Codec must not be mutated after it is registered with a Map. type Codec interface { // FormatSupported returns true if the format is supported. FormatSupported(int16) bool @@ -178,6 +178,7 @@ func (e *nullAssignmentError) Error() string { return fmt.Sprintf("cannot assign NULL to %T", e.dst) } +// Type represents a PostgreSQL data type. It must not be mutated after it is registered with a Map. type Type struct { Codec Codec Name string @@ -211,7 +212,9 @@ type Map struct { } func NewMap() *Map { - m := &Map{ + defaultMapInitOnce.Do(initDefaultMap) + + return &Map{ oidToType: make(map[uint32]*Type), nameToType: make(map[string]*Type), reflectTypeToName: make(map[reflect.Type]string), @@ -240,184 +243,9 @@ func NewMap() *Map { TryWrapPtrArrayScanPlan, }, } - - // Base types - m.RegisterType(&Type{Name: "aclitem", OID: ACLItemOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) - m.RegisterType(&Type{Name: "bit", OID: BitOID, Codec: BitsCodec{}}) - m.RegisterType(&Type{Name: "bool", OID: BoolOID, Codec: BoolCodec{}}) - m.RegisterType(&Type{Name: "box", OID: BoxOID, Codec: BoxCodec{}}) - m.RegisterType(&Type{Name: "bpchar", OID: BPCharOID, Codec: TextCodec{}}) - m.RegisterType(&Type{Name: "bytea", OID: ByteaOID, Codec: ByteaCodec{}}) - m.RegisterType(&Type{Name: "char", OID: QCharOID, Codec: QCharCodec{}}) - m.RegisterType(&Type{Name: "cid", OID: CIDOID, Codec: Uint32Codec{}}) - m.RegisterType(&Type{Name: "cidr", OID: CIDROID, Codec: InetCodec{}}) - m.RegisterType(&Type{Name: "circle", OID: CircleOID, Codec: CircleCodec{}}) - m.RegisterType(&Type{Name: "date", OID: DateOID, Codec: DateCodec{}}) - m.RegisterType(&Type{Name: "float4", OID: Float4OID, Codec: Float4Codec{}}) - m.RegisterType(&Type{Name: "float8", OID: Float8OID, Codec: Float8Codec{}}) - m.RegisterType(&Type{Name: "inet", OID: InetOID, Codec: InetCodec{}}) - m.RegisterType(&Type{Name: "int2", OID: Int2OID, Codec: Int2Codec{}}) - m.RegisterType(&Type{Name: "int4", OID: Int4OID, Codec: Int4Codec{}}) - m.RegisterType(&Type{Name: "int8", OID: Int8OID, Codec: Int8Codec{}}) - m.RegisterType(&Type{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}}) - m.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: JSONCodec{}}) - m.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: JSONBCodec{}}) - m.RegisterType(&Type{Name: "jsonpath", OID: JSONPathOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) - m.RegisterType(&Type{Name: "line", OID: LineOID, Codec: LineCodec{}}) - m.RegisterType(&Type{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}}) - m.RegisterType(&Type{Name: "macaddr", OID: MacaddrOID, Codec: MacaddrCodec{}}) - m.RegisterType(&Type{Name: "name", OID: NameOID, Codec: TextCodec{}}) - m.RegisterType(&Type{Name: "numeric", OID: NumericOID, Codec: NumericCodec{}}) - m.RegisterType(&Type{Name: "oid", OID: OIDOID, Codec: Uint32Codec{}}) - m.RegisterType(&Type{Name: "path", OID: PathOID, Codec: PathCodec{}}) - m.RegisterType(&Type{Name: "point", OID: PointOID, Codec: PointCodec{}}) - m.RegisterType(&Type{Name: "polygon", OID: PolygonOID, Codec: PolygonCodec{}}) - m.RegisterType(&Type{Name: "record", OID: RecordOID, Codec: RecordCodec{}}) - m.RegisterType(&Type{Name: "text", OID: TextOID, Codec: TextCodec{}}) - m.RegisterType(&Type{Name: "tid", OID: TIDOID, Codec: TIDCodec{}}) - m.RegisterType(&Type{Name: "time", OID: TimeOID, Codec: TimeCodec{}}) - m.RegisterType(&Type{Name: "timestamp", OID: TimestampOID, Codec: TimestampCodec{}}) - m.RegisterType(&Type{Name: "timestamptz", OID: TimestamptzOID, Codec: TimestamptzCodec{}}) - m.RegisterType(&Type{Name: "unknown", OID: UnknownOID, Codec: TextCodec{}}) - m.RegisterType(&Type{Name: "uuid", OID: UUIDOID, Codec: UUIDCodec{}}) - m.RegisterType(&Type{Name: "varbit", OID: VarbitOID, Codec: BitsCodec{}}) - m.RegisterType(&Type{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}}) - m.RegisterType(&Type{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}}) - - // Range types - m.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementType: m.oidToType[DateOID]}}) - m.RegisterType(&Type{Name: "int4range", OID: Int4rangeOID, Codec: &RangeCodec{ElementType: m.oidToType[Int4OID]}}) - m.RegisterType(&Type{Name: "int8range", OID: Int8rangeOID, Codec: &RangeCodec{ElementType: m.oidToType[Int8OID]}}) - m.RegisterType(&Type{Name: "numrange", OID: NumrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[NumericOID]}}) - m.RegisterType(&Type{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[TimestampOID]}}) - m.RegisterType(&Type{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[TimestamptzOID]}}) - - // Multirange types - m.RegisterType(&Type{Name: "datemultirange", OID: DatemultirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[DaterangeOID]}}) - m.RegisterType(&Type{Name: "int4multirange", OID: Int4multirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[Int4rangeOID]}}) - m.RegisterType(&Type{Name: "int8multirange", OID: Int8multirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[Int8rangeOID]}}) - m.RegisterType(&Type{Name: "nummultirange", OID: NummultirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[NumrangeOID]}}) - m.RegisterType(&Type{Name: "tsmultirange", OID: TsmultirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[TsrangeOID]}}) - m.RegisterType(&Type{Name: "tstzmultirange", OID: TstzmultirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[TstzrangeOID]}}) - - // Array types - m.RegisterType(&Type{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[ACLItemOID]}}) - m.RegisterType(&Type{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BitOID]}}) - m.RegisterType(&Type{Name: "_bool", OID: BoolArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BoolOID]}}) - m.RegisterType(&Type{Name: "_box", OID: BoxArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BoxOID]}}) - m.RegisterType(&Type{Name: "_bpchar", OID: BPCharArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BPCharOID]}}) - m.RegisterType(&Type{Name: "_bytea", OID: ByteaArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[ByteaOID]}}) - m.RegisterType(&Type{Name: "_char", OID: QCharArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[QCharOID]}}) - m.RegisterType(&Type{Name: "_cid", OID: CIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[CIDOID]}}) - m.RegisterType(&Type{Name: "_cidr", OID: CIDRArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[CIDROID]}}) - m.RegisterType(&Type{Name: "_circle", OID: CircleArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[CircleOID]}}) - m.RegisterType(&Type{Name: "_date", OID: DateArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[DateOID]}}) - m.RegisterType(&Type{Name: "_daterange", OID: DaterangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[DaterangeOID]}}) - m.RegisterType(&Type{Name: "_float4", OID: Float4ArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Float4OID]}}) - m.RegisterType(&Type{Name: "_float8", OID: Float8ArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Float8OID]}}) - m.RegisterType(&Type{Name: "_inet", OID: InetArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[InetOID]}}) - m.RegisterType(&Type{Name: "_int2", OID: Int2ArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Int2OID]}}) - m.RegisterType(&Type{Name: "_int4", OID: Int4ArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Int4OID]}}) - m.RegisterType(&Type{Name: "_int4range", OID: Int4rangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Int4rangeOID]}}) - m.RegisterType(&Type{Name: "_int8", OID: Int8ArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Int8OID]}}) - m.RegisterType(&Type{Name: "_int8range", OID: Int8rangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Int8rangeOID]}}) - m.RegisterType(&Type{Name: "_interval", OID: IntervalArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[IntervalOID]}}) - m.RegisterType(&Type{Name: "_json", OID: JSONArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[JSONOID]}}) - m.RegisterType(&Type{Name: "_jsonb", OID: JSONBArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[JSONBOID]}}) - m.RegisterType(&Type{Name: "_jsonpath", OID: JSONPathArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[JSONPathOID]}}) - m.RegisterType(&Type{Name: "_line", OID: LineArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[LineOID]}}) - m.RegisterType(&Type{Name: "_lseg", OID: LsegArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[LsegOID]}}) - m.RegisterType(&Type{Name: "_macaddr", OID: MacaddrArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[MacaddrOID]}}) - m.RegisterType(&Type{Name: "_name", OID: NameArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[NameOID]}}) - m.RegisterType(&Type{Name: "_numeric", OID: NumericArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[NumericOID]}}) - m.RegisterType(&Type{Name: "_numrange", OID: NumrangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[NumrangeOID]}}) - m.RegisterType(&Type{Name: "_oid", OID: OIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[OIDOID]}}) - m.RegisterType(&Type{Name: "_path", OID: PathArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[PathOID]}}) - m.RegisterType(&Type{Name: "_point", OID: PointArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[PointOID]}}) - m.RegisterType(&Type{Name: "_polygon", OID: PolygonArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[PolygonOID]}}) - m.RegisterType(&Type{Name: "_record", OID: RecordArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[RecordOID]}}) - m.RegisterType(&Type{Name: "_text", OID: TextArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TextOID]}}) - m.RegisterType(&Type{Name: "_tid", OID: TIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TIDOID]}}) - m.RegisterType(&Type{Name: "_time", OID: TimeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TimeOID]}}) - m.RegisterType(&Type{Name: "_timestamp", OID: TimestampArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TimestampOID]}}) - m.RegisterType(&Type{Name: "_timestamptz", OID: TimestamptzArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TimestamptzOID]}}) - m.RegisterType(&Type{Name: "_tsrange", OID: TsrangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TsrangeOID]}}) - m.RegisterType(&Type{Name: "_tstzrange", OID: TstzrangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TstzrangeOID]}}) - m.RegisterType(&Type{Name: "_uuid", OID: UUIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[UUIDOID]}}) - m.RegisterType(&Type{Name: "_varbit", OID: VarbitArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[VarbitOID]}}) - m.RegisterType(&Type{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[VarcharOID]}}) - m.RegisterType(&Type{Name: "_xid", OID: XIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[XIDOID]}}) - - // Integer types that directly map to a PostgreSQL type - registerDefaultPgTypeVariants[int16](m, "int2") - registerDefaultPgTypeVariants[int32](m, "int4") - registerDefaultPgTypeVariants[int64](m, "int8") - - // Integer types that do not have a direct match to a PostgreSQL type - registerDefaultPgTypeVariants[int8](m, "int8") - registerDefaultPgTypeVariants[int](m, "int8") - registerDefaultPgTypeVariants[uint8](m, "int8") - registerDefaultPgTypeVariants[uint16](m, "int8") - registerDefaultPgTypeVariants[uint32](m, "int8") - registerDefaultPgTypeVariants[uint64](m, "numeric") - registerDefaultPgTypeVariants[uint](m, "numeric") - - registerDefaultPgTypeVariants[float32](m, "float4") - registerDefaultPgTypeVariants[float64](m, "float8") - - registerDefaultPgTypeVariants[bool](m, "bool") - registerDefaultPgTypeVariants[time.Time](m, "timestamptz") - registerDefaultPgTypeVariants[time.Duration](m, "interval") - registerDefaultPgTypeVariants[string](m, "text") - registerDefaultPgTypeVariants[[]byte](m, "bytea") - - registerDefaultPgTypeVariants[net.IP](m, "inet") - registerDefaultPgTypeVariants[net.IPNet](m, "cidr") - registerDefaultPgTypeVariants[netip.Addr](m, "inet") - registerDefaultPgTypeVariants[netip.Prefix](m, "cidr") - - // pgtype provided structs - registerDefaultPgTypeVariants[Bits](m, "varbit") - registerDefaultPgTypeVariants[Bool](m, "bool") - registerDefaultPgTypeVariants[Box](m, "box") - registerDefaultPgTypeVariants[Circle](m, "circle") - registerDefaultPgTypeVariants[Date](m, "date") - registerDefaultPgTypeVariants[Range[Date]](m, "daterange") - registerDefaultPgTypeVariants[Multirange[Range[Date]]](m, "datemultirange") - registerDefaultPgTypeVariants[Float4](m, "float4") - registerDefaultPgTypeVariants[Float8](m, "float8") - registerDefaultPgTypeVariants[Range[Float8]](m, "numrange") // There is no PostgreSQL builtin float8range so map it to numrange. - registerDefaultPgTypeVariants[Multirange[Range[Float8]]](m, "nummultirange") // There is no PostgreSQL builtin float8multirange so map it to nummultirange. - registerDefaultPgTypeVariants[Int2](m, "int2") - registerDefaultPgTypeVariants[Int4](m, "int4") - registerDefaultPgTypeVariants[Range[Int4]](m, "int4range") - registerDefaultPgTypeVariants[Multirange[Range[Int4]]](m, "int4multirange") - registerDefaultPgTypeVariants[Int8](m, "int8") - registerDefaultPgTypeVariants[Range[Int8]](m, "int8range") - registerDefaultPgTypeVariants[Multirange[Range[Int8]]](m, "int8multirange") - registerDefaultPgTypeVariants[Interval](m, "interval") - registerDefaultPgTypeVariants[Line](m, "line") - registerDefaultPgTypeVariants[Lseg](m, "lseg") - registerDefaultPgTypeVariants[Numeric](m, "numeric") - registerDefaultPgTypeVariants[Range[Numeric]](m, "numrange") - registerDefaultPgTypeVariants[Multirange[Range[Numeric]]](m, "nummultirange") - registerDefaultPgTypeVariants[Path](m, "path") - registerDefaultPgTypeVariants[Point](m, "point") - registerDefaultPgTypeVariants[Polygon](m, "polygon") - registerDefaultPgTypeVariants[TID](m, "tid") - registerDefaultPgTypeVariants[Text](m, "text") - registerDefaultPgTypeVariants[Time](m, "time") - registerDefaultPgTypeVariants[Timestamp](m, "timestamp") - registerDefaultPgTypeVariants[Timestamptz](m, "timestamptz") - registerDefaultPgTypeVariants[Range[Timestamp]](m, "tsrange") - registerDefaultPgTypeVariants[Multirange[Range[Timestamp]]](m, "tsmultirange") - registerDefaultPgTypeVariants[Range[Timestamptz]](m, "tstzrange") - registerDefaultPgTypeVariants[Multirange[Range[Timestamptz]]](m, "tstzmultirange") - registerDefaultPgTypeVariants[UUID](m, "uuid") - - return m } +// RegisterType registers a data type with the Map. t must not be mutated after it is registered. func (m *Map) RegisterType(t *Type) { m.oidToType[t.OID] = t m.nameToType[t.Name] = t @@ -449,13 +277,22 @@ func (m *Map) RegisterDefaultPgType(value any, name string) { } } +// TypeForOID returns the Type registered for the given OID. The returned Type must not be mutated. func (m *Map) TypeForOID(oid uint32) (*Type, bool) { - dt, ok := m.oidToType[oid] + if dt, ok := m.oidToType[oid]; ok { + return dt, true + } + + dt, ok := defaultMap.oidToType[oid] return dt, ok } +// TypeForName returns the Type registered for the given name. The returned Type must not be mutated. func (m *Map) TypeForName(name string) (*Type, bool) { - dt, ok := m.nameToType[name] + if dt, ok := m.nameToType[name]; ok { + return dt, true + } + dt, ok := defaultMap.nameToType[name] return dt, ok } @@ -463,30 +300,39 @@ func (m *Map) buildReflectTypeToType() { m.reflectTypeToType = make(map[reflect.Type]*Type) for reflectType, name := range m.reflectTypeToName { - if dt, ok := m.nameToType[name]; ok { + if dt, ok := m.TypeForName(name); ok { m.reflectTypeToType[reflectType] = dt } } } // TypeForValue finds a data type suitable for v. Use RegisterType to register types that can encode and decode -// themselves. Use RegisterDefaultPgType to register that can be handled by a registered data type. +// themselves. Use RegisterDefaultPgType to register that can be handled by a registered data type. The returned Type +// must not be mutated. func (m *Map) TypeForValue(v any) (*Type, bool) { if m.reflectTypeToType == nil { m.buildReflectTypeToType() } - dt, ok := m.reflectTypeToType[reflect.TypeOf(v)] + if dt, ok := m.reflectTypeToType[reflect.TypeOf(v)]; ok { + return dt, true + } + + dt, ok := defaultMap.reflectTypeToType[reflect.TypeOf(v)] return dt, ok } // FormatCodeForOID returns the preferred format code for type oid. If the type is not registered it returns the text // format code. func (m *Map) FormatCodeForOID(oid uint32) int16 { - fc, ok := m.oidToFormatCode[oid] - if ok { + if fc, ok := m.oidToFormatCode[oid]; ok { return fc } + + if fc, ok := defaultMap.oidToFormatCode[oid]; ok { + return fc + } + return TextFormatCode } @@ -587,6 +433,14 @@ func (plan *scanPlanFail) Scan(src []byte, dst any) error { return plan.Scan(src, dst) } } + for oid := range defaultMap.oidToType { + if _, ok := plan.m.oidToType[oid]; !ok { + plan := plan.m.planScan(oid, plan.formatCode, dst) + if _, ok := plan.(*scanPlanFail); !ok { + return plan.Scan(src, dst) + } + } + } } var format string @@ -600,7 +454,7 @@ func (plan *scanPlanFail) Scan(src []byte, dst any) error { } var dataTypeName string - if t, ok := plan.m.oidToType[plan.oid]; ok { + if t, ok := plan.m.TypeForOID(plan.oid); ok { dataTypeName = t.Name } else { dataTypeName = "unknown type" @@ -666,6 +520,7 @@ var elemKindToPointerTypes map[reflect.Kind]reflect.Type = map[reflect.Kind]refl reflect.Float32: reflect.TypeOf(new(float32)), reflect.Float64: reflect.TypeOf(new(float64)), reflect.String: reflect.TypeOf(new(string)), + reflect.Bool: reflect.TypeOf(new(bool)), } type underlyingTypeScanPlan struct { @@ -1089,15 +944,16 @@ func TryWrapPtrSliceScanPlan(target any) (plan WrappedScanPlanNextSetter, nextVa return &wrapPtrSliceScanPlan[time.Time]{}, (*FlatArray[time.Time])(target), true } - targetValue := reflect.ValueOf(target) - if targetValue.Kind() != reflect.Ptr { + targetType := reflect.TypeOf(target) + if targetType.Kind() != reflect.Ptr { return nil, nil, false } - targetElemValue := targetValue.Elem() + targetElemType := targetType.Elem() - if targetElemValue.Kind() == reflect.Slice { - return &wrapPtrSliceReflectScanPlan{}, &anySliceArrayReflect{slice: targetElemValue}, true + if targetElemType.Kind() == reflect.Slice { + slice := reflect.New(targetElemType).Elem() + return &wrapPtrSliceReflectScanPlan{}, &anySliceArrayReflect{slice: slice}, true } return nil, nil, false } @@ -1198,6 +1054,10 @@ func (m *Map) PlanScan(oid uint32, formatCode int16, target any) ScanPlan { } func (m *Map) planScan(oid uint32, formatCode int16, target any) ScanPlan { + if target == nil { + return &scanPlanFail{m: m, oid: oid, formatCode: formatCode} + } + if _, ok := target.(*UndecodedBytes); ok { return scanPlanAnyToUndecodedBytes{} } @@ -1280,25 +1140,6 @@ func (m *Map) Scan(oid uint32, formatCode int16, src []byte, dst any) error { return plan.Scan(src, dst) } -func scanUnknownType(oid uint32, formatCode int16, buf []byte, dest any) error { - switch dest := dest.(type) { - case *string: - if formatCode == BinaryFormatCode { - return fmt.Errorf("unknown oid %d in binary format cannot be scanned into %T", oid, dest) - } - *dest = string(buf) - return nil - case *[]byte: - *dest = buf - return nil - default: - if nextDst, retry := GetAssignToDstType(dest); retry { - return scanUnknownType(oid, formatCode, buf, nextDst) - } - return fmt.Errorf("unknown oid %d cannot be scanned into %T", oid, dest) - } -} - var ErrScanTargetTypeChanged = errors.New("scan target type changed") func codecScan(codec Codec, m *Map, oid uint32, format int16, src []byte, dst any) error { @@ -1514,6 +1355,7 @@ var kindToTypes map[reflect.Kind]reflect.Type = map[reflect.Kind]reflect.Type{ reflect.Float32: reflect.TypeOf(float32(0)), reflect.Float64: reflect.TypeOf(float64(0)), reflect.String: reflect.TypeOf(""), + reflect.Bool: reflect.TypeOf(false), } type underlyingTypeEncodePlan struct { @@ -2039,13 +1881,13 @@ func newEncodeError(value any, m *Map, oid uint32, formatCode int16, err error) } var dataTypeName string - if t, ok := m.oidToType[oid]; ok { + if t, ok := m.TypeForOID(oid); ok { dataTypeName = t.Name } else { dataTypeName = "unknown type" } - return fmt.Errorf("unable to encode %#v into %s format for %s (OID %d): %s", value, format, dataTypeName, oid, err) + return fmt.Errorf("unable to encode %#v into %s format for %s (OID %d): %w", value, format, dataTypeName, oid, err) } // Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/pgtype_default.go b/vendor/github.com/jackc/pgx/v5/pgtype/pgtype_default.go new file mode 100644 index 0000000000..58f4b92c7b --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/pgtype/pgtype_default.go @@ -0,0 +1,223 @@ +package pgtype + +import ( + "net" + "net/netip" + "reflect" + "sync" + "time" +) + +var ( + // defaultMap contains default mappings between PostgreSQL server types and Go type handling logic. + defaultMap *Map + defaultMapInitOnce = sync.Once{} +) + +func initDefaultMap() { + defaultMap = &Map{ + oidToType: make(map[uint32]*Type), + nameToType: make(map[string]*Type), + reflectTypeToName: make(map[reflect.Type]string), + oidToFormatCode: make(map[uint32]int16), + + memoizedScanPlans: make(map[uint32]map[reflect.Type][2]ScanPlan), + memoizedEncodePlans: make(map[uint32]map[reflect.Type][2]EncodePlan), + + TryWrapEncodePlanFuncs: []TryWrapEncodePlanFunc{ + TryWrapDerefPointerEncodePlan, + TryWrapBuiltinTypeEncodePlan, + TryWrapFindUnderlyingTypeEncodePlan, + TryWrapStructEncodePlan, + TryWrapSliceEncodePlan, + TryWrapMultiDimSliceEncodePlan, + TryWrapArrayEncodePlan, + }, + + TryWrapScanPlanFuncs: []TryWrapScanPlanFunc{ + TryPointerPointerScanPlan, + TryWrapBuiltinTypeScanPlan, + TryFindUnderlyingTypeScanPlan, + TryWrapStructScanPlan, + TryWrapPtrSliceScanPlan, + TryWrapPtrMultiDimSliceScanPlan, + TryWrapPtrArrayScanPlan, + }, + } + + // Base types + defaultMap.RegisterType(&Type{Name: "aclitem", OID: ACLItemOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) + defaultMap.RegisterType(&Type{Name: "bit", OID: BitOID, Codec: BitsCodec{}}) + defaultMap.RegisterType(&Type{Name: "bool", OID: BoolOID, Codec: BoolCodec{}}) + defaultMap.RegisterType(&Type{Name: "box", OID: BoxOID, Codec: BoxCodec{}}) + defaultMap.RegisterType(&Type{Name: "bpchar", OID: BPCharOID, Codec: TextCodec{}}) + defaultMap.RegisterType(&Type{Name: "bytea", OID: ByteaOID, Codec: ByteaCodec{}}) + defaultMap.RegisterType(&Type{Name: "char", OID: QCharOID, Codec: QCharCodec{}}) + defaultMap.RegisterType(&Type{Name: "cid", OID: CIDOID, Codec: Uint32Codec{}}) + defaultMap.RegisterType(&Type{Name: "cidr", OID: CIDROID, Codec: InetCodec{}}) + defaultMap.RegisterType(&Type{Name: "circle", OID: CircleOID, Codec: CircleCodec{}}) + defaultMap.RegisterType(&Type{Name: "date", OID: DateOID, Codec: DateCodec{}}) + defaultMap.RegisterType(&Type{Name: "float4", OID: Float4OID, Codec: Float4Codec{}}) + defaultMap.RegisterType(&Type{Name: "float8", OID: Float8OID, Codec: Float8Codec{}}) + defaultMap.RegisterType(&Type{Name: "inet", OID: InetOID, Codec: InetCodec{}}) + defaultMap.RegisterType(&Type{Name: "int2", OID: Int2OID, Codec: Int2Codec{}}) + defaultMap.RegisterType(&Type{Name: "int4", OID: Int4OID, Codec: Int4Codec{}}) + defaultMap.RegisterType(&Type{Name: "int8", OID: Int8OID, Codec: Int8Codec{}}) + defaultMap.RegisterType(&Type{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}}) + defaultMap.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: JSONCodec{}}) + defaultMap.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: JSONBCodec{}}) + defaultMap.RegisterType(&Type{Name: "jsonpath", OID: JSONPathOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) + defaultMap.RegisterType(&Type{Name: "line", OID: LineOID, Codec: LineCodec{}}) + defaultMap.RegisterType(&Type{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}}) + defaultMap.RegisterType(&Type{Name: "macaddr", OID: MacaddrOID, Codec: MacaddrCodec{}}) + defaultMap.RegisterType(&Type{Name: "name", OID: NameOID, Codec: TextCodec{}}) + defaultMap.RegisterType(&Type{Name: "numeric", OID: NumericOID, Codec: NumericCodec{}}) + defaultMap.RegisterType(&Type{Name: "oid", OID: OIDOID, Codec: Uint32Codec{}}) + defaultMap.RegisterType(&Type{Name: "path", OID: PathOID, Codec: PathCodec{}}) + defaultMap.RegisterType(&Type{Name: "point", OID: PointOID, Codec: PointCodec{}}) + defaultMap.RegisterType(&Type{Name: "polygon", OID: PolygonOID, Codec: PolygonCodec{}}) + defaultMap.RegisterType(&Type{Name: "record", OID: RecordOID, Codec: RecordCodec{}}) + defaultMap.RegisterType(&Type{Name: "text", OID: TextOID, Codec: TextCodec{}}) + defaultMap.RegisterType(&Type{Name: "tid", OID: TIDOID, Codec: TIDCodec{}}) + defaultMap.RegisterType(&Type{Name: "time", OID: TimeOID, Codec: TimeCodec{}}) + defaultMap.RegisterType(&Type{Name: "timestamp", OID: TimestampOID, Codec: TimestampCodec{}}) + defaultMap.RegisterType(&Type{Name: "timestamptz", OID: TimestamptzOID, Codec: TimestamptzCodec{}}) + defaultMap.RegisterType(&Type{Name: "unknown", OID: UnknownOID, Codec: TextCodec{}}) + defaultMap.RegisterType(&Type{Name: "uuid", OID: UUIDOID, Codec: UUIDCodec{}}) + defaultMap.RegisterType(&Type{Name: "varbit", OID: VarbitOID, Codec: BitsCodec{}}) + defaultMap.RegisterType(&Type{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}}) + defaultMap.RegisterType(&Type{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}}) + + // Range types + defaultMap.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[DateOID]}}) + defaultMap.RegisterType(&Type{Name: "int4range", OID: Int4rangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[Int4OID]}}) + defaultMap.RegisterType(&Type{Name: "int8range", OID: Int8rangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[Int8OID]}}) + defaultMap.RegisterType(&Type{Name: "numrange", OID: NumrangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[NumericOID]}}) + defaultMap.RegisterType(&Type{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[TimestampOID]}}) + defaultMap.RegisterType(&Type{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[TimestamptzOID]}}) + + // Multirange types + defaultMap.RegisterType(&Type{Name: "datemultirange", OID: DatemultirangeOID, Codec: &MultirangeCodec{ElementType: defaultMap.oidToType[DaterangeOID]}}) + defaultMap.RegisterType(&Type{Name: "int4multirange", OID: Int4multirangeOID, Codec: &MultirangeCodec{ElementType: defaultMap.oidToType[Int4rangeOID]}}) + defaultMap.RegisterType(&Type{Name: "int8multirange", OID: Int8multirangeOID, Codec: &MultirangeCodec{ElementType: defaultMap.oidToType[Int8rangeOID]}}) + defaultMap.RegisterType(&Type{Name: "nummultirange", OID: NummultirangeOID, Codec: &MultirangeCodec{ElementType: defaultMap.oidToType[NumrangeOID]}}) + defaultMap.RegisterType(&Type{Name: "tsmultirange", OID: TsmultirangeOID, Codec: &MultirangeCodec{ElementType: defaultMap.oidToType[TsrangeOID]}}) + defaultMap.RegisterType(&Type{Name: "tstzmultirange", OID: TstzmultirangeOID, Codec: &MultirangeCodec{ElementType: defaultMap.oidToType[TstzrangeOID]}}) + + // Array types + defaultMap.RegisterType(&Type{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[ACLItemOID]}}) + defaultMap.RegisterType(&Type{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[BitOID]}}) + defaultMap.RegisterType(&Type{Name: "_bool", OID: BoolArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[BoolOID]}}) + defaultMap.RegisterType(&Type{Name: "_box", OID: BoxArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[BoxOID]}}) + defaultMap.RegisterType(&Type{Name: "_bpchar", OID: BPCharArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[BPCharOID]}}) + defaultMap.RegisterType(&Type{Name: "_bytea", OID: ByteaArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[ByteaOID]}}) + defaultMap.RegisterType(&Type{Name: "_char", OID: QCharArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[QCharOID]}}) + defaultMap.RegisterType(&Type{Name: "_cid", OID: CIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[CIDOID]}}) + defaultMap.RegisterType(&Type{Name: "_cidr", OID: CIDRArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[CIDROID]}}) + defaultMap.RegisterType(&Type{Name: "_circle", OID: CircleArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[CircleOID]}}) + defaultMap.RegisterType(&Type{Name: "_date", OID: DateArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[DateOID]}}) + defaultMap.RegisterType(&Type{Name: "_daterange", OID: DaterangeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[DaterangeOID]}}) + defaultMap.RegisterType(&Type{Name: "_float4", OID: Float4ArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Float4OID]}}) + defaultMap.RegisterType(&Type{Name: "_float8", OID: Float8ArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Float8OID]}}) + defaultMap.RegisterType(&Type{Name: "_inet", OID: InetArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[InetOID]}}) + defaultMap.RegisterType(&Type{Name: "_int2", OID: Int2ArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Int2OID]}}) + defaultMap.RegisterType(&Type{Name: "_int4", OID: Int4ArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Int4OID]}}) + defaultMap.RegisterType(&Type{Name: "_int4range", OID: Int4rangeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Int4rangeOID]}}) + defaultMap.RegisterType(&Type{Name: "_int8", OID: Int8ArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Int8OID]}}) + defaultMap.RegisterType(&Type{Name: "_int8range", OID: Int8rangeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Int8rangeOID]}}) + defaultMap.RegisterType(&Type{Name: "_interval", OID: IntervalArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[IntervalOID]}}) + defaultMap.RegisterType(&Type{Name: "_json", OID: JSONArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[JSONOID]}}) + defaultMap.RegisterType(&Type{Name: "_jsonb", OID: JSONBArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[JSONBOID]}}) + defaultMap.RegisterType(&Type{Name: "_jsonpath", OID: JSONPathArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[JSONPathOID]}}) + defaultMap.RegisterType(&Type{Name: "_line", OID: LineArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[LineOID]}}) + defaultMap.RegisterType(&Type{Name: "_lseg", OID: LsegArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[LsegOID]}}) + defaultMap.RegisterType(&Type{Name: "_macaddr", OID: MacaddrArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[MacaddrOID]}}) + defaultMap.RegisterType(&Type{Name: "_name", OID: NameArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[NameOID]}}) + defaultMap.RegisterType(&Type{Name: "_numeric", OID: NumericArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[NumericOID]}}) + defaultMap.RegisterType(&Type{Name: "_numrange", OID: NumrangeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[NumrangeOID]}}) + defaultMap.RegisterType(&Type{Name: "_oid", OID: OIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[OIDOID]}}) + defaultMap.RegisterType(&Type{Name: "_path", OID: PathArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[PathOID]}}) + defaultMap.RegisterType(&Type{Name: "_point", OID: PointArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[PointOID]}}) + defaultMap.RegisterType(&Type{Name: "_polygon", OID: PolygonArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[PolygonOID]}}) + defaultMap.RegisterType(&Type{Name: "_record", OID: RecordArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[RecordOID]}}) + defaultMap.RegisterType(&Type{Name: "_text", OID: TextArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TextOID]}}) + defaultMap.RegisterType(&Type{Name: "_tid", OID: TIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TIDOID]}}) + defaultMap.RegisterType(&Type{Name: "_time", OID: TimeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TimeOID]}}) + defaultMap.RegisterType(&Type{Name: "_timestamp", OID: TimestampArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TimestampOID]}}) + defaultMap.RegisterType(&Type{Name: "_timestamptz", OID: TimestamptzArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TimestamptzOID]}}) + defaultMap.RegisterType(&Type{Name: "_tsrange", OID: TsrangeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TsrangeOID]}}) + defaultMap.RegisterType(&Type{Name: "_tstzrange", OID: TstzrangeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TstzrangeOID]}}) + defaultMap.RegisterType(&Type{Name: "_uuid", OID: UUIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[UUIDOID]}}) + defaultMap.RegisterType(&Type{Name: "_varbit", OID: VarbitArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[VarbitOID]}}) + defaultMap.RegisterType(&Type{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[VarcharOID]}}) + defaultMap.RegisterType(&Type{Name: "_xid", OID: XIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[XIDOID]}}) + + // Integer types that directly map to a PostgreSQL type + registerDefaultPgTypeVariants[int16](defaultMap, "int2") + registerDefaultPgTypeVariants[int32](defaultMap, "int4") + registerDefaultPgTypeVariants[int64](defaultMap, "int8") + + // Integer types that do not have a direct match to a PostgreSQL type + registerDefaultPgTypeVariants[int8](defaultMap, "int8") + registerDefaultPgTypeVariants[int](defaultMap, "int8") + registerDefaultPgTypeVariants[uint8](defaultMap, "int8") + registerDefaultPgTypeVariants[uint16](defaultMap, "int8") + registerDefaultPgTypeVariants[uint32](defaultMap, "int8") + registerDefaultPgTypeVariants[uint64](defaultMap, "numeric") + registerDefaultPgTypeVariants[uint](defaultMap, "numeric") + + registerDefaultPgTypeVariants[float32](defaultMap, "float4") + registerDefaultPgTypeVariants[float64](defaultMap, "float8") + + registerDefaultPgTypeVariants[bool](defaultMap, "bool") + registerDefaultPgTypeVariants[time.Time](defaultMap, "timestamptz") + registerDefaultPgTypeVariants[time.Duration](defaultMap, "interval") + registerDefaultPgTypeVariants[string](defaultMap, "text") + registerDefaultPgTypeVariants[[]byte](defaultMap, "bytea") + + registerDefaultPgTypeVariants[net.IP](defaultMap, "inet") + registerDefaultPgTypeVariants[net.IPNet](defaultMap, "cidr") + registerDefaultPgTypeVariants[netip.Addr](defaultMap, "inet") + registerDefaultPgTypeVariants[netip.Prefix](defaultMap, "cidr") + + // pgtype provided structs + registerDefaultPgTypeVariants[Bits](defaultMap, "varbit") + registerDefaultPgTypeVariants[Bool](defaultMap, "bool") + registerDefaultPgTypeVariants[Box](defaultMap, "box") + registerDefaultPgTypeVariants[Circle](defaultMap, "circle") + registerDefaultPgTypeVariants[Date](defaultMap, "date") + registerDefaultPgTypeVariants[Range[Date]](defaultMap, "daterange") + registerDefaultPgTypeVariants[Multirange[Range[Date]]](defaultMap, "datemultirange") + registerDefaultPgTypeVariants[Float4](defaultMap, "float4") + registerDefaultPgTypeVariants[Float8](defaultMap, "float8") + registerDefaultPgTypeVariants[Range[Float8]](defaultMap, "numrange") // There is no PostgreSQL builtin float8range so map it to numrange. + registerDefaultPgTypeVariants[Multirange[Range[Float8]]](defaultMap, "nummultirange") // There is no PostgreSQL builtin float8multirange so map it to nummultirange. + registerDefaultPgTypeVariants[Int2](defaultMap, "int2") + registerDefaultPgTypeVariants[Int4](defaultMap, "int4") + registerDefaultPgTypeVariants[Range[Int4]](defaultMap, "int4range") + registerDefaultPgTypeVariants[Multirange[Range[Int4]]](defaultMap, "int4multirange") + registerDefaultPgTypeVariants[Int8](defaultMap, "int8") + registerDefaultPgTypeVariants[Range[Int8]](defaultMap, "int8range") + registerDefaultPgTypeVariants[Multirange[Range[Int8]]](defaultMap, "int8multirange") + registerDefaultPgTypeVariants[Interval](defaultMap, "interval") + registerDefaultPgTypeVariants[Line](defaultMap, "line") + registerDefaultPgTypeVariants[Lseg](defaultMap, "lseg") + registerDefaultPgTypeVariants[Numeric](defaultMap, "numeric") + registerDefaultPgTypeVariants[Range[Numeric]](defaultMap, "numrange") + registerDefaultPgTypeVariants[Multirange[Range[Numeric]]](defaultMap, "nummultirange") + registerDefaultPgTypeVariants[Path](defaultMap, "path") + registerDefaultPgTypeVariants[Point](defaultMap, "point") + registerDefaultPgTypeVariants[Polygon](defaultMap, "polygon") + registerDefaultPgTypeVariants[TID](defaultMap, "tid") + registerDefaultPgTypeVariants[Text](defaultMap, "text") + registerDefaultPgTypeVariants[Time](defaultMap, "time") + registerDefaultPgTypeVariants[Timestamp](defaultMap, "timestamp") + registerDefaultPgTypeVariants[Timestamptz](defaultMap, "timestamptz") + registerDefaultPgTypeVariants[Range[Timestamp]](defaultMap, "tsrange") + registerDefaultPgTypeVariants[Multirange[Range[Timestamp]]](defaultMap, "tsmultirange") + registerDefaultPgTypeVariants[Range[Timestamptz]](defaultMap, "tstzrange") + registerDefaultPgTypeVariants[Multirange[Range[Timestamptz]]](defaultMap, "tstzmultirange") + registerDefaultPgTypeVariants[UUID](defaultMap, "uuid") + + defaultMap.buildReflectTypeToType() +} diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/point.go b/vendor/github.com/jackc/pgx/v5/pgtype/point.go index cfa5a9f1a2..b5a4320b6a 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/point.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/point.go @@ -40,7 +40,7 @@ func (p Point) PointValue() (Point, error) { } func parsePoint(src []byte) (*Point, error) { - if src == nil || bytes.Compare(src, []byte("null")) == 0 { + if src == nil || bytes.Equal(src, []byte("null")) { return &Point{}, nil } diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/timestamp.go b/vendor/github.com/jackc/pgx/v5/pgtype/timestamp.go index 9f3de2c592..35d7395660 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/timestamp.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/timestamp.go @@ -3,6 +3,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "encoding/json" "fmt" "strings" "time" @@ -66,6 +67,55 @@ func (ts Timestamp) Value() (driver.Value, error) { return ts.Time, nil } +func (ts Timestamp) MarshalJSON() ([]byte, error) { + if !ts.Valid { + return []byte("null"), nil + } + + var s string + + switch ts.InfinityModifier { + case Finite: + s = ts.Time.Format(time.RFC3339Nano) + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + return json.Marshal(s) +} + +func (ts *Timestamp) UnmarshalJSON(b []byte) error { + var s *string + err := json.Unmarshal(b, &s) + if err != nil { + return err + } + + if s == nil { + *ts = Timestamp{} + return nil + } + + switch *s { + case "infinity": + *ts = Timestamp{Valid: true, InfinityModifier: Infinity} + case "-infinity": + *ts = Timestamp{Valid: true, InfinityModifier: -Infinity} + default: + // PostgreSQL uses ISO 8601 for to_json function and casting from a string to timestamptz + tim, err := time.Parse(time.RFC3339Nano, *s) + if err != nil { + return err + } + + *ts = Timestamp{Time: tim, Valid: true} + } + + return nil +} + type TimestampCodec struct{} func (TimestampCodec) FormatSupported(format int16) bool { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/uuid.go b/vendor/github.com/jackc/pgx/v5/pgtype/uuid.go index 96a4c32fd1..b59d6e766b 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/uuid.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/uuid.go @@ -97,7 +97,7 @@ func (src UUID) MarshalJSON() ([]byte, error) { } func (dst *UUID) UnmarshalJSON(src []byte) error { - if bytes.Compare(src, []byte("null")) == 0 { + if bytes.Equal(src, []byte("null")) { *dst = UUID{} return nil } diff --git a/vendor/github.com/jackc/pgx/v5/rows.go b/vendor/github.com/jackc/pgx/v5/rows.go index ffe739b025..1b1c8ac9a3 100644 --- a/vendor/github.com/jackc/pgx/v5/rows.go +++ b/vendor/github.com/jackc/pgx/v5/rows.go @@ -28,12 +28,16 @@ type Rows interface { // to call Close after rows is already closed. Close() - // Err returns any error that occurred while reading. + // Err returns any error that occurred while reading. Err must only be called after the Rows is closed (either by + // calling Close or by Next returning false). If it is called early it may return nil even if there was an error + // executing the query. Err() error // CommandTag returns the command tag from this query. It is only available after Rows is closed. CommandTag() pgconn.CommandTag + // FieldDescriptions returns the field descriptions of the columns. It may return nil. In particular this can occur + // when there was an error executing the query. FieldDescriptions() []pgconn.FieldDescription // Next prepares the next row for reading. It returns true if there is another @@ -227,7 +231,11 @@ func (rows *baseRows) Scan(dest ...any) error { if len(dest) == 1 { if rc, ok := dest[0].(RowScanner); ok { - return rc.ScanRow(rows) + err := rc.ScanRow(rows) + if err != nil { + rows.fatal(err) + } + return err } } @@ -298,7 +306,7 @@ func (rows *baseRows) Values() ([]any, error) { copy(newBuf, buf) values = append(values, newBuf) default: - rows.fatal(errors.New("Unknown format code")) + rows.fatal(errors.New("unknown format code")) } } @@ -488,7 +496,8 @@ func (rs *mapRowScanner) ScanRow(rows Rows) error { } // RowToStructByPos returns a T scanned from row. T must be a struct. T must have the same number a public fields as row -// has fields. The row and T fields will by matched by position. +// has fields. The row and T fields will by matched by position. If the "db" struct tag is "-" then the field will be +// ignored. func RowToStructByPos[T any](row CollectableRow) (T, error) { var value T err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value}) @@ -496,7 +505,8 @@ func RowToStructByPos[T any](row CollectableRow) (T, error) { } // RowToAddrOfStructByPos returns the address of a T scanned from row. T must be a struct. T must have the same number a -// public fields as row has fields. The row and T fields will by matched by position. +// public fields as row has fields. The row and T fields will by matched by position. If the "db" struct tag is "-" then +// the field will be ignored. func RowToAddrOfStructByPos[T any](row CollectableRow) (*T, error) { var value T err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value}) @@ -533,13 +543,16 @@ func (rs *positionalStructRowScanner) appendScanTargets(dstElemValue reflect.Val for i := 0; i < dstElemType.NumField(); i++ { sf := dstElemType.Field(i) - if sf.PkgPath == "" { - // Handle anonymous struct embedding, but do not try to handle embedded pointers. - if sf.Anonymous && sf.Type.Kind() == reflect.Struct { - scanTargets = rs.appendScanTargets(dstElemValue.Field(i), scanTargets) - } else { - scanTargets = append(scanTargets, dstElemValue.Field(i).Addr().Interface()) + // Handle anonymous struct embedding, but do not try to handle embedded pointers. + if sf.Anonymous && sf.Type.Kind() == reflect.Struct { + scanTargets = rs.appendScanTargets(dstElemValue.Field(i), scanTargets) + } else if sf.PkgPath == "" { + dbTag, _ := sf.Tag.Lookup(structTagKey) + if dbTag == "-" { + // Field is ignored, skip it. + continue } + scanTargets = append(scanTargets, dstElemValue.Field(i).Addr().Interface()) } } @@ -565,8 +578,28 @@ func RowToAddrOfStructByName[T any](row CollectableRow) (*T, error) { return &value, err } +// RowToStructByNameLax returns a T scanned from row. T must be a struct. T must have greater than or equal number of named public +// fields as row has fields. The row and T fields will by matched by name. The match is case-insensitive. The database +// column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored. +func RowToStructByNameLax[T any](row CollectableRow) (T, error) { + var value T + err := row.Scan(&namedStructRowScanner{ptrToStruct: &value, lax: true}) + return value, err +} + +// RowToAddrOfStructByNameLax returns the address of a T scanned from row. T must be a struct. T must have greater than or +// equal number of named public fields as row has fields. The row and T fields will by matched by name. The match is +// case-insensitive. The database column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" +// then the field will be ignored. +func RowToAddrOfStructByNameLax[T any](row CollectableRow) (*T, error) { + var value T + err := row.Scan(&namedStructRowScanner{ptrToStruct: &value, lax: true}) + return &value, err +} + type namedStructRowScanner struct { ptrToStruct any + lax bool } func (rs *namedStructRowScanner) ScanRow(rows Rows) error { @@ -578,7 +611,6 @@ func (rs *namedStructRowScanner) ScanRow(rows Rows) error { dstElemValue := dstValue.Elem() scanTargets, err := rs.appendScanTargets(dstElemValue, nil, rows.FieldDescriptions()) - if err != nil { return err } @@ -638,7 +670,13 @@ func (rs *namedStructRowScanner) appendScanTargets(dstElemValue reflect.Value, s colName = sf.Name } fpos := fieldPosByName(fldDescs, colName) - if fpos == -1 || fpos >= len(scanTargets) { + if fpos == -1 { + if rs.lax { + continue + } + return nil, fmt.Errorf("cannot find field %s in returned row", colName) + } + if fpos >= len(scanTargets) && !rs.lax { return nil, fmt.Errorf("cannot find field %s in returned row", colName) } scanTargets[fpos] = dstElemValue.Field(i).Addr().Interface() diff --git a/vendor/github.com/jackc/pgx/v5/tx.go b/vendor/github.com/jackc/pgx/v5/tx.go index e57142a619..8feeb51233 100644 --- a/vendor/github.com/jackc/pgx/v5/tx.go +++ b/vendor/github.com/jackc/pgx/v5/tx.go @@ -44,6 +44,10 @@ type TxOptions struct { IsoLevel TxIsoLevel AccessMode TxAccessMode DeferrableMode TxDeferrableMode + + // BeginQuery is the SQL query that will be executed to begin the transaction. This allows using non-standard syntax + // such as BEGIN PRIORITY HIGH with CockroachDB. If set this will override the other settings. + BeginQuery string } var emptyTxOptions TxOptions @@ -53,6 +57,10 @@ func (txOptions TxOptions) beginSQL() string { return "begin" } + if txOptions.BeginQuery != "" { + return txOptions.BeginQuery + } + var buf strings.Builder buf.Grow(64) // 64 - maximum length of string with available options buf.WriteString("begin") @@ -144,7 +152,6 @@ type Tx interface { // called on the dbTx. type dbTx struct { conn *Conn - err error savepointNum int64 closed bool } diff --git a/vendor/gorm.io/driver/postgres/error_translator.go b/vendor/gorm.io/driver/postgres/error_translator.go index 285494c2de..9c0ef25342 100644 --- a/vendor/gorm.io/driver/postgres/error_translator.go +++ b/vendor/gorm.io/driver/postgres/error_translator.go @@ -2,26 +2,30 @@ package postgres import ( "encoding/json" - "github.com/jackc/pgx/v5/pgconn" + "gorm.io/gorm" + + "github.com/jackc/pgx/v5/pgconn" ) -var errCodes = map[string]string{ - "uniqueConstraint": "23505", +var errCodes = map[string]error{ + "23505": gorm.ErrDuplicatedKey, + "23503": gorm.ErrForeignKeyViolated, + "42703": gorm.ErrInvalidField, } type ErrMessage struct { - Code string `json:"Code"` - Severity string `json:"Severity"` - Message string `json:"Message"` + Code string + Severity string + Message string } // Translate it will translate the error to native gorm errors. // Since currently gorm supporting both pgx and pg drivers, only checking for pgx PgError types is not enough for translating errors, so we have additional error json marshal fallback. func (dialector Dialector) Translate(err error) error { if pgErr, ok := err.(*pgconn.PgError); ok { - if pgErr.Code == errCodes["uniqueConstraint"] { - return gorm.ErrDuplicatedKey + if translatedErr, found := errCodes[pgErr.Code]; found { + return translatedErr } return err } @@ -37,8 +41,8 @@ func (dialector Dialector) Translate(err error) error { return err } - if errMsg.Code == errCodes["uniqueConstraint"] { - return gorm.ErrDuplicatedKey + if translatedErr, found := errCodes[errMsg.Code]; found { + return translatedErr } return err } diff --git a/vendor/gorm.io/driver/postgres/migrator.go b/vendor/gorm.io/driver/postgres/migrator.go index e4d8e9260b..c085a70ee0 100644 --- a/vendor/gorm.io/driver/postgres/migrator.go +++ b/vendor/gorm.io/driver/postgres/migrator.go @@ -35,14 +35,16 @@ where ` var typeAliasMap = map[string][]string{ - "int2": {"smallint"}, - "int4": {"integer"}, - "int8": {"bigint"}, - "smallint": {"int2"}, - "integer": {"int4"}, - "bigint": {"int8"}, - "decimal": {"numeric"}, - "numeric": {"decimal"}, + "int2": {"smallint"}, + "int4": {"integer"}, + "int8": {"bigint"}, + "smallint": {"int2"}, + "integer": {"int4"}, + "bigint": {"int8"}, + "decimal": {"numeric"}, + "numeric": {"decimal"}, + "timestamptz": {"timestamp with time zone"}, + "timestamp with time zone": {"timestamptz"}, } type Migrator struct { @@ -160,7 +162,8 @@ func (m Migrator) CreateTable(values ...interface{}) (err error) { for _, value := range m.ReorderModels(values, false) { if err = m.RunWithValue(value, func(stmt *gorm.Statement) error { if stmt.Schema != nil { - for _, field := range stmt.Schema.FieldsByDBName { + for _, fieldName := range stmt.Schema.DBNames { + field := stmt.Schema.FieldsByDBName[fieldName] if field.Comment != "" { if err := m.DB.Exec( "COMMENT ON COLUMN ?.? IS ?", @@ -326,8 +329,7 @@ func (m Migrator) AlterColumn(value interface{}, field string) error { return err } } else { - if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? TYPE ?"+m.genUsingExpression(fileType.SQL, fieldColumnType.DatabaseTypeName()), - m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType, clause.Column{Name: field.DBName}, fileType).Error; err != nil { + if err := m.modifyColumn(stmt, field, fileType, fieldColumnType); err != nil { return err } } @@ -387,14 +389,27 @@ func (m Migrator) AlterColumn(value interface{}, field string) error { return nil } -func (m Migrator) genUsingExpression(targetType, sourceType string) string { - if targetType == "boolean" { - switch sourceType { +func (m Migrator) modifyColumn(stmt *gorm.Statement, field *schema.Field, targetType clause.Expr, existingColumn *migrator.ColumnType) error { + alterSQL := "ALTER TABLE ? ALTER COLUMN ? TYPE ? USING ?::?" + isUncastableDefaultValue := false + + if targetType.SQL == "boolean" { + switch existingColumn.DatabaseTypeName() { case "int2", "int8", "numeric": - return " USING ?::INT::?" + alterSQL = "ALTER TABLE ? ALTER COLUMN ? TYPE ? USING ?::int::?" + } + isUncastableDefaultValue = true + } + + if dv, _ := existingColumn.DefaultValue(); dv != "" && isUncastableDefaultValue { + if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil { + return err } } - return " USING ?::?" + if err := m.DB.Exec(alterSQL, m.CurrentTable(stmt), clause.Column{Name: field.DBName}, targetType, clause.Column{Name: field.DBName}, targetType).Error; err != nil { + return err + } + return nil } func (m Migrator) HasConstraint(value interface{}, name string) bool { @@ -463,7 +478,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, } if column.DefaultValueValue.Valid { - column.DefaultValueValue.String = regexp.MustCompile(`'?(.*)\b'?:+[\w\s]+$`).ReplaceAllString(column.DefaultValueValue.String, "$1") + column.DefaultValueValue.String = parseDefaultValueValue(column.DefaultValueValue.String) } if datetimePrecision.Valid { @@ -497,7 +512,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, // check primary, unique field { - columnTypeRows, err := m.DB.Raw("SELECT constraint_name FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ? AND constraint_type = ?", currentDatabase, currentSchema, table, "UNIQUE").Rows() + columnTypeRows, err := m.DB.Raw("SELECT constraint_name FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_catalog, table_name, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ? AND constraint_type = ?", currentDatabase, currentSchema, table, "UNIQUE").Rows() if err != nil { return err } @@ -509,7 +524,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, } columnTypeRows.Close() - columnTypeRows, err = m.DB.Raw("SELECT c.column_name, constraint_name, constraint_type FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ?", currentDatabase, currentSchema, table).Rows() + columnTypeRows, err = m.DB.Raw("SELECT c.column_name, constraint_name, constraint_type FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_catalog, table_name, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ?", currentDatabase, currentSchema, table).Rows() if err != nil { return err } @@ -769,3 +784,8 @@ func (m Migrator) RenameColumn(dst interface{}, oldName, field string) error { m.resetPreparedStmts() return nil } + +func parseDefaultValueValue(defaultValue string) string { + value := regexp.MustCompile(`^(.*?)(?:::.*)?$`).ReplaceAllString(defaultValue, "$1") + return strings.Trim(value, "'") +} diff --git a/vendor/gorm.io/driver/postgres/postgres.go b/vendor/gorm.io/driver/postgres/postgres.go index dbeabf561e..eb93b40e91 100644 --- a/vendor/gorm.io/driver/postgres/postgres.go +++ b/vendor/gorm.io/driver/postgres/postgres.go @@ -3,11 +3,11 @@ package postgres import ( "database/sql" "fmt" - "github.com/jackc/pgx/v5" "regexp" "strconv" "strings" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/stdlib" "gorm.io/gorm" "gorm.io/gorm/callbacks" @@ -24,6 +24,7 @@ type Dialector struct { type Config struct { DriverName string DSN string + WithoutQuotingCheck bool PreferSimpleProtocol bool WithoutReturning bool Conn gorm.ConnPool @@ -46,7 +47,7 @@ var timeZoneMatcher = regexp.MustCompile("(time_zone|TimeZone)=(.*?)($|&| )") func (dialector Dialector) Initialize(db *gorm.DB) (err error) { callbackConfig := &callbacks.Config{ CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT"}, - UpdateClauses: []string{"UPDATE", "SET", "WHERE"}, + UpdateClauses: []string{"UPDATE", "SET", "FROM", "WHERE"}, DeleteClauses: []string{"DELETE", "FROM", "WHERE"}, } // register callbacks @@ -98,6 +99,11 @@ func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, } func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { + if dialector.WithoutQuotingCheck { + writer.WriteString(str) + return + } + var ( underQuoted, selfQuoted bool continuousBacktick int8 diff --git a/vendor/modules.txt b/vendor/modules.txt index dd89ea7343..805e99f66a 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -317,16 +317,16 @@ github.com/jackc/pgpassfile # github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a ## explicit; go 1.14 github.com/jackc/pgservicefile -# github.com/jackc/pgx/v5 v5.3.1 +# github.com/jackc/pgx/v5 v5.4.3 ## explicit; go 1.19 github.com/jackc/pgx/v5 github.com/jackc/pgx/v5/internal/anynil github.com/jackc/pgx/v5/internal/iobufpool -github.com/jackc/pgx/v5/internal/nbconn github.com/jackc/pgx/v5/internal/pgio github.com/jackc/pgx/v5/internal/sanitize github.com/jackc/pgx/v5/internal/stmtcache github.com/jackc/pgx/v5/pgconn +github.com/jackc/pgx/v5/pgconn/internal/bgreader github.com/jackc/pgx/v5/pgconn/internal/ctxwatch github.com/jackc/pgx/v5/pgproto3 github.com/jackc/pgx/v5/pgtype @@ -986,7 +986,7 @@ gopkg.in/yaml.v2 # gopkg.in/yaml.v3 v3.0.1 ## explicit gopkg.in/yaml.v3 -# gorm.io/driver/postgres v1.5.2 +# gorm.io/driver/postgres v1.5.4 ## explicit; go 1.18 gorm.io/driver/postgres # gorm.io/gorm v1.25.5