diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index 7a8427ccee..e00363a548 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -14,6 +14,7 @@ import ( "fmt" "io" "net" + "os" "strings" "sync" "sync/atomic" @@ -55,7 +56,7 @@ type connection struct { nc net.Conn // When nil, the connection is closed. addr address.Address idleTimeout time.Duration - idleDeadline atomic.Value // Stores a time.Time + idleStart atomic.Value // Stores a time.Time readTimeout time.Duration writeTimeout time.Duration desc description.Server @@ -561,25 +562,65 @@ func (c *connection) close() error { return err } +// closed returns true if the connection has been closed by the driver. func (c *connection) closed() bool { return atomic.LoadInt64(&c.state) == connDisconnected } +// isAlive returns true if the connection is alive and ready to be used for an +// operation. +// +// Note that the liveness check can be slow (at least 1ms), so isAlive only +// checks the liveness of the connection if it's been idle for at least 10 +// seconds. For frequently in-use connections, a network error during an +// operation will be the first indication of a dead connection. +func (c *connection) isAlive() bool { + if c.nc == nil { + return false + } + + // If the connection has been idle for less than 10 seconds, skip the + // liveness check. + // + // The 10-seconds idle bypass is based on the liveness check implementation + // in the Python Driver. That implementation uses 1 second as the idle + // threshold, but we chose to be more conservative in the Go Driver because + // this is new behavior with unknown side-effects. See + // https://github.com/mongodb/mongo-python-driver/blob/e6b95f65953e01e435004af069a6976473eaf841/pymongo/synchronous/pool.py#L983-L985 + idleStart, ok := c.idleStart.Load().(time.Time) + if !ok || idleStart.Add(10*time.Second).After(time.Now()) { + return true + } + + // Set a 1ms read deadline and attempt to read 1 byte from the connection. + // Expect it to block for 1ms then return a deadline exceeded error. If it + // returns any other error, the connection is not usable, so return false. + // If it doesn't return an error and actually reads data, the connection is + // also not usable, so return false. + // + // Note that we don't need to un-set the read deadline because the "read" + // and "write" methods always reset the deadlines. + err := c.nc.SetReadDeadline(time.Now().Add(1 * time.Millisecond)) + if err != nil { + return false + } + var b [1]byte + _, err = c.nc.Read(b[:]) + return errors.Is(err, os.ErrDeadlineExceeded) +} + func (c *connection) idleTimeoutExpired() bool { - now := time.Now() - if c.idleTimeout > 0 { - idleDeadline, ok := c.idleDeadline.Load().(time.Time) - if ok && now.After(idleDeadline) { - return true - } + if c.idleTimeout == 0 { + return false } - return false + idleStart, ok := c.idleStart.Load().(time.Time) + return ok && idleStart.Add(c.idleTimeout).Before(time.Now()) } -func (c *connection) bumpIdleDeadline() { +func (c *connection) bumpIdleStart() { if c.idleTimeout > 0 { - c.idleDeadline.Store(time.Now().Add(c.idleTimeout)) + c.idleStart.Store(time.Now()) } } diff --git a/x/mongo/driver/topology/connection_test.go b/x/mongo/driver/topology/connection_test.go index ff0f3d0498..07dd9ff0ec 100644 --- a/x/mongo/driver/topology/connection_test.go +++ b/x/mongo/driver/topology/connection_test.go @@ -19,6 +19,7 @@ import ( "github.com/google/go-cmp/cmp" "go.mongodb.org/mongo-driver/internal/assert" + "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/mongo/address" "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/mongo/driver" @@ -427,7 +428,7 @@ func TestConnection(t *testing.T) { want := []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A} err := conn.writeWireMessage(context.Background(), want) - noerr(t, err) + require.NoError(t, err) got := tnc.buf if !cmp.Equal(got, want) { t.Errorf("writeWireMessage did not write the proper bytes. got %v; want %v", got, want) @@ -624,7 +625,7 @@ func TestConnection(t *testing.T) { conn.cancellationListener = listener got, err := conn.readWireMessage(context.Background()) - noerr(t, err) + require.NoError(t, err) if !cmp.Equal(got, want) { t.Errorf("did not read full wire message. got %v; want %v", got, want) } @@ -1251,3 +1252,85 @@ func (tcl *testCancellationListener) assertCalledOnce(t *testing.T) { assert.Equal(t, 1, tcl.numListen, "expected Listen to be called once, got %d", tcl.numListen) assert.Equal(t, 1, tcl.numStopListening, "expected StopListening to be called once, got %d", tcl.numListen) } + +func TestConnection_IsAlive(t *testing.T) { + t.Parallel() + + t.Run("uninitialized", func(t *testing.T) { + t.Parallel() + + conn := newConnection("") + assert.False(t, + conn.isAlive(), + "expected isAlive for an uninitialized connection to always return false") + }) + + t.Run("connection open", func(t *testing.T) { + t.Parallel() + + cleanup := make(chan struct{}) + defer close(cleanup) + addr := bootstrapConnections(t, 1, func(nc net.Conn) { + // Keep the connection open until the end of the test. + <-cleanup + _ = nc.Close() + }) + + conn := newConnection(address.Address(addr.String())) + err := conn.connect(context.Background()) + require.NoError(t, err) + + conn.idleStart.Store(time.Now().Add(-11 * time.Second)) + assert.True(t, + conn.isAlive(), + "expected isAlive for an open connection to return true") + }) + + t.Run("connection closed", func(t *testing.T) { + t.Parallel() + + conns := make(chan net.Conn) + addr := bootstrapConnections(t, 1, func(nc net.Conn) { + conns <- nc + }) + + conn := newConnection(address.Address(addr.String())) + err := conn.connect(context.Background()) + require.NoError(t, err) + + // Close the connection before calling isAlive. + nc := <-conns + err = nc.Close() + require.NoError(t, err) + + conn.idleStart.Store(time.Now().Add(-11 * time.Second)) + assert.False(t, + conn.isAlive(), + "expected isAlive for a closed connection to return false") + }) + + t.Run("connection reads data", func(t *testing.T) { + t.Parallel() + + cleanup := make(chan struct{}) + defer close(cleanup) + addr := bootstrapConnections(t, 1, func(nc net.Conn) { + // Write some data to the connection before calling isAlive. + _, err := nc.Write([]byte{5, 0, 0, 0, 0}) + require.NoError(t, err) + + // Keep the connection open until the end of the test. + <-cleanup + _ = nc.Close() + }) + + conn := newConnection(address.Address(addr.String())) + err := conn.connect(context.Background()) + require.NoError(t, err) + + conn.idleStart.Store(time.Now().Add(-11 * time.Second)) + assert.False(t, + conn.isAlive(), + "expected isAlive for an open connection that reads data to return false") + }) +} diff --git a/x/mongo/driver/topology/pool.go b/x/mongo/driver/topology/pool.go index ddb69ada76..e9565425d9 100644 --- a/x/mongo/driver/topology/pool.go +++ b/x/mongo/driver/topology/pool.go @@ -167,8 +167,11 @@ type reason struct { // connectionPerished checks if a given connection is perished and should be removed from the pool. func connectionPerished(conn *connection) (reason, bool) { switch { - case conn.closed(): - // A connection would only be closed if it encountered a network error during an operation and closed itself. + case conn.closed() || !conn.isAlive(): + // A connection would only be closed if it encountered a network error + // during an operation and closed itself. If a connection is not alive + // (e.g. the connection was closed by the server-side), it's also + // considered a network error. return reason{ loggerConn: logger.ReasonConnClosedError, event: event.ReasonError, @@ -898,13 +901,15 @@ func (p *pool) checkInNoEvent(conn *connection) error { return nil } - // Bump the connection idle deadline here because we're about to make the connection "available". - // The idle deadline is used to determine when a connection has reached its max idle time and - // should be closed. A connection reaches its max idle time when it has been "available" in the - // idle connections stack for more than the configured duration (maxIdleTimeMS). Set it before - // we call connectionPerished(), which checks the idle deadline, because a newly "available" - // connection should never be perished due to max idle time. - conn.bumpIdleDeadline() + // Bump the connection idle start time here because we're about to make the + // connection "available". The idle start time is used to determine how long + // a connection has been idle and when it has reached its max idle time and + // should be closed. A connection reaches its max idle time when it has been + // "available" in the idle connections stack for more than the configured + // duration (maxIdleTimeMS). Set it before we call connectionPerished(), + // which checks the idle deadline, because a newly "available" connection + // should never be perished due to max idle time. + conn.bumpIdleStart() r, perished := connectionPerished(conn) if !perished && conn.pool.getState() == poolClosed { diff --git a/x/mongo/driver/topology/pool_test.go b/x/mongo/driver/topology/pool_test.go index e0265ae4c6..0f8a5a0570 100644 --- a/x/mongo/driver/topology/pool_test.go +++ b/x/mongo/driver/topology/pool_test.go @@ -70,14 +70,14 @@ func TestPool(t *testing.T) { Address: address.Address(addr.String()), }) err := p1.ready() - noerr(t, err) + require.NoError(t, err) c, err := p1.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) p2 := newPool(poolConfig{}) err = p2.ready() - noerr(t, err) + require.NoError(t, err) err = p2.closeConnection(c) assert.Equalf(t, ErrWrongPool, err, "expected ErrWrongPool error") @@ -94,7 +94,7 @@ func TestPool(t *testing.T) { p := newPool(poolConfig{}) err := p.ready() - noerr(t, err) + require.NoError(t, err) for i := 0; i < 5; i++ { p.close(context.Background()) @@ -115,16 +115,16 @@ func TestPool(t *testing.T) { Address: address.Address(addr.String()), }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() - noerr(t, err) + require.NoError(t, err) conns := make([]*connection, 3) for i := range conns { conns[i], err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) } for i := range conns { err = p.checkIn(conns[i]) - noerr(t, err) + require.NoError(t, err) } assert.Equalf(t, 3, d.lenopened(), "should have opened 3 connections") assert.Equalf(t, 0, d.lenclosed(), "should have closed 0 connections") @@ -151,16 +151,16 @@ func TestPool(t *testing.T) { Address: address.Address(addr.String()), }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() - noerr(t, err) + require.NoError(t, err) conns := make([]*connection, 3) for i := range conns { conns[i], err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) } for i := 0; i < 2; i++ { err = p.checkIn(conns[i]) - noerr(t, err) + require.NoError(t, err) } assert.Equalf(t, 3, d.lenopened(), "should have opened 3 connections") assert.Equalf(t, 0, d.lenclosed(), "should have closed 0 connections") @@ -186,10 +186,10 @@ func TestPool(t *testing.T) { Address: address.Address(addr.String()), }) err := p.ready() - noerr(t, err) + require.NoError(t, err) _, err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) closed := make(chan struct{}) started := make(chan struct{}) @@ -212,7 +212,7 @@ func TestPool(t *testing.T) { // connection pool. <-started _, err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) p.close(context.Background()) @@ -232,13 +232,13 @@ func TestPool(t *testing.T) { Address: address.Address(addr.String()), }) err := p.ready() - noerr(t, err) + require.NoError(t, err) // Check out 2 connections from the pool and add them to a conns slice. conns := make([]*connection, 2) for i := 0; i < 2; i++ { c, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) conns[i] = c } @@ -246,10 +246,10 @@ func TestPool(t *testing.T) { // Check out a 3rd connection from the pool and immediately check it back in so there is // a mixture of in-use and idle connections. c, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) err = p.checkIn(c) - noerr(t, err) + require.NoError(t, err) // Start a goroutine that waits for the pool to start closing, then checks in the // 2 in-use connections. Assert that both connections are still connected during @@ -262,7 +262,7 @@ func TestPool(t *testing.T) { assert.Equalf(t, connConnected, c.state, "expected conn to still be connected") err := p.checkIn(c) - noerr(t, err) + require.NoError(t, err) } }() @@ -287,16 +287,16 @@ func TestPool(t *testing.T) { Address: address.Address(addr.String()), }) err := p.ready() - noerr(t, err) + require.NoError(t, err) c, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) p.close(context.Background()) c1 := &Connection{connection: c} err = c1.Close() - noerr(t, err) + require.NoError(t, err) }) }) t.Run("ready", func(t *testing.T) { @@ -316,12 +316,12 @@ func TestPool(t *testing.T) { Address: address.Address(addr.String()), }) err := p.ready() - noerr(t, err) + require.NoError(t, err) conns := make([]*connection, 3) for i := range conns { conn, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) conns[i] = conn } assert.Equalf(t, 0, p.availableConnectionCount(), "should have 0 available connections") @@ -330,17 +330,17 @@ func TestPool(t *testing.T) { p.clear(nil, nil) for _, conn := range conns { err = p.checkIn(conn) - noerr(t, err) + require.NoError(t, err) } assert.Equalf(t, 0, p.availableConnectionCount(), "should have 0 available connections") assert.Equalf(t, 0, p.totalConnectionCount(), "should have 0 total connections") err = p.ready() - noerr(t, err) + require.NoError(t, err) for i := 0; i < 3; i++ { _, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) } assert.Equalf(t, 0, p.availableConnectionCount(), "should have 0 available connections") assert.Equalf(t, 3, p.totalConnectionCount(), "should have 3 total connections") @@ -353,7 +353,7 @@ func TestPool(t *testing.T) { p := newPool(poolConfig{}) for i := 0; i < 5; i++ { err := p.ready() - noerr(t, err) + require.NoError(t, err) } p.close(context.Background()) @@ -372,27 +372,27 @@ func TestPool(t *testing.T) { Address: address.Address(addr.String()), }) err := p.ready() - noerr(t, err) + require.NoError(t, err) c, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) err = p.checkIn(c) - noerr(t, err) + require.NoError(t, err) for i := 0; i < 100; i++ { err = p.ready() - noerr(t, err) + require.NoError(t, err) p.clear(nil, nil) } err = p.ready() - noerr(t, err) + require.NoError(t, err) c, err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) err = p.checkIn(c) - noerr(t, err) + require.NoError(t, err) p.close(context.Background()) }) @@ -410,12 +410,12 @@ func TestPool(t *testing.T) { Address: address.Address(addr.String()), }) err := p.ready() - noerr(t, err) + require.NoError(t, err) c, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) err = p.checkIn(c) - noerr(t, err) + require.NoError(t, err) var wg sync.WaitGroup for i := 0; i < 10; i++ { @@ -424,7 +424,7 @@ func TestPool(t *testing.T) { defer wg.Done() for i := 0; i < 1000; i++ { err := p.ready() - noerr(t, err) + require.NoError(t, err) } }() @@ -439,12 +439,12 @@ func TestPool(t *testing.T) { wg.Wait() err = p.ready() - noerr(t, err) + require.NoError(t, err) c, err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) err = p.checkIn(c) - noerr(t, err) + require.NoError(t, err) p.close(context.Background()) }) @@ -462,7 +462,7 @@ func TestPool(t *testing.T) { }) })) err := p.ready() - noerr(t, err) + require.NoError(t, err) _, err = p.checkOut(context.Background()) var want error = ConnectionError{Wrapped: dialErr, init: true} @@ -499,25 +499,25 @@ func TestPool(t *testing.T) { WithDialer(func(Dialer) Dialer { return d }), ) err := p.ready() - noerr(t, err) + require.NoError(t, err) // Check out a connection and assert that the idle timeout is properly set then check it // back into the pool. c1, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) assert.Equalf(t, 1, d.lenopened(), "should have opened 1 connection") assert.Equalf(t, 1, p.totalConnectionCount(), "pool should have 1 total connection") assert.Equalf(t, time.Millisecond, c1.idleTimeout, "connection should have a 1ms idle timeout") err = p.checkIn(c1) - noerr(t, err) + require.NoError(t, err) // Sleep for more than the 1ms idle timeout and then try to check out a connection. // Expect that the previously checked-out connection is closed because it's idle and a // new connection is created. time.Sleep(50 * time.Millisecond) c2, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) // Assert that the connection pointers are not equal. Don't use "assert.NotEqual" because it asserts // non-equality of fields, possibly accessing some fields non-atomically and causing a race condition. assert.True(t, c1 != c2, "expected a new connection on 2nd check out after idle timeout expires") @@ -541,14 +541,14 @@ func TestPool(t *testing.T) { Address: address.Address(addr.String()), }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() - noerr(t, err) + require.NoError(t, err) for i := 0; i < 100; i++ { c, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) err = p.checkIn(c) - noerr(t, err) + require.NoError(t, err) } assert.Equalf(t, 1, d.lenopened(), "should have opened 1 connection") @@ -568,7 +568,7 @@ func TestPool(t *testing.T) { Address: address.Address(addr.String()), }) err := p.ready() - noerr(t, err) + require.NoError(t, err) p.close(context.Background()) @@ -594,7 +594,7 @@ func TestPool(t *testing.T) { }), ) err := p.ready() - noerr(t, err) + require.NoError(t, err) _, err = p.checkOut(context.Background()) assert.IsTypef(t, ConnectionError{}, err, "expected a ConnectionError") @@ -636,11 +636,11 @@ func TestPool(t *testing.T) { MaxPoolSize: 1, }) err := p.ready() - noerr(t, err) + require.NoError(t, err) // check out first connection. _, err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) // Set a short timeout and check out again. ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) @@ -676,11 +676,11 @@ func TestPool(t *testing.T) { MaxPoolSize: 1, }) err := p.ready() - noerr(t, err) + require.NoError(t, err) // Check out the 1 connection that the pool will create. c, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) // Start a goroutine that tries to check out another connection with no timeout. Expect // this goroutine to block (wait in the wait queue) until the checked-out connection is @@ -691,7 +691,7 @@ func TestPool(t *testing.T) { defer wg.Done() _, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) }() // Run lots of check-out attempts with a low timeout and assert that each one fails with @@ -707,7 +707,7 @@ func TestPool(t *testing.T) { // Check-in the connection we checked out earlier and wait for the checkOut() goroutine // to resume. err = p.checkIn(c) - noerr(t, err) + require.NoError(t, err) wg.Wait() p.close(context.Background()) @@ -733,14 +733,14 @@ func TestPool(t *testing.T) { WithDialer(func(Dialer) Dialer { return d }), ) err := p.ready() - noerr(t, err) + require.NoError(t, err) // Check out two connections (MaxPoolSize) so that subsequent checkOut() calls should // block until a connection is checked back in or removed from the pool. c, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) _, err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) assert.Equalf(t, 2, d.lenopened(), "should have opened 2 connection") assert.Equalf(t, 2, p.totalConnectionCount(), "pool should have 2 total connection") assert.Equalf(t, 0, p.availableConnectionCount(), "pool should have 0 idle connection") @@ -765,10 +765,10 @@ func TestPool(t *testing.T) { c.close() start = time.Now() err := p.checkIn(c) - noerr(t, err) + require.NoError(t, err) }() _, err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) assert.WithinDurationf( t, time.Now(), @@ -798,11 +798,11 @@ func TestPool(t *testing.T) { MaxPoolSize: 1, }) err := p.ready() - noerr(t, err) + require.NoError(t, err) // Check out first connection. _, err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) // Use a canceled context to check out another connection. cancelCtx, cancel := context.WithCancel(context.Background()) @@ -817,6 +817,79 @@ func TestPool(t *testing.T) { assert.Containsf(t, err.Error(), "canceled", `expected error message to contain "canceled"`) } + p.close(context.Background()) + }) + t.Run("discards connections closed by the server side", func(t *testing.T) { + t.Parallel() + + cleanup := make(chan struct{}) + defer close(cleanup) + + ncs := make(chan net.Conn, 2) + addr := bootstrapConnections(t, 2, func(nc net.Conn) { + // Send all "server-side" connections to a channel so we can + // interact with them during the test. + ncs <- nc + + <-cleanup + _ = nc.Close() + }) + + d := newdialer(&net.Dialer{}) + p := newPool(poolConfig{ + Address: address.Address(addr.String()), + }, WithDialer(func(Dialer) Dialer { return d })) + err := p.ready() + require.NoError(t, err) + + // Add 1 idle connection to the pool by checking-out and checking-in + // a connection. + conn, err := p.checkOut(context.Background()) + require.NoError(t, err) + err = p.checkIn(conn) + require.NoError(t, err) + assertConnectionsOpened(t, d, 1) + assert.Equalf(t, 1, p.availableConnectionCount(), "should be 1 idle connections in pool") + assert.Equalf(t, 1, p.totalConnectionCount(), "should be 1 total connection in pool") + + // Make that connection appear as if it's been idle for a minute. + conn.idleStart.Store(time.Now().Add(-1 * time.Minute)) + + // Close the "server-side" of the connection we just created. The idle + // connection in the pool is now unusable because the "server-side" + // closed it. + nc := <-ncs + err = nc.Close() + require.NoError(t, err) + + // In a separate goroutine, write a valid wire message to the 2nd + // connection that's about to be created. Stop waiting for a 2nd + // connection after 100ms to prevent leaking a goroutine. + go func() { + select { + case nc := <-ncs: + _, err = nc.Write([]byte{5, 0, 0, 0, 0}) + require.NoError(t, err, "Write error") + case <-time.After(100 * time.Millisecond): + } + }() + + // Check out a connection and try to read from it. Expect the pool to + // discard the connection that was closed by the "server-side" and + // return a newly created connection instead. + conn, err = p.checkOut(context.Background()) + require.NoError(t, err) + msg, err := conn.readWireMessage(context.Background()) + require.NoError(t, err) + assert.Equal(t, []byte{5, 0, 0, 0, 0}, msg) + + err = p.checkIn(conn) + require.NoError(t, err) + + assertConnectionsOpened(t, d, 2) + assert.Equalf(t, 1, p.availableConnectionCount(), "should be 1 idle connections in pool") + assert.Equalf(t, 1, p.totalConnectionCount(), "should be 1 total connection in pool") + p.close(context.Background()) }) }) @@ -837,15 +910,15 @@ func TestPool(t *testing.T) { Address: address.Address(addr.String()), }) err := p.ready() - noerr(t, err) + require.NoError(t, err) c, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) assert.Equalf(t, 0, p.availableConnectionCount(), "should be no idle connections in pool") assert.Equalf(t, 1, p.totalConnectionCount(), "should be 1 total connection in pool") err = p.checkIn(c) - noerr(t, err) + require.NoError(t, err) err = p.checkIn(c) assert.NotNilf(t, err, "expected an error trying to return the same conn to the pool twice") @@ -870,10 +943,10 @@ func TestPool(t *testing.T) { Address: address.Address(addr.String()), }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() - noerr(t, err) + require.NoError(t, err) c, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) assert.Equalf(t, 0, d.lenclosed(), "should have closed 0 connections") assert.Equalf(t, 0, p.availableConnectionCount(), "should have 0 idle connections in pool") assert.Equalf(t, 1, p.totalConnectionCount(), "should have 1 total connection in pool") @@ -881,7 +954,7 @@ func TestPool(t *testing.T) { p.close(context.Background()) err = p.checkIn(c) - noerr(t, err) + require.NoError(t, err) assert.Equalf(t, 1, d.lenclosed(), "should have closed 1 connection") assert.Equalf(t, 0, p.availableConnectionCount(), "should have 0 idle connections in pool") assert.Equalf(t, 0, p.totalConnectionCount(), "should have 0 total connection in pool") @@ -900,14 +973,14 @@ func TestPool(t *testing.T) { Address: address.Address(addr.String()), }) err := p1.ready() - noerr(t, err) + require.NoError(t, err) c, err := p1.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) p2 := newPool(poolConfig{}) err = p2.ready() - noerr(t, err) + require.NoError(t, err) err = p2.checkIn(c) assert.Equalf(t, ErrWrongPool, err, "expected ErrWrongPool error") @@ -931,18 +1004,18 @@ func TestPool(t *testing.T) { MaxIdleTime: 100 * time.Millisecond, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() - noerr(t, err) + require.NoError(t, err) defer p.close(context.Background()) c, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) // Sleep for 110ms, which will exceed the 100ms connection idle timeout. Then check the // connection back in and expect that it is not closed because checkIn() should bump the // connection idle deadline. time.Sleep(110 * time.Millisecond) err = p.checkIn(c) - noerr(t, err) + require.NoError(t, err) assert.Equalf(t, 0, d.lenclosed(), "should have closed 0 connections") assert.Equalf(t, 1, p.availableConnectionCount(), "should have 1 idle connections in pool") @@ -965,7 +1038,7 @@ func TestPool(t *testing.T) { MaxIdleTime: 10 * time.Millisecond, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() - noerr(t, err) + require.NoError(t, err) defer p.close(context.Background()) // Wait for maintain() to open 3 connections. @@ -977,7 +1050,7 @@ func TestPool(t *testing.T) { // and tries to create a new connection. time.Sleep(100 * time.Millisecond) _, err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) assertConnectionsClosed(t, d, 3) assert.Equalf(t, 4, d.lenopened(), "should have opened 4 connections") @@ -1004,7 +1077,7 @@ func TestPool(t *testing.T) { MinPoolSize: 3, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() - noerr(t, err) + require.NoError(t, err) assertConnectionsOpened(t, d, 3) assert.Equalf(t, 3, p.availableConnectionCount(), "should be 3 idle connections in pool") @@ -1029,7 +1102,7 @@ func TestPool(t *testing.T) { MaxPoolSize: 2, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() - noerr(t, err) + require.NoError(t, err) assertConnectionsOpened(t, d, 2) assert.Equalf(t, 2, p.availableConnectionCount(), "should be 2 idle connections in pool") @@ -1054,18 +1127,18 @@ func TestPool(t *testing.T) { MaintainInterval: 10 * time.Millisecond, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() - noerr(t, err) + require.NoError(t, err) // Check out and check in 3 connections. Assert that there are 3 total and 3 idle // connections in the pool. conns := make([]*connection, 3) for i := range conns { conns[i], err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) } for _, c := range conns { err = p.checkIn(c) - noerr(t, err) + require.NoError(t, err) } assert.Equalf(t, 3, d.lenopened(), "should have opened 3 connections") assert.Equalf(t, 3, p.availableConnectionCount(), "should be 3 idle connections in pool") @@ -1077,7 +1150,7 @@ func TestPool(t *testing.T) { p.idleMu.Lock() for i := 0; i < 2; i++ { p.idleConns[i].idleTimeout = time.Millisecond - p.idleConns[i].idleDeadline.Store(time.Now().Add(-1 * time.Hour)) + p.idleConns[i].idleStart.Store(time.Now().Add(-1 * time.Hour)) } p.idleMu.Unlock() assertConnectionsClosed(t, d, 2) @@ -1104,7 +1177,7 @@ func TestPool(t *testing.T) { MaintainInterval: 10 * time.Millisecond, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() - noerr(t, err) + require.NoError(t, err) assertConnectionsOpened(t, d, 3) assert.Equalf(t, 3, p.availableConnectionCount(), "should be 3 idle connections in pool") assert.Equalf(t, 3, p.totalConnectionCount(), "should be 3 total connection in pool") @@ -1112,7 +1185,7 @@ func TestPool(t *testing.T) { p.idleMu.Lock() for i := 0; i < 2; i++ { p.idleConns[i].idleTimeout = time.Millisecond - p.idleConns[i].idleDeadline.Store(time.Now().Add(-1 * time.Hour)) + p.idleConns[i].idleStart.Store(time.Now().Add(-1 * time.Hour)) } p.idleMu.Unlock() assertConnectionsClosed(t, d, 2) @@ -1154,7 +1227,7 @@ func TestBackgroundRead(t *testing.T) { }() _, err := nc.Write([]byte{10, 0, 0}) - noerr(t, err) + require.NoError(t, err) }) p := newPool( @@ -1162,10 +1235,10 @@ func TestBackgroundRead(t *testing.T) { ) defer p.close(context.Background()) err := p.ready() - noerr(t, err) + require.NoError(t, err) conn, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) ctx, cancel := csot.MakeTimeoutContext(context.Background(), timeout) defer cancel() _, err = conn.readWireMessage(ctx) @@ -1194,7 +1267,7 @@ func TestBackgroundRead(t *testing.T) { // Wait until the operation times out, then write an full message. time.Sleep(timeout * 2) _, err := nc.Write([]byte{10, 0, 0, 0, 0, 0, 0, 0, 0, 0}) - noerr(t, err) + require.NoError(t, err) }) p := newPool( @@ -1202,10 +1275,10 @@ func TestBackgroundRead(t *testing.T) { ) defer p.close(context.Background()) err := p.ready() - noerr(t, err) + require.NoError(t, err) conn, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) ctx, cancel := csot.MakeTimeoutContext(context.Background(), timeout) defer cancel() _, err = conn.readWireMessage(ctx) @@ -1214,7 +1287,7 @@ func TestBackgroundRead(t *testing.T) { ) assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) err = p.checkIn(conn) - noerr(t, err) + require.NoError(t, err) var bgErrs []error select { case bgErrs = <-errsCh: @@ -1241,7 +1314,7 @@ func TestBackgroundRead(t *testing.T) { // Wait until the operation times out, then write an incomplete head. time.Sleep(timeout * 2) _, err := nc.Write([]byte{10, 0, 0}) - noerr(t, err) + require.NoError(t, err) }) p := newPool( @@ -1249,10 +1322,10 @@ func TestBackgroundRead(t *testing.T) { ) defer p.close(context.Background()) err := p.ready() - noerr(t, err) + require.NoError(t, err) conn, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) ctx, cancel := csot.MakeTimeoutContext(context.Background(), timeout) defer cancel() _, err = conn.readWireMessage(ctx) @@ -1261,7 +1334,7 @@ func TestBackgroundRead(t *testing.T) { ) assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) err = p.checkIn(conn) - noerr(t, err) + require.NoError(t, err) var bgErrs []error select { case bgErrs = <-errsCh: @@ -1293,7 +1366,7 @@ func TestBackgroundRead(t *testing.T) { // message. time.Sleep(timeout * 2) _, err := nc.Write([]byte{10, 0, 0, 0, 0, 0, 0, 0}) - noerr(t, err) + require.NoError(t, err) }) p := newPool( @@ -1301,10 +1374,10 @@ func TestBackgroundRead(t *testing.T) { ) defer p.close(context.Background()) err := p.ready() - noerr(t, err) + require.NoError(t, err) conn, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) ctx, cancel := csot.MakeTimeoutContext(context.Background(), timeout) defer cancel() _, err = conn.readWireMessage(ctx) @@ -1313,7 +1386,7 @@ func TestBackgroundRead(t *testing.T) { ) assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) err = p.checkIn(conn) - noerr(t, err) + require.NoError(t, err) var bgErrs []error select { case bgErrs = <-errsCh: @@ -1343,11 +1416,11 @@ func TestBackgroundRead(t *testing.T) { var err error _, err = nc.Write([]byte{12, 0, 0, 0, 0, 0, 0, 0, 1}) - noerr(t, err) + require.NoError(t, err) time.Sleep(timeout * 2) // write a complete message _, err = nc.Write([]byte{2, 3, 4}) - noerr(t, err) + require.NoError(t, err) }) p := newPool( @@ -1355,10 +1428,10 @@ func TestBackgroundRead(t *testing.T) { ) defer p.close(context.Background()) err := p.ready() - noerr(t, err) + require.NoError(t, err) conn, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) ctx, cancel := csot.MakeTimeoutContext(context.Background(), timeout) defer cancel() _, err = conn.readWireMessage(ctx) @@ -1367,7 +1440,7 @@ func TestBackgroundRead(t *testing.T) { ) assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) err = p.checkIn(conn) - noerr(t, err) + require.NoError(t, err) var bgErrs []error select { case bgErrs = <-errsCh: @@ -1393,11 +1466,11 @@ func TestBackgroundRead(t *testing.T) { var err error _, err = nc.Write([]byte{12, 0, 0, 0, 0, 0, 0, 0, 1}) - noerr(t, err) + require.NoError(t, err) time.Sleep(timeout * 2) // write an incomplete message _, err = nc.Write([]byte{2}) - noerr(t, err) + require.NoError(t, err) }) p := newPool( @@ -1405,10 +1478,10 @@ func TestBackgroundRead(t *testing.T) { ) defer p.close(context.Background()) err := p.ready() - noerr(t, err) + require.NoError(t, err) conn, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) ctx, cancel := csot.MakeTimeoutContext(context.Background(), timeout) defer cancel() _, err = conn.readWireMessage(ctx) @@ -1417,7 +1490,7 @@ func TestBackgroundRead(t *testing.T) { ) assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) err = p.checkIn(conn) - noerr(t, err) + require.NoError(t, err) var bgErrs []error select { case bgErrs = <-errsCh: diff --git a/x/mongo/driver/topology/server_test.go b/x/mongo/driver/topology/server_test.go index a418e690a5..1c10d6188a 100644 --- a/x/mongo/driver/topology/server_test.go +++ b/x/mongo/driver/topology/server_test.go @@ -608,11 +608,11 @@ func TestServer(t *testing.T) { })) s.state = serverConnected err := s.pool.ready() - noerr(t, err) + require.NoError(t, err) defer s.pool.close(context.Background()) conn, err := s.Connection(context.Background()) - noerr(t, err) + require.NoError(t, err) if d.lenopened() != 1 { t.Errorf("Should have opened 1 connections, but didn't. got %d; want %d", d.lenopened(), 1) } @@ -634,7 +634,7 @@ func TestServer(t *testing.T) { <-ch runtime.Gosched() err = conn.Close() - noerr(t, err) + require.NoError(t, err) wg.Wait() close(cleanup) }) diff --git a/x/mongo/driver/topology/topology_errors_test.go b/x/mongo/driver/topology/topology_errors_test.go index c7dc7336e9..1831a16e72 100644 --- a/x/mongo/driver/topology/topology_errors_test.go +++ b/x/mongo/driver/topology/topology_errors_test.go @@ -17,6 +17,7 @@ import ( "time" "go.mongodb.org/mongo-driver/internal/assert" + "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/mongo/description" ) @@ -28,7 +29,7 @@ func TestTopologyErrors(t *testing.T) { t.Run("errors are wrapped", func(t *testing.T) { t.Run("server selection error", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) atomic.StoreInt64(&topo.state, topologyConnected) desc := description.Topology{ diff --git a/x/mongo/driver/topology/topology_test.go b/x/mongo/driver/topology/topology_test.go index fd0703b97b..ad91d95e04 100644 --- a/x/mongo/driver/topology/topology_test.go +++ b/x/mongo/driver/topology/topology_test.go @@ -33,14 +33,6 @@ import ( const testTimeout = 2 * time.Second -func noerr(t *testing.T, err error) { - t.Helper() - if err != nil { - t.Errorf("Unexpected error: %v", err) - t.FailNow() - } -} - func compareErrors(err1, err2 error) bool { if err1 == nil && err2 == nil { return true @@ -74,7 +66,7 @@ func TestServerSelection(t *testing.T) { t.Run("Success", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) desc := description.Topology{ Servers: []description.Server{ {Addr: address.Address("one"), Kind: description.Standalone}, @@ -87,7 +79,7 @@ func TestServerSelection(t *testing.T) { state := newServerSelectionState(selectFirst, nil) srvs, err := topo.selectServerFromSubscription(context.Background(), subCh, state) - noerr(t, err) + require.NoError(t, err) if len(srvs) != 1 { t.Errorf("Incorrect number of descriptions returned. got %d; want %d", len(srvs), 1) } @@ -97,7 +89,7 @@ func TestServerSelection(t *testing.T) { }) t.Run("Compatibility Error Min Version Too High", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) desc := description.Topology{ Kind: description.Single, Servers: []description.Server{ @@ -120,7 +112,7 @@ func TestServerSelection(t *testing.T) { }) t.Run("Compatibility Error Max Version Too Low", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) desc := description.Topology{ Kind: description.Single, Servers: []description.Server{ @@ -143,7 +135,7 @@ func TestServerSelection(t *testing.T) { }) t.Run("Updated", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) desc := description.Topology{Servers: []description.Server{}} subCh := make(chan description.Topology, 1) subCh <- desc @@ -152,7 +144,7 @@ func TestServerSelection(t *testing.T) { go func() { state := newServerSelectionState(selectFirst, nil) srvs, err := topo.selectServerFromSubscription(context.Background(), subCh, state) - noerr(t, err) + require.NoError(t, err) resp <- srvs }() @@ -192,7 +184,7 @@ func TestServerSelection(t *testing.T) { }, } topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) subCh := make(chan description.Topology, 1) subCh <- desc resp := make(chan error) @@ -229,7 +221,7 @@ func TestServerSelection(t *testing.T) { }, } topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) subCh := make(chan description.Topology, 1) subCh <- desc resp := make(chan error) @@ -265,7 +257,7 @@ func TestServerSelection(t *testing.T) { }, } topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) subCh := make(chan description.Topology, 1) subCh <- desc resp := make(chan error) @@ -288,10 +280,10 @@ func TestServerSelection(t *testing.T) { }) t.Run("findServer returns topology kind", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) atomic.StoreInt64(&topo.state, topologyConnected) srvr, err := ConnectServer(address.Address("one"), topo.updateCallback, topo.id) - noerr(t, err) + require.NoError(t, err) topo.servers[address.Address("one")] = srvr desc := topo.desc.Load().(description.Topology) desc.Kind = description.Single @@ -300,14 +292,14 @@ func TestServerSelection(t *testing.T) { selected := description.Server{Addr: address.Address("one")} ss, err := topo.FindServer(selected) - noerr(t, err) + require.NoError(t, err) if ss.Kind != description.Single { t.Errorf("findServer does not properly set the topology description kind. got %v; want %v", ss.Kind, description.Single) } }) t.Run("Update on not primary error", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) atomic.StoreInt64(&topo.state, topologyConnected) addr1 := address.Address("one") @@ -324,7 +316,7 @@ func TestServerSelection(t *testing.T) { // manually add the servers to the topology for _, srv := range desc.Servers { s, err := ConnectServer(srv.Addr, topo.updateCallback, topo.id) - noerr(t, err) + require.NoError(t, err) topo.servers[srv.Addr] = s } @@ -342,7 +334,7 @@ func TestServerSelection(t *testing.T) { // send a not primary error to the server forcing an update serv, err := topo.FindServer(desc.Servers[0]) - noerr(t, err) + require.NoError(t, err) atomic.StoreInt64(&serv.state, serverConnected) _ = serv.ProcessError(driver.Error{Message: driver.LegacyNotPrimaryErrMsg}, initConnection{}) @@ -352,7 +344,7 @@ func TestServerSelection(t *testing.T) { // server selection should discover the new topology state := newServerSelectionState(description.WriteSelector(), nil) srvs, err := topo.selectServerFromSubscription(context.Background(), subCh, state) - noerr(t, err) + require.NoError(t, err) resp <- srvs }() @@ -373,7 +365,7 @@ func TestServerSelection(t *testing.T) { t.Run("fast path does not subscribe or check timeouts", func(t *testing.T) { // Assert that the server selection fast path does not create a Subscription or check for timeout errors. topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) atomic.StoreInt64(&topo.state, topologyConnected) primaryAddr := address.Address("one") @@ -385,7 +377,7 @@ func TestServerSelection(t *testing.T) { topo.desc.Store(desc) for _, srv := range desc.Servers { s, err := ConnectServer(srv.Addr, topo.updateCallback, topo.id) - noerr(t, err) + require.NoError(t, err) topo.servers[srv.Addr] = s } @@ -395,13 +387,13 @@ func TestServerSelection(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() selectedServer, err := topo.SelectServer(ctx, description.WriteSelector()) - noerr(t, err) + require.NoError(t, err) selectedAddr := selectedServer.(*SelectedServer).address assert.Equal(t, primaryAddr, selectedAddr, "expected address %v, got %v", primaryAddr, selectedAddr) }) t.Run("default to selecting from subscription if fast path fails", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) atomic.StoreInt64(&topo.state, topologyConnected) desc := description.Topology{ @@ -420,7 +412,7 @@ func TestSessionTimeout(t *testing.T) { t.Run("UpdateSessionTimeout", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) topo.servers["foo"] = nil topo.fsm.Servers = []description.Server{ { @@ -449,7 +441,7 @@ func TestSessionTimeout(t *testing.T) { }) t.Run("MultipleUpdates", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) topo.fsm.Kind = description.ReplicaSetWithPrimary topo.servers["foo"] = nil topo.servers["bar"] = nil @@ -496,7 +488,7 @@ func TestSessionTimeout(t *testing.T) { }) t.Run("NoUpdate", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) topo.servers["foo"] = nil topo.servers["bar"] = nil topo.fsm.Servers = []description.Server{ @@ -542,7 +534,7 @@ func TestSessionTimeout(t *testing.T) { }) t.Run("TimeoutDataBearing", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) topo.servers["foo"] = nil topo.servers["bar"] = nil topo.fsm.Servers = []description.Server{ @@ -588,7 +580,7 @@ func TestSessionTimeout(t *testing.T) { }) t.Run("MixedSessionSupport", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) topo.fsm.Kind = description.ReplicaSetWithPrimary topo.servers["one"] = nil topo.servers["two"] = nil