diff --git a/kv/memberlist/tcp_transport.go b/kv/memberlist/tcp_transport.go index abd7d1c90..c74e50d17 100644 --- a/kv/memberlist/tcp_transport.go +++ b/kv/memberlist/tcp_transport.go @@ -76,13 +76,18 @@ func (cfg *TCPTransportConfig) RegisterFlagsWithPrefix(f *flag.FlagSet, prefix s f.IntVar(&cfg.BindPort, prefix+"memberlist.bind-port", 7946, "Port to listen on for gossip messages.") f.DurationVar(&cfg.PacketDialTimeout, prefix+"memberlist.packet-dial-timeout", 2*time.Second, "Timeout used when connecting to other nodes to send packet.") f.DurationVar(&cfg.PacketWriteTimeout, prefix+"memberlist.packet-write-timeout", 5*time.Second, "Timeout for writing 'packet' data.") - f.IntVar(&cfg.MaxConcurrentWrites, prefix+"memberlist.max-concurrent-writes", 1, "Maximum number of concurrent writes to other nodes.") + f.IntVar(&cfg.MaxConcurrentWrites, prefix+"memberlist.max-concurrent-writes", 3, "Maximum number of concurrent writes to other nodes.") f.BoolVar(&cfg.TransportDebug, prefix+"memberlist.transport-debug", false, "Log debug transport messages. Note: global log.level must be at debug level as well.") f.BoolVar(&cfg.TLSEnabled, prefix+"memberlist.tls-enabled", false, "Enable TLS on the memberlist transport layer.") cfg.TLS.RegisterFlagsWithPrefix(prefix+"memberlist", f) } +type writeRequest struct { + b []byte + addr string +} + // TCPTransport is a memberlist.Transport implementation that uses TCP for both packet and stream // operations ("packet" and "stream" are terms used by memberlist). // It uses a new TCP connections for each operation. There is no connection reuse. @@ -92,10 +97,13 @@ type TCPTransport struct { packetCh chan *memberlist.Packet connCh chan net.Conn wg sync.WaitGroup - writeCh chan struct{} tcpListeners []net.Listener tlsConfig *tls.Config + writeMu sync.RWMutex + writeCh chan writeRequest + writeWG sync.WaitGroup + shutdown atomic.Int32 advertiseMu sync.RWMutex @@ -124,12 +132,20 @@ func NewTCPTransport(config TCPTransportConfig, logger log.Logger, registerer pr // Build out the new transport. var ok bool + concurrentWrites := config.MaxConcurrentWrites + if concurrentWrites <= 0 { + concurrentWrites = 1 + } t := TCPTransport{ cfg: config, logger: log.With(logger, "component", "memberlist TCPTransport"), packetCh: make(chan *memberlist.Packet), connCh: make(chan net.Conn), - writeCh: make(chan struct{}, config.MaxConcurrentWrites), + writeCh: make(chan writeRequest), + } + + for i := 0; i < concurrentWrites; i++ { + go t.writeWorker() } var err error @@ -430,31 +446,34 @@ func (t *TCPTransport) getAdvertisedAddr() string { // WriteTo is a packet-oriented interface that fires off the given // payload to the given address. func (t *TCPTransport) WriteTo(b []byte, addr string) (time.Time, error) { - t.sentPackets.Inc() - t.sentPacketsBytes.Add(float64(len(b))) - t.writeCh <- struct{}{} - go func() { - defer func() { <-t.writeCh }() - t.writeToAsync(b, addr) - }() + if t.shutdown.Load() == 1 { + return time.Time{}, errors.New("transport is shutting down") + } + t.writeMu.RLock() + defer t.writeMu.RUnlock() + t.writeWG.Add(1) + t.writeCh <- writeRequest{b: b, addr: addr} return time.Now(), nil } -func (t *TCPTransport) writeToAsync(b []byte, addr string) { - err := t.writeTo(b, addr) - if err != nil { - t.sentPacketsErrors.Inc() +func (t *TCPTransport) writeWorker() { + for req := range t.writeCh { + b, addr := req.b, req.addr + t.sentPackets.Inc() + t.sentPacketsBytes.Add(float64(len(b))) + err := t.writeTo(b, addr) + if err != nil { + t.sentPacketsErrors.Inc() - logLevel := level.Warn(t.logger) - if strings.Contains(err.Error(), "connection refused") { - // The connection refused is a common error that could happen during normal operations when a node - // shutdown (or crash). It shouldn't be considered a warning condition on the sender side. - logLevel = t.debugLog() + logLevel := level.Warn(t.logger) + if strings.Contains(err.Error(), "connection refused") { + // The connection refused is a common error that could happen during normal operations when a node + // shutdown (or crash). It shouldn't be considered a warning condition on the sender side. + logLevel = t.debugLog() + } + logLevel.Log("msg", "WriteTo failed", "addr", addr, "err", err) } - logLevel.Log("msg", "WriteTo failed", "addr", addr, "err", err) - - // WriteTo is used to send "UDP" packets. Since we use TCP, we can detect more errors, - // but memberlist library doesn't seem to cope with that very well. That is why we return nil instead. + t.writeWG.Done() } } @@ -570,9 +589,12 @@ func (t *TCPTransport) StreamCh() <-chan net.Conn { // Shutdown is called when memberlist is shutting down; this gives the // transport a chance to clean up any listeners. +// This will avoid log spam about errors when we shut down. func (t *TCPTransport) Shutdown() error { // This will avoid log spam about errors when we shut down. - t.shutdown.Store(1) + if old := t.shutdown.Swap(1); old == 1 { + return nil // already shut down + } // Rip through all the connections and shut them down. for _, conn := range t.tcpListeners { @@ -581,6 +603,12 @@ func (t *TCPTransport) Shutdown() error { // Block until all the listener threads have died. t.wg.Wait() + + // Wait until the write channel is empty and close it (to end the writeWorker goroutines). + t.writeMu.Lock() + defer t.writeMu.Unlock() + t.writeWG.Wait() + close(t.writeCh) return nil } diff --git a/kv/memberlist/tcp_transport_test.go b/kv/memberlist/tcp_transport_test.go index 310e11ecb..282bbc693 100644 --- a/kv/memberlist/tcp_transport_test.go +++ b/kv/memberlist/tcp_transport_test.go @@ -1,7 +1,10 @@ package memberlist import ( + "net" + "strings" "testing" + "time" "github.com/go-kit/log" "github.com/prometheus/client_golang/prometheus" @@ -9,6 +12,7 @@ import ( "github.com/stretchr/testify/require" "github.com/grafana/dskit/concurrency" + "github.com/grafana/dskit/crypto/tls" "github.com/grafana/dskit/flagext" ) @@ -51,6 +55,8 @@ func TestTCPTransport_WriteTo_ShouldNotLogAsWarningExpectedFailures(t *testing.T _, err = transport.WriteTo([]byte("test"), testData.remoteAddr) require.NoError(t, err) + require.NoError(t, transport.Shutdown()) + if testData.expectedLogs != "" { assert.Contains(t, logs.String(), testData.expectedLogs) } @@ -61,6 +67,58 @@ func TestTCPTransport_WriteTo_ShouldNotLogAsWarningExpectedFailures(t *testing.T } } +type timeoutReader struct{} + +func (f *timeoutReader) ReadSecret(_ string) ([]byte, error) { + time.Sleep(1 * time.Second) + return nil, nil +} + +func TestTCPTransportWriteToUnreachableAddr(t *testing.T) { + writeCt := 50 + + // Listen for TCP connections on a random port + freePorts, err := getFreePorts(1) + require.NoError(t, err) + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: freePorts[0]} + listener, err := net.ListenTCP("tcp", addr) + require.NoError(t, err) + defer listener.Close() + + logs := &concurrency.SyncBuffer{} + logger := log.NewLogfmtLogger(logs) + + cfg := TCPTransportConfig{} + flagext.DefaultValues(&cfg) + cfg.MaxConcurrentWrites = writeCt + cfg.PacketDialTimeout = 500 * time.Millisecond + transport, err := NewTCPTransport(cfg, logger, nil) + require.NoError(t, err) + + // Configure TLS only for writes. The dialing should timeout (because of the timeoutReader) + transport.cfg.TLSEnabled = true + transport.cfg.TLS = tls.ClientConfig{ + Reader: &timeoutReader{}, + CertPath: "fake", + KeyPath: "fake", + CAPath: "fake", + } + + timeStart := time.Now() + + for i := 0; i < writeCt; i++ { + _, err = transport.WriteTo([]byte("test"), addr.String()) + require.NoError(t, err) + } + + require.NoError(t, transport.Shutdown()) + + gotErrorCt := strings.Count(logs.String(), "context deadline exceeded") + assert.Equal(t, writeCt, gotErrorCt, "expected %d errors, got %d", writeCt, gotErrorCt) + assert.GreaterOrEqual(t, time.Since(timeStart), 500*time.Millisecond, "expected to take at least 500ms (timeout duration)") + assert.LessOrEqual(t, time.Since(timeStart), 2*time.Second, "expected to take less than 2s (timeout + a good margin), writing to unreachable addresses should not block") +} + func TestFinalAdvertiseAddr(t *testing.T) { tests := map[string]struct { advertiseAddr string