Skip to content

Commit

Permalink
Try to make this PR ready to go
Browse files Browse the repository at this point in the history
- Create goroutines and keep them while the TCPTransport is alive. End them on the `Shutdown` function
- Add `TestTCPTransportWriteToUnreachableAddr` test to check that writing is not blocking anymore (without this PR, it takes `writeCt * timeout` to run and it fails)
  • Loading branch information
julienduchesne committed Oct 3, 2024
1 parent 92c9075 commit 1807e0c
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 24 deletions.
76 changes: 52 additions & 24 deletions kv/memberlist/tcp_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,18 @@ func (cfg *TCPTransportConfig) RegisterFlagsWithPrefix(f *flag.FlagSet, prefix s
f.IntVar(&cfg.BindPort, prefix+"memberlist.bind-port", 7946, "Port to listen on for gossip messages.")
f.DurationVar(&cfg.PacketDialTimeout, prefix+"memberlist.packet-dial-timeout", 2*time.Second, "Timeout used when connecting to other nodes to send packet.")
f.DurationVar(&cfg.PacketWriteTimeout, prefix+"memberlist.packet-write-timeout", 5*time.Second, "Timeout for writing 'packet' data.")
f.IntVar(&cfg.MaxConcurrentWrites, prefix+"memberlist.max-concurrent-writes", 1, "Maximum number of concurrent writes to other nodes.")
f.IntVar(&cfg.MaxConcurrentWrites, prefix+"memberlist.max-concurrent-writes", 3, "Maximum number of concurrent writes to other nodes.")
f.BoolVar(&cfg.TransportDebug, prefix+"memberlist.transport-debug", false, "Log debug transport messages. Note: global log.level must be at debug level as well.")

f.BoolVar(&cfg.TLSEnabled, prefix+"memberlist.tls-enabled", false, "Enable TLS on the memberlist transport layer.")
cfg.TLS.RegisterFlagsWithPrefix(prefix+"memberlist", f)
}

type writeRequest struct {
b []byte
addr string
}

// TCPTransport is a memberlist.Transport implementation that uses TCP for both packet and stream
// operations ("packet" and "stream" are terms used by memberlist).
// It uses a new TCP connections for each operation. There is no connection reuse.
Expand All @@ -92,10 +97,13 @@ type TCPTransport struct {
packetCh chan *memberlist.Packet
connCh chan net.Conn
wg sync.WaitGroup
writeCh chan struct{}
tcpListeners []net.Listener
tlsConfig *tls.Config

writeMu sync.RWMutex
writeCh chan writeRequest
writeWG sync.WaitGroup

shutdown atomic.Int32

advertiseMu sync.RWMutex
Expand Down Expand Up @@ -124,12 +132,20 @@ func NewTCPTransport(config TCPTransportConfig, logger log.Logger, registerer pr

// Build out the new transport.
var ok bool
concurrentWrites := config.MaxConcurrentWrites
if concurrentWrites <= 0 {
concurrentWrites = 1
}
t := TCPTransport{
cfg: config,
logger: log.With(logger, "component", "memberlist TCPTransport"),
packetCh: make(chan *memberlist.Packet),
connCh: make(chan net.Conn),
writeCh: make(chan struct{}, config.MaxConcurrentWrites),
writeCh: make(chan writeRequest),
}

for i := 0; i < concurrentWrites; i++ {
go t.writeWorker()
}

var err error
Expand Down Expand Up @@ -430,31 +446,34 @@ func (t *TCPTransport) getAdvertisedAddr() string {
// WriteTo is a packet-oriented interface that fires off the given
// payload to the given address.
func (t *TCPTransport) WriteTo(b []byte, addr string) (time.Time, error) {
t.sentPackets.Inc()
t.sentPacketsBytes.Add(float64(len(b)))
t.writeCh <- struct{}{}
go func() {
defer func() { <-t.writeCh }()
t.writeToAsync(b, addr)
}()
if t.shutdown.Load() == 1 {
return time.Time{}, errors.New("transport is shutting down")
}
t.writeMu.RLock()
defer t.writeMu.RUnlock()
t.writeWG.Add(1)
t.writeCh <- writeRequest{b: b, addr: addr}
return time.Now(), nil
}

func (t *TCPTransport) writeToAsync(b []byte, addr string) {
err := t.writeTo(b, addr)
if err != nil {
t.sentPacketsErrors.Inc()
func (t *TCPTransport) writeWorker() {
for req := range t.writeCh {
b, addr := req.b, req.addr
t.sentPackets.Inc()
t.sentPacketsBytes.Add(float64(len(b)))
err := t.writeTo(b, addr)
if err != nil {
t.sentPacketsErrors.Inc()

logLevel := level.Warn(t.logger)
if strings.Contains(err.Error(), "connection refused") {
// The connection refused is a common error that could happen during normal operations when a node
// shutdown (or crash). It shouldn't be considered a warning condition on the sender side.
logLevel = t.debugLog()
logLevel := level.Warn(t.logger)
if strings.Contains(err.Error(), "connection refused") {
// The connection refused is a common error that could happen during normal operations when a node
// shutdown (or crash). It shouldn't be considered a warning condition on the sender side.
logLevel = t.debugLog()
}
logLevel.Log("msg", "WriteTo failed", "addr", addr, "err", err)
}
logLevel.Log("msg", "WriteTo failed", "addr", addr, "err", err)

// WriteTo is used to send "UDP" packets. Since we use TCP, we can detect more errors,
// but memberlist library doesn't seem to cope with that very well. That is why we return nil instead.
t.writeWG.Done()
}
}

Expand Down Expand Up @@ -570,9 +589,12 @@ func (t *TCPTransport) StreamCh() <-chan net.Conn {

// Shutdown is called when memberlist is shutting down; this gives the
// transport a chance to clean up any listeners.
// This will avoid log spam about errors when we shut down.
func (t *TCPTransport) Shutdown() error {
// This will avoid log spam about errors when we shut down.
t.shutdown.Store(1)
if old := t.shutdown.Swap(1); old == 1 {
return nil // already shut down
}

// Rip through all the connections and shut them down.
for _, conn := range t.tcpListeners {
Expand All @@ -581,6 +603,12 @@ func (t *TCPTransport) Shutdown() error {

// Block until all the listener threads have died.
t.wg.Wait()

// Wait until the write channel is empty and close it (to end the writeWorker goroutines).
t.writeMu.Lock()
defer t.writeMu.Unlock()
t.writeWG.Wait()
close(t.writeCh)
return nil
}

Expand Down
58 changes: 58 additions & 0 deletions kv/memberlist/tcp_transport_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
package memberlist

import (
"net"
"strings"
"testing"
"time"

"github.com/go-kit/log"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/grafana/dskit/concurrency"
"github.com/grafana/dskit/crypto/tls"
"github.com/grafana/dskit/flagext"
)

Expand Down Expand Up @@ -51,6 +55,8 @@ func TestTCPTransport_WriteTo_ShouldNotLogAsWarningExpectedFailures(t *testing.T
_, err = transport.WriteTo([]byte("test"), testData.remoteAddr)
require.NoError(t, err)

require.NoError(t, transport.Shutdown())

if testData.expectedLogs != "" {
assert.Contains(t, logs.String(), testData.expectedLogs)
}
Expand All @@ -61,6 +67,58 @@ func TestTCPTransport_WriteTo_ShouldNotLogAsWarningExpectedFailures(t *testing.T
}
}

type timeoutReader struct{}

func (f *timeoutReader) ReadSecret(_ string) ([]byte, error) {
time.Sleep(1 * time.Second)
return nil, nil
}

func TestTCPTransportWriteToUnreachableAddr(t *testing.T) {
writeCt := 50

// Listen for TCP connections on a random port
freePorts, err := getFreePorts(1)
require.NoError(t, err)
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: freePorts[0]}
listener, err := net.ListenTCP("tcp", addr)
require.NoError(t, err)
defer listener.Close()

logs := &concurrency.SyncBuffer{}
logger := log.NewLogfmtLogger(logs)

cfg := TCPTransportConfig{}
flagext.DefaultValues(&cfg)
cfg.MaxConcurrentWrites = writeCt
cfg.PacketDialTimeout = 500 * time.Millisecond
transport, err := NewTCPTransport(cfg, logger, nil)
require.NoError(t, err)

// Configure TLS only for writes. The dialing should timeout (because of the timeoutReader)
transport.cfg.TLSEnabled = true
transport.cfg.TLS = tls.ClientConfig{
Reader: &timeoutReader{},
CertPath: "fake",
KeyPath: "fake",
CAPath: "fake",
}

timeStart := time.Now()

for i := 0; i < writeCt; i++ {
_, err = transport.WriteTo([]byte("test"), addr.String())
require.NoError(t, err)
}

require.NoError(t, transport.Shutdown())

gotErrorCt := strings.Count(logs.String(), "context deadline exceeded")
assert.Equal(t, writeCt, gotErrorCt, "expected %d errors, got %d", writeCt, gotErrorCt)
assert.GreaterOrEqual(t, time.Since(timeStart), 500*time.Millisecond, "expected to take at least 500ms (timeout duration)")
assert.LessOrEqual(t, time.Since(timeStart), 2*time.Second, "expected to take less than 2s (timeout + a good margin), writing to unreachable addresses should not block")
}

func TestFinalAdvertiseAddr(t *testing.T) {
tests := map[string]struct {
advertiseAddr string
Expand Down

0 comments on commit 1807e0c

Please sign in to comment.