From 228cfffc20bdb9b572771a586ec7ff0f1dd568c1 Mon Sep 17 00:00:00 2001 From: Anthonin Bonnefoy Date: Thu, 23 Jan 2025 11:57:42 +0100 Subject: [PATCH] Unwatch and close connection on a batch write error Previously, a conn.Write would simply unlock pgconn, leaving the connection as Idle and reusable while the multiResultReader would be closed. From this state, calling multiResultReader.Close won't try to receiveMessage and thus won't unwatch and close the connection since it is already closed. This leaves the connection "open" and the next time it's used, a "Watch already in progress" panic could be triggered. This patch fixes the issue by unwatching and closing the connection on a batch write error. The same was done on Sync.Encode error even if the path is unreachable as Sync.Error never returns an error. --- pgconn/pgconn.go | 10 ++++++---- pgconn/pgconn_test.go | 46 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 4 deletions(-) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 59b89cf7d..5ff9632c0 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -1773,9 +1773,10 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR batch.buf, batch.err = (&pgproto3.Sync{}).Encode(batch.buf) if batch.err != nil { + pgConn.contextWatcher.Unwatch() + multiResult.err = normalizeTimeoutError(multiResult.ctx, batch.err) multiResult.closed = true - multiResult.err = batch.err - pgConn.unlock() + pgConn.asyncClose() return multiResult } @@ -1783,9 +1784,10 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR defer pgConn.exitPotentialWriteReadDeadlock() _, err := pgConn.conn.Write(batch.buf) if err != nil { + pgConn.contextWatcher.Unwatch() + multiResult.err = normalizeTimeoutError(multiResult.ctx, err) multiResult.closed = true - multiResult.err = err - pgConn.unlock() + pgConn.asyncClose() return multiResult } diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 2b582e242..b2d2f7f79 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -1420,6 +1420,52 @@ func TestConnExecBatch(t *testing.T) { assert.Equal(t, "SELECT 1", results[2].CommandTag.String()) } +type mockConnection struct { + net.Conn + writeLatency *time.Duration +} + +func (m mockConnection) Write(b []byte) (n int, err error) { + time.Sleep(*m.writeLatency) + return m.Conn.Write(b) +} + +func TestConnExecBatchWriteError(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + var mockConn mockConnection + writeLatency := 0 * time.Second + config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { + conn, err := net.Dial(network, address) + mockConn = mockConnection{conn, &writeLatency} + return mockConn, err + } + + pgConn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + batch := &pgconn.Batch{} + pgConn.Conn() + + ctx2, cancel2 := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel2() + + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) + writeLatency = 2 * time.Second + mrr := pgConn.ExecBatch(ctx2, batch) + err = mrr.Close() + require.Error(t, err) + assert.ErrorIs(t, err, context.DeadlineExceeded) + require.True(t, pgConn.IsClosed()) +} + func TestConnExecBatchDeferredError(t *testing.T) { t.Parallel()