From a01997deecbd29aff958826263227e1e65aaae5a Mon Sep 17 00:00:00 2001 From: Julien Duchesne Date: Tue, 8 Oct 2024 18:13:38 -0400 Subject: [PATCH] Address PR comments - Move variables around - Add timeout before dropping requests. This prevents blocking on the `WriteTo` function --- kv/memberlist/tcp_transport.go | 27 +++++++++++++++---- kv/memberlist/tcp_transport_test.go | 41 +++++++++++++++++++++++++---- 2 files changed, 58 insertions(+), 10 deletions(-) diff --git a/kv/memberlist/tcp_transport.go b/kv/memberlist/tcp_transport.go index bd13d9393..2010d3919 100644 --- a/kv/memberlist/tcp_transport.go +++ b/kv/memberlist/tcp_transport.go @@ -54,6 +54,9 @@ type TCPTransportConfig struct { // Maximum number of concurrent writes to other nodes. MaxConcurrentWrites int `yaml:"max_concurrent_writes" category:"advanced"` + // Timeout for acquiring one of the concurrent write slots. + AcquireWriterTimeout time.Duration `yaml:"acquire_writer_timeout" category:"advanced"` + // Transport logs lots of messages at debug level, so it deserves an extra flag for turning it on TransportDebug bool `yaml:"-" category:"advanced"` @@ -76,6 +79,7 @@ func (cfg *TCPTransportConfig) RegisterFlagsWithPrefix(f *flag.FlagSet, prefix s f.DurationVar(&cfg.PacketDialTimeout, prefix+"memberlist.packet-dial-timeout", 2*time.Second, "Timeout used when connecting to other nodes to send packet.") f.DurationVar(&cfg.PacketWriteTimeout, prefix+"memberlist.packet-write-timeout", 5*time.Second, "Timeout for writing 'packet' data.") f.IntVar(&cfg.MaxConcurrentWrites, prefix+"memberlist.max-concurrent-writes", 3, "Maximum number of concurrent writes to other nodes.") + f.DurationVar(&cfg.AcquireWriterTimeout, prefix+"memberlist.acquire-writer-timeout", 250*time.Millisecond, "Timeout for acquiring one of the concurrent write slots. After this time, the message will be dropped.") f.BoolVar(&cfg.TransportDebug, prefix+"memberlist.transport-debug", false, "Log debug transport messages. Note: global log.level must be at debug level as well.") f.BoolVar(&cfg.TLSEnabled, prefix+"memberlist.tls-enabled", false, "Enable TLS on the memberlist transport layer.") @@ -99,11 +103,11 @@ type TCPTransport struct { tcpListeners []net.Listener tlsConfig *tls.Config - writeCh chan writeRequest - writeWG sync.WaitGroup - - shutdown bool shutdownMu sync.RWMutex + shutdown bool + writeCh chan writeRequest // this channel is protected by shutdownMu + + writeWG sync.WaitGroup advertiseMu sync.RWMutex advertiseAddr string @@ -454,7 +458,20 @@ func (t *TCPTransport) WriteTo(b []byte, addr string) (time.Time, error) { if t.shutdown { return time.Time{}, errors.New("transport is shutting down") } - t.writeCh <- writeRequest{b: b, addr: addr} + + // Send the packet to the write workers + // If this blocks for too long (as configured), abort and log an error. + select { + case <-time.After(t.cfg.AcquireWriterTimeout): + level.Warn(t.logger).Log("msg", "WriteTo failed to acquire a writer. Dropping message", "timeout", t.cfg.AcquireWriterTimeout, "addr", addr) + t.sentPacketsErrors.Inc() + // WriteTo is used to send "UDP" packets. Since we use TCP, we can detect more errors, + // but memberlist library doesn't seem to cope with that very well. That is why we return nil instead. + return time.Now(), nil + case t.writeCh <- writeRequest{b: b, addr: addr}: + // OK + } + return time.Now(), nil } diff --git a/kv/memberlist/tcp_transport_test.go b/kv/memberlist/tcp_transport_test.go index 282bbc693..739fcc322 100644 --- a/kv/memberlist/tcp_transport_test.go +++ b/kv/memberlist/tcp_transport_test.go @@ -3,6 +3,7 @@ package memberlist import ( "net" "strings" + "sync" "testing" "time" @@ -78,10 +79,7 @@ func TestTCPTransportWriteToUnreachableAddr(t *testing.T) { writeCt := 50 // Listen for TCP connections on a random port - freePorts, err := getFreePorts(1) - require.NoError(t, err) - addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: freePorts[0]} - listener, err := net.ListenTCP("tcp", addr) + listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) defer listener.Close() @@ -107,7 +105,7 @@ func TestTCPTransportWriteToUnreachableAddr(t *testing.T) { timeStart := time.Now() for i := 0; i < writeCt; i++ { - _, err = transport.WriteTo([]byte("test"), addr.String()) + _, err = transport.WriteTo([]byte("test"), listener.Addr().String()) require.NoError(t, err) } @@ -119,6 +117,39 @@ func TestTCPTransportWriteToUnreachableAddr(t *testing.T) { assert.LessOrEqual(t, time.Since(timeStart), 2*time.Second, "expected to take less than 2s (timeout + a good margin), writing to unreachable addresses should not block") } +func TestTCPTransportWriterAcquireTimeout(t *testing.T) { + // Listen for TCP connections on a random port + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + logs := &concurrency.SyncBuffer{} + logger := log.NewLogfmtLogger(logs) + + cfg := TCPTransportConfig{} + flagext.DefaultValues(&cfg) + cfg.MaxConcurrentWrites = 1 + cfg.AcquireWriterTimeout = 1 * time.Millisecond // very short timeout + transport, err := NewTCPTransport(cfg, logger, nil) + require.NoError(t, err) + + writeCt := 100 + var reqWg sync.WaitGroup + for i := 0; i < writeCt; i++ { + reqWg.Add(1) + go func() { + defer reqWg.Done() + transport.WriteTo([]byte("test"), listener.Addr().String()) + }() + } + reqWg.Wait() + + require.NoError(t, transport.Shutdown()) + gotErrorCt := strings.Count(logs.String(), "WriteTo failed to acquire a writer. Dropping message") + assert.Less(t, gotErrorCt, writeCt, "expected to have less errors (%d) than total writes (%d). Some writes should pass.", gotErrorCt, writeCt) + assert.NotZero(t, gotErrorCt, "expected errors, got none") +} + func TestFinalAdvertiseAddr(t *testing.T) { tests := map[string]struct { advertiseAddr string