Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add concurrency to the memberlist transport's WriteTo method #525

Merged
merged 5 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@
* [ENHANCEMENT] Adapt `metrics.SendSumOfGaugesPerTenant` to use `metrics.MetricOption`. #584
* [ENHANCEMENT] Cache: Add `.Add()` and `.Set()` methods to cache clients. #591
* [ENHANCEMENT] Cache: Add `.Advance()` methods to mock cache clients for easier testing of TTLs. #601
* [ENHANCEMENT] Memberlist: Add concurrency to the transport's WriteTo method. #525
* [CHANGE] Backoff: added `Backoff.ErrCause()` which is like `Backoff.Err()` but returns the context cause if backoff is terminated because the context has been canceled. #538
* [BUGFIX] spanlogger: Support multiple tenant IDs. #59
* [BUGFIX] Memberlist: fixed corrupted packets when sending compound messages with more than 255 messages or messages bigger than 64KB. #85
Expand Down
99 changes: 81 additions & 18 deletions kv/memberlist/tcp_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"go.uber.org/atomic"

dstls "github.com/grafana/dskit/crypto/tls"
"github.com/grafana/dskit/flagext"
Expand Down Expand Up @@ -52,7 +51,13 @@ type TCPTransportConfig struct {
// Timeout for writing packet data. Zero = no timeout.
PacketWriteTimeout time.Duration `yaml:"packet_write_timeout" category:"advanced"`

// Transport logs lot of messages at debug level, so it deserves an extra flag for turning it on
// Maximum number of concurrent writes to other nodes.
MaxConcurrentWrites int `yaml:"max_concurrent_writes" category:"advanced"`

// Timeout for acquiring one of the concurrent write slots.
AcquireWriterTimeout time.Duration `yaml:"acquire_writer_timeout" category:"advanced"`

// Transport logs lots of messages at debug level, so it deserves an extra flag for turning it on
TransportDebug bool `yaml:"-" category:"advanced"`

// Where to put custom metrics. nil = don't register.
Expand All @@ -73,12 +78,19 @@ func (cfg *TCPTransportConfig) RegisterFlagsWithPrefix(f *flag.FlagSet, prefix s
f.IntVar(&cfg.BindPort, prefix+"memberlist.bind-port", 7946, "Port to listen on for gossip messages.")
f.DurationVar(&cfg.PacketDialTimeout, prefix+"memberlist.packet-dial-timeout", 2*time.Second, "Timeout used when connecting to other nodes to send packet.")
f.DurationVar(&cfg.PacketWriteTimeout, prefix+"memberlist.packet-write-timeout", 5*time.Second, "Timeout for writing 'packet' data.")
f.IntVar(&cfg.MaxConcurrentWrites, prefix+"memberlist.max-concurrent-writes", 3, "Maximum number of concurrent writes to other nodes.")
f.DurationVar(&cfg.AcquireWriterTimeout, prefix+"memberlist.acquire-writer-timeout", 250*time.Millisecond, "Timeout for acquiring one of the concurrent write slots. After this time, the message will be dropped.")
f.BoolVar(&cfg.TransportDebug, prefix+"memberlist.transport-debug", false, "Log debug transport messages. Note: global log.level must be at debug level as well.")

f.BoolVar(&cfg.TLSEnabled, prefix+"memberlist.tls-enabled", false, "Enable TLS on the memberlist transport layer.")
cfg.TLS.RegisterFlagsWithPrefix(prefix+"memberlist", f)
}

type writeRequest struct {
b []byte
addr string
}

// TCPTransport is a memberlist.Transport implementation that uses TCP for both packet and stream
// operations ("packet" and "stream" are terms used by memberlist).
// It uses a new TCP connections for each operation. There is no connection reuse.
Expand All @@ -91,7 +103,11 @@ type TCPTransport struct {
tcpListeners []net.Listener
tlsConfig *tls.Config

shutdown atomic.Int32
shutdownMu sync.RWMutex
julienduchesne marked this conversation as resolved.
Show resolved Hide resolved
shutdown bool
writeCh chan writeRequest // this channel is protected by shutdownMu
julienduchesne marked this conversation as resolved.
Show resolved Hide resolved

writeWG sync.WaitGroup

advertiseMu sync.RWMutex
advertiseAddr string
Expand Down Expand Up @@ -119,11 +135,21 @@ func NewTCPTransport(config TCPTransportConfig, logger log.Logger, registerer pr

// Build out the new transport.
var ok bool
concurrentWrites := config.MaxConcurrentWrites
if concurrentWrites <= 0 {
concurrentWrites = 1
}
t := TCPTransport{
cfg: config,
logger: log.With(logger, "component", "memberlist TCPTransport"),
packetCh: make(chan *memberlist.Packet),
connCh: make(chan net.Conn),
writeCh: make(chan writeRequest),
}

for i := 0; i < concurrentWrites; i++ {
t.writeWG.Add(1)
go t.writeWorker()
}

var err error
Expand Down Expand Up @@ -205,7 +231,10 @@ func (t *TCPTransport) tcpListen(tcpLn net.Listener) {
for {
conn, err := tcpLn.Accept()
if err != nil {
if s := t.shutdown.Load(); s == 1 {
t.shutdownMu.RLock()
isShuttingDown := t.shutdown
t.shutdownMu.RUnlock()
if isShuttingDown {
break
}

Expand Down Expand Up @@ -424,29 +453,49 @@ func (t *TCPTransport) getAdvertisedAddr() string {
// WriteTo is a packet-oriented interface that fires off the given
// payload to the given address.
func (t *TCPTransport) WriteTo(b []byte, addr string) (time.Time, error) {
t.sentPackets.Inc()
t.sentPacketsBytes.Add(float64(len(b)))
t.shutdownMu.RLock()
defer t.shutdownMu.RUnlock() // Unlock at the end to protect the chan
if t.shutdown {
return time.Time{}, errors.New("transport is shutting down")
}

err := t.writeTo(b, addr)
if err != nil {
// Send the packet to the write workers
// If this blocks for too long (as configured), abort and log an error.
select {
case <-time.After(t.cfg.AcquireWriterTimeout):
level.Warn(t.logger).Log("msg", "WriteTo failed to acquire a writer. Dropping message", "timeout", t.cfg.AcquireWriterTimeout, "addr", addr)
t.sentPacketsErrors.Inc()

logLevel := level.Warn(t.logger)
if strings.Contains(err.Error(), "connection refused") {
// The connection refused is a common error that could happen during normal operations when a node
// shutdown (or crash). It shouldn't be considered a warning condition on the sender side.
logLevel = t.debugLog()
}
logLevel.Log("msg", "WriteTo failed", "addr", addr, "err", err)

// WriteTo is used to send "UDP" packets. Since we use TCP, we can detect more errors,
// but memberlist library doesn't seem to cope with that very well. That is why we return nil instead.
return time.Now(), nil
case t.writeCh <- writeRequest{b: b, addr: addr}:
// OK
}

return time.Now(), nil
}

func (t *TCPTransport) writeWorker() {
defer t.writeWG.Done()
for req := range t.writeCh {
b, addr := req.b, req.addr
t.sentPackets.Inc()
t.sentPacketsBytes.Add(float64(len(b)))
err := t.writeTo(b, addr)
if err != nil {
t.sentPacketsErrors.Inc()

logLevel := level.Warn(t.logger)
if strings.Contains(err.Error(), "connection refused") {
// The connection refused is a common error that could happen during normal operations when a node
// shutdown (or crash). It shouldn't be considered a warning condition on the sender side.
logLevel = t.debugLog()
}
logLevel.Log("msg", "WriteTo failed", "addr", addr, "err", err)
}
}
}

func (t *TCPTransport) writeTo(b []byte, addr string) error {
// Open connection, write packet header and data, data hash, close. Simple.
c, err := t.getConnection(addr, t.cfg.PacketDialTimeout)
Expand Down Expand Up @@ -559,17 +608,31 @@ func (t *TCPTransport) StreamCh() <-chan net.Conn {

// Shutdown is called when memberlist is shutting down; this gives the
// transport a chance to clean up any listeners.
// This will avoid log spam about errors when we shut down.
func (t *TCPTransport) Shutdown() error {
t.shutdownMu.Lock()
// This will avoid log spam about errors when we shut down.
t.shutdown.Store(1)
if t.shutdown {
t.shutdownMu.Unlock()
return nil // already shut down
}

// Set the shutdown flag and close the write channel.
t.shutdown = true
close(t.writeCh)
t.shutdownMu.Unlock()

// Rip through all the connections and shut them down.
for _, conn := range t.tcpListeners {
_ = conn.Close()
}

// Wait until all write workers have finished.
t.writeWG.Wait()

// Block until all the listener threads have died.
t.wg.Wait()

return nil
}

Expand Down
89 changes: 89 additions & 0 deletions kv/memberlist/tcp_transport_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
package memberlist

import (
"net"
"strings"
"sync"
"testing"
"time"

"github.com/go-kit/log"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/grafana/dskit/concurrency"
"github.com/grafana/dskit/crypto/tls"
"github.com/grafana/dskit/flagext"
)

Expand Down Expand Up @@ -51,6 +56,8 @@ func TestTCPTransport_WriteTo_ShouldNotLogAsWarningExpectedFailures(t *testing.T
_, err = transport.WriteTo([]byte("test"), testData.remoteAddr)
require.NoError(t, err)

require.NoError(t, transport.Shutdown())

if testData.expectedLogs != "" {
assert.Contains(t, logs.String(), testData.expectedLogs)
}
Expand All @@ -61,6 +68,88 @@ func TestTCPTransport_WriteTo_ShouldNotLogAsWarningExpectedFailures(t *testing.T
}
}

type timeoutReader struct{}

func (f *timeoutReader) ReadSecret(_ string) ([]byte, error) {
time.Sleep(1 * time.Second)
return nil, nil
}

func TestTCPTransportWriteToUnreachableAddr(t *testing.T) {
writeCt := 50

// Listen for TCP connections on a random port
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()

logs := &concurrency.SyncBuffer{}
logger := log.NewLogfmtLogger(logs)

cfg := TCPTransportConfig{}
flagext.DefaultValues(&cfg)
cfg.MaxConcurrentWrites = writeCt
cfg.PacketDialTimeout = 500 * time.Millisecond
transport, err := NewTCPTransport(cfg, logger, nil)
require.NoError(t, err)

// Configure TLS only for writes. The dialing should timeout (because of the timeoutReader)
transport.cfg.TLSEnabled = true
transport.cfg.TLS = tls.ClientConfig{
Reader: &timeoutReader{},
CertPath: "fake",
KeyPath: "fake",
CAPath: "fake",
}

timeStart := time.Now()

for i := 0; i < writeCt; i++ {
_, err = transport.WriteTo([]byte("test"), listener.Addr().String())
require.NoError(t, err)
}

require.NoError(t, transport.Shutdown())

gotErrorCt := strings.Count(logs.String(), "context deadline exceeded")
assert.Equal(t, writeCt, gotErrorCt, "expected %d errors, got %d", writeCt, gotErrorCt)
assert.GreaterOrEqual(t, time.Since(timeStart), 500*time.Millisecond, "expected to take at least 500ms (timeout duration)")
assert.LessOrEqual(t, time.Since(timeStart), 2*time.Second, "expected to take less than 2s (timeout + a good margin), writing to unreachable addresses should not block")
}

func TestTCPTransportWriterAcquireTimeout(t *testing.T) {
// Listen for TCP connections on a random port
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()

logs := &concurrency.SyncBuffer{}
logger := log.NewLogfmtLogger(logs)

cfg := TCPTransportConfig{}
flagext.DefaultValues(&cfg)
cfg.MaxConcurrentWrites = 1
cfg.AcquireWriterTimeout = 1 * time.Millisecond // very short timeout
transport, err := NewTCPTransport(cfg, logger, nil)
require.NoError(t, err)

writeCt := 100
var reqWg sync.WaitGroup
for i := 0; i < writeCt; i++ {
reqWg.Add(1)
go func() {
defer reqWg.Done()
transport.WriteTo([]byte("test"), listener.Addr().String()) // nolint:errcheck
}()
}
reqWg.Wait()

require.NoError(t, transport.Shutdown())
gotErrorCt := strings.Count(logs.String(), "WriteTo failed to acquire a writer. Dropping message")
assert.Less(t, gotErrorCt, writeCt, "expected to have less errors (%d) than total writes (%d). Some writes should pass.", gotErrorCt, writeCt)
assert.NotZero(t, gotErrorCt, "expected errors, got none")
}

func TestFinalAdvertiseAddr(t *testing.T) {
tests := map[string]struct {
advertiseAddr string
Expand Down