diff --git a/p2p/test/transport/transport_test.go b/p2p/test/transport/transport_test.go index e960806444..daae9d6cfe 100644 --- a/p2p/test/transport/transport_test.go +++ b/p2p/test/transport/transport_test.go @@ -20,6 +20,7 @@ import ( "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" + rcmgr "github.com/libp2p/go-libp2p/p2p/host/resource-manager" "github.com/libp2p/go-libp2p/p2p/muxer/yamux" "github.com/libp2p/go-libp2p/p2p/protocol/ping" "github.com/libp2p/go-libp2p/p2p/security/noise" @@ -357,8 +358,17 @@ func TestMoreStreamsThanOurLimits(t *testing.T) { const streamCount = 1024 for _, tc := range transportsToTest { t.Run(tc.Name, func(t *testing.T) { - listener := tc.HostGenerator(t, TransportTestCaseOpts{}) - dialer := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true}) + listenerLimits := rcmgr.PartialLimitConfig{ + PeerDefault: rcmgr.ResourceLimits{ + Streams: 32, + StreamsInbound: 16, + StreamsOutbound: 16, + }, + } + r, err := rcmgr.NewResourceManager(rcmgr.NewFixedLimiter(listenerLimits.Build(rcmgr.DefaultLimits.AutoScale()))) + require.NoError(t, err) + listener := tc.HostGenerator(t, TransportTestCaseOpts{ResourceManager: r}) + dialer := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true, NoRcmgr: true}) defer listener.Close() defer dialer.Close() @@ -370,101 +380,132 @@ func TestMoreStreamsThanOurLimits(t *testing.T) { var handledStreams atomic.Int32 var sawFirstErr atomic.Bool - semaphore := make(chan struct{}, streamCount) - // Start with a single stream at a time. If that works, we'll increase the number of concurrent streams. - semaphore <- struct{}{} + workQueue := make(chan struct{}, streamCount) + for i := 0; i < streamCount; i++ { + workQueue <- struct{}{} + } + close(workQueue) listener.SetStreamHandler("echo", func(s network.Stream) { + // Wait a bit so that we have more parallel streams open at the same time + time.Sleep(time.Millisecond * 10) io.Copy(s, s) s.Close() }) wg := sync.WaitGroup{} - wg.Add(streamCount) errCh := make(chan error, 1) var completedStreams atomic.Int32 - for i := 0; i < streamCount; i++ { - go func() { - <-semaphore - var didErr bool - defer wg.Done() - defer completedStreams.Add(1) - defer func() { - select { - case semaphore <- struct{}{}: - default: - } - if !didErr && !sawFirstErr.Load() { - // No error! We can add one more stream to our concurrency limit. - select { - case semaphore <- struct{}{}: - default: - } - } - }() - var s network.Stream - var err error - // maxRetries is an arbitrary retry amount if there's any error. - maxRetries := streamCount * 4 - shouldRetry := func(err error) bool { - didErr = true - sawFirstErr.Store(true) - maxRetries-- - if maxRetries == 0 || len(errCh) > 0 { - select { - case errCh <- errors.New("max retries exceeded"): - default: - } - return false - } - return true + const maxWorkerCount = streamCount + workerCount := 4 + + var startWorker func(workerIdx int) + startWorker = func(workerIdx int) { + wg.Add(1) + defer wg.Done() + for { + _, ok := <-workQueue + if !ok { + return } - for { - s, err = dialer.NewStream(context.Background(), listener.ID(), "echo") - if err != nil { - if shouldRetry(err) { - time.Sleep(50 * time.Millisecond) - continue - } - } - err = func(s network.Stream) error { - defer s.Close() - _, err = s.Write([]byte("hello")) - if err != nil { - return err + // Inline function so we can use defer + func() { + var didErr bool + defer completedStreams.Add(1) + defer func() { + // Only the first worker adds more workers + if workerIdx == 0 && !didErr && !sawFirstErr.Load() { + nextWorkerCount := workerCount * 2 + if nextWorkerCount < maxWorkerCount { + for i := workerCount; i < nextWorkerCount; i++ { + go startWorker(i) + } + workerCount = nextWorkerCount + } } - - err = s.CloseWrite() - if err != nil { - return err + }() + + var s network.Stream + var err error + // maxRetries is an arbitrary retry amount if there's any error. + maxRetries := streamCount * 4 + shouldRetry := func(err error) bool { + didErr = true + sawFirstErr.Store(true) + maxRetries-- + if maxRetries == 0 || len(errCh) > 0 { + select { + case errCh <- errors.New("max retries exceeded"): + default: + } + return false } + return true + } - b, err := io.ReadAll(s) + for { + s, err = dialer.NewStream(context.Background(), listener.ID(), "echo") if err != nil { - return err + if shouldRetry(err) { + time.Sleep(50 * time.Millisecond) + continue + } } - if !bytes.Equal(b, []byte("hello")) { - return errors.New("received data does not match sent data") + err = func(s network.Stream) error { + defer s.Close() + err = s.SetDeadline(time.Now().Add(100 * time.Millisecond)) + if err != nil { + return err + } + + _, err = s.Write([]byte("hello")) + if err != nil { + return err + } + + err = s.CloseWrite() + if err != nil { + return err + } + + b, err := io.ReadAll(s) + if err != nil { + return err + } + if !bytes.Equal(b, []byte("hello")) { + return errors.New("received data does not match sent data") + } + handledStreams.Add(1) + + return nil + }(s) + if err != nil && shouldRetry(err) { + time.Sleep(50 * time.Millisecond) + continue } - handledStreams.Add(1) + return - return nil - }(s) - if err != nil && shouldRetry(err) { - time.Sleep(50 * time.Millisecond) - continue } - return - } - }() + }() + } + } + + // Create any initial parallel workers + for i := 1; i < workerCount; i++ { + go startWorker(i) } + + // Start the first worker + startWorker(0) + wg.Wait() close(errCh) require.NoError(t, <-errCh) require.Equal(t, streamCount, int(handledStreams.Load())) + require.True(t, sawFirstErr.Load(), "Expected to see an error from the peer") }) } }