Skip to content

Commit

Permalink
Merge pull request #2240 from bonnefoa/fix-watch-panic
Browse files Browse the repository at this point in the history
Unwatch and close connection on a batch write error
  • Loading branch information
jackc authored Jan 25, 2025
2 parents b5efc90 + 228cfff commit 1abf7d9
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 4 deletions.
10 changes: 6 additions & 4 deletions pgconn/pgconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1773,19 +1773,21 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR

batch.buf, batch.err = (&pgproto3.Sync{}).Encode(batch.buf)
if batch.err != nil {
pgConn.contextWatcher.Unwatch()
multiResult.err = normalizeTimeoutError(multiResult.ctx, batch.err)
multiResult.closed = true
multiResult.err = batch.err
pgConn.unlock()
pgConn.asyncClose()
return multiResult
}

pgConn.enterPotentialWriteReadDeadlock()
defer pgConn.exitPotentialWriteReadDeadlock()
_, err := pgConn.conn.Write(batch.buf)
if err != nil {
pgConn.contextWatcher.Unwatch()
multiResult.err = normalizeTimeoutError(multiResult.ctx, err)
multiResult.closed = true
multiResult.err = err
pgConn.unlock()
pgConn.asyncClose()
return multiResult
}

Expand Down
46 changes: 46 additions & 0 deletions pgconn/pgconn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1420,6 +1420,52 @@ func TestConnExecBatch(t *testing.T) {
assert.Equal(t, "SELECT 1", results[2].CommandTag.String())
}

type mockConnection struct {
net.Conn
writeLatency *time.Duration
}

func (m mockConnection) Write(b []byte) (n int, err error) {
time.Sleep(*m.writeLatency)
return m.Conn.Write(b)
}

func TestConnExecBatchWriteError(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()

config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)

var mockConn mockConnection
writeLatency := 0 * time.Second
config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) {
conn, err := net.Dial(network, address)
mockConn = mockConnection{conn, &writeLatency}
return mockConn, err
}

pgConn, err := pgconn.ConnectConfig(ctx, config)
require.NoError(t, err)
defer closeConn(t, pgConn)

batch := &pgconn.Batch{}
pgConn.Conn()

ctx2, cancel2 := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel2()

batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil)
writeLatency = 2 * time.Second
mrr := pgConn.ExecBatch(ctx2, batch)
err = mrr.Close()
require.Error(t, err)
assert.ErrorIs(t, err, context.DeadlineExceeded)
require.True(t, pgConn.IsClosed())
}

func TestConnExecBatchDeferredError(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit 1abf7d9

Please sign in to comment.