Skip to content

Commit

Permalink
Try to make this PR ready to go
Browse files Browse the repository at this point in the history
- Create goroutines and keep them while the TCPTransport is alive. End them on the `Shutdown` function
- Add `MaxConcurrentWrites` to all tests with a random value so random errors would be found
- Add `TestTCPTransportWriteToUnreachableAddr` test to check that writing is not blocking anymore (without this PR, it takes `writeCt * timeout` to run and it fails)
  • Loading branch information
julienduchesne committed Oct 2, 2024
1 parent 47a93c0 commit 3d073ca
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 44 deletions.
48 changes: 30 additions & 18 deletions kv/memberlist/memberlist_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,8 @@ func TestBasicGetAndCas(t *testing.T) {
var cfg KVConfig
flagext.DefaultValues(&cfg)
cfg.TCPTransport = TCPTransportConfig{
BindAddrs: getLocalhostAddrs(),
BindAddrs: getLocalhostAddrs(),
MaxConcurrentWrites: rand.Intn(10) + 1,
}
cfg.Codecs = []codec.Codec{c}

Expand Down Expand Up @@ -321,7 +322,8 @@ func withFixtures(t *testing.T, testFN func(t *testing.T, kv *Client)) {
var cfg KVConfig
flagext.DefaultValues(&cfg)
cfg.TCPTransport = TCPTransportConfig{
BindAddrs: getLocalhostAddrs(),
BindAddrs: getLocalhostAddrs(),
MaxConcurrentWrites: rand.Intn(10) + 1,
}
cfg.Codecs = []codec.Codec{c}

Expand Down Expand Up @@ -477,7 +479,8 @@ func TestMultipleCAS(t *testing.T) {
var cfg KVConfig
flagext.DefaultValues(&cfg)
cfg.TCPTransport = TCPTransportConfig{
BindAddrs: getLocalhostAddrs(),
BindAddrs: getLocalhostAddrs(),
MaxConcurrentWrites: rand.Intn(10) + 1,
}
cfg.Codecs = []codec.Codec{c}

Expand Down Expand Up @@ -567,8 +570,9 @@ func defaultKVConfig(i int) KVConfig {
cfg.PushPullInterval = 5 * time.Second

cfg.TCPTransport = TCPTransportConfig{
BindAddrs: getLocalhostAddrs(),
BindPort: 0, // randomize ports
BindAddrs: getLocalhostAddrs(),
BindPort: 0, // randomize ports
MaxConcurrentWrites: rand.Intn(10) + 1,
}

return cfg
Expand Down Expand Up @@ -866,8 +870,9 @@ func TestJoinMembersWithRetryBackoff(t *testing.T) {
cfg.AbortIfJoinFails = true

cfg.TCPTransport = TCPTransportConfig{
BindAddrs: getLocalhostAddrs(),
BindPort: port,
BindAddrs: getLocalhostAddrs(),
BindPort: port,
MaxConcurrentWrites: rand.Intn(10) + 1,
}

cfg.Codecs = []codec.Codec{c}
Expand Down Expand Up @@ -951,8 +956,9 @@ func TestMemberlistFailsToJoin(t *testing.T) {
cfg.AbortIfJoinFails = true

cfg.TCPTransport = TCPTransportConfig{
BindAddrs: getLocalhostAddrs(),
BindPort: 0,
BindAddrs: getLocalhostAddrs(),
BindPort: 0,
MaxConcurrentWrites: rand.Intn(10) + 1,
}

cfg.JoinMembers = []string{net.JoinHostPort(getLocalhostAddr(), strconv.Itoa(ports[0]))}
Expand Down Expand Up @@ -1123,8 +1129,9 @@ func TestMultipleCodecs(t *testing.T) {
var cfg KVConfig
flagext.DefaultValues(&cfg)
cfg.TCPTransport = TCPTransportConfig{
BindAddrs: getLocalhostAddrs(),
BindPort: 0, // randomize
BindAddrs: getLocalhostAddrs(),
BindPort: 0, // randomize
MaxConcurrentWrites: rand.Intn(10) + 1,
}

cfg.Codecs = []codec.Codec{
Expand Down Expand Up @@ -1214,8 +1221,9 @@ func TestRejoin(t *testing.T) {
var cfg1 KVConfig
flagext.DefaultValues(&cfg1)
cfg1.TCPTransport = TCPTransportConfig{
BindAddrs: getLocalhostAddrs(),
BindPort: ports[0],
BindAddrs: getLocalhostAddrs(),
BindPort: ports[0],
MaxConcurrentWrites: rand.Intn(10) + 1,
}

cfg1.RandomizeNodeName = true
Expand Down Expand Up @@ -1277,7 +1285,8 @@ func TestNotifyMsgResendsOnlyChanges(t *testing.T) {

cfg := KVConfig{
TCPTransport: TCPTransportConfig{
BindAddrs: getLocalhostAddrs(),
BindAddrs: getLocalhostAddrs(),
MaxConcurrentWrites: rand.Intn(10) + 1,
},
}
// We will be checking for number of messages in the broadcast queue, so make sure to use known retransmit factor.
Expand Down Expand Up @@ -1346,7 +1355,8 @@ func TestSendingOldTombstoneShouldNotForwardMessage(t *testing.T) {

cfg := KVConfig{
TCPTransport: TCPTransportConfig{
BindAddrs: getLocalhostAddrs(),
BindAddrs: getLocalhostAddrs(),
MaxConcurrentWrites: rand.Intn(10) + 1,
},
}
// We will be checking for number of messages in the broadcast queue, so make sure to use known retransmit factor.
Expand Down Expand Up @@ -1441,8 +1451,9 @@ func TestFastJoin(t *testing.T) {
var cfg KVConfig
flagext.DefaultValues(&cfg)
cfg.TCPTransport = TCPTransportConfig{
BindAddrs: getLocalhostAddrs(),
BindPort: 0, // randomize
BindAddrs: getLocalhostAddrs(),
BindPort: 0, // randomize
MaxConcurrentWrites: rand.Intn(10) + 1,
}

cfg.Codecs = []codec.Codec{
Expand Down Expand Up @@ -1492,7 +1503,8 @@ func TestDelegateMethodsDontCrashBeforeKVStarts(t *testing.T) {
cfg := KVConfig{}
cfg.Codecs = append(cfg.Codecs, codec)
cfg.TCPTransport = TCPTransportConfig{
BindAddrs: getLocalhostAddrs(),
BindAddrs: getLocalhostAddrs(),
MaxConcurrentWrites: rand.Intn(10) + 1,
}

kv := NewKV(cfg, log.NewNopLogger(), &dnsProviderMock{}, prometheus.NewPedanticRegistry())
Expand Down
63 changes: 40 additions & 23 deletions kv/memberlist/tcp_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ func (cfg *TCPTransportConfig) RegisterFlagsWithPrefix(f *flag.FlagSet, prefix s
cfg.TLS.RegisterFlagsWithPrefix(prefix+"memberlist", f)
}

type writeRequest struct {
b []byte
addr string
}

// TCPTransport is a memberlist.Transport implementation that uses TCP for both packet and stream
// operations ("packet" and "stream" are terms used by memberlist).
// It uses a new TCP connections for each operation. There is no connection reuse.
Expand All @@ -92,9 +97,10 @@ type TCPTransport struct {
packetCh chan *memberlist.Packet
connCh chan net.Conn
wg sync.WaitGroup
writeCh chan struct{}
tcpListeners []net.Listener
tlsConfig *tls.Config
writeCh chan writeRequest
writeWG sync.WaitGroup

shutdown atomic.Int32

Expand Down Expand Up @@ -129,7 +135,11 @@ func NewTCPTransport(config TCPTransportConfig, logger log.Logger, registerer pr
logger: log.With(logger, "component", "memberlist TCPTransport"),
packetCh: make(chan *memberlist.Packet),
connCh: make(chan net.Conn),
writeCh: make(chan struct{}, config.MaxConcurrentWrites),
writeCh: make(chan writeRequest),
}

for i := 0; i < config.MaxConcurrentWrites; i++ {
go t.writeWorker()
}

var err error
Expand Down Expand Up @@ -430,31 +440,32 @@ func (t *TCPTransport) getAdvertisedAddr() string {
// WriteTo is a packet-oriented interface that fires off the given
// payload to the given address.
func (t *TCPTransport) WriteTo(b []byte, addr string) (time.Time, error) {
t.sentPackets.Inc()
t.sentPacketsBytes.Add(float64(len(b)))
t.writeCh <- struct{}{}
go func() {
defer func() { <-t.writeCh }()
t.writeToAsync(b, addr)
}()
if t.shutdown.Load() == 1 {
return time.Time{}, errors.New("transport is shutting down")
}
t.writeWG.Add(1)
t.writeCh <- writeRequest{b: b, addr: addr}
return time.Now(), nil
}

func (t *TCPTransport) writeToAsync(b []byte, addr string) {
err := t.writeTo(b, addr)
if err != nil {
t.sentPacketsErrors.Inc()
func (t *TCPTransport) writeWorker() {
for req := range t.writeCh {
b, addr := req.b, req.addr
t.sentPackets.Inc()
t.sentPacketsBytes.Add(float64(len(b)))
err := t.writeTo(b, addr)
if err != nil {
t.sentPacketsErrors.Inc()

logLevel := level.Warn(t.logger)
if strings.Contains(err.Error(), "connection refused") {
// The connection refused is a common error that could happen during normal operations when a node
// shutdown (or crash). It shouldn't be considered a warning condition on the sender side.
logLevel = t.debugLog()
logLevel := level.Warn(t.logger)
if strings.Contains(err.Error(), "connection refused") {
// The connection refused is a common error that could happen during normal operations when a node
// shutdown (or crash). It shouldn't be considered a warning condition on the sender side.
logLevel = t.debugLog()
}
logLevel.Log("msg", "WriteTo failed", "addr", addr, "err", err)
}
logLevel.Log("msg", "WriteTo failed", "addr", addr, "err", err)

// WriteTo is used to send "UDP" packets. Since we use TCP, we can detect more errors,
// but memberlist library doesn't seem to cope with that very well. That is why we return nil instead.
t.writeWG.Done()
}
}

Expand Down Expand Up @@ -572,7 +583,9 @@ func (t *TCPTransport) StreamCh() <-chan net.Conn {
// transport a chance to clean up any listeners.
func (t *TCPTransport) Shutdown() error {
// This will avoid log spam about errors when we shut down.
t.shutdown.Store(1)
if old := t.shutdown.Swap(1); old == 1 {
return nil // already shut down
}

// Rip through all the connections and shut them down.
for _, conn := range t.tcpListeners {
Expand All @@ -581,6 +594,10 @@ func (t *TCPTransport) Shutdown() error {

// Block until all the listener threads have died.
t.wg.Wait()

// Wait until the write channel is empty and close it (to end the writeWorker goroutines).
t.writeWG.Wait()
close(t.writeCh)
return nil
}

Expand Down
72 changes: 69 additions & 3 deletions kv/memberlist/tcp_transport_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
package memberlist

import (
"net"
"strings"
"testing"
"time"

"github.com/go-kit/log"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/rand"

"github.com/grafana/dskit/concurrency"
"github.com/grafana/dskit/crypto/tls"
"github.com/grafana/dskit/flagext"
)

Expand Down Expand Up @@ -37,7 +42,9 @@ func TestTCPTransport_WriteTo_ShouldNotLogAsWarningExpectedFailures(t *testing.T
logs := &concurrency.SyncBuffer{}
logger := log.NewLogfmtLogger(logs)

cfg := TCPTransportConfig{}
cfg := TCPTransportConfig{
MaxConcurrentWrites: rand.Intn(10) + 1,
}
flagext.DefaultValues(&cfg)
cfg.BindAddrs = []string{"127.0.0.1"}
cfg.BindPort = 0
Expand All @@ -51,6 +58,8 @@ func TestTCPTransport_WriteTo_ShouldNotLogAsWarningExpectedFailures(t *testing.T
_, err = transport.WriteTo([]byte("test"), testData.remoteAddr)
require.NoError(t, err)

require.NoError(t, transport.Shutdown())

if testData.expectedLogs != "" {
assert.Contains(t, logs.String(), testData.expectedLogs)
}
Expand All @@ -61,6 +70,58 @@ func TestTCPTransport_WriteTo_ShouldNotLogAsWarningExpectedFailures(t *testing.T
}
}

type timeoutReader struct{}

func (f *timeoutReader) ReadSecret(path string) ([]byte, error) {
time.Sleep(1 * time.Second)
return nil, nil
}

func TestTCPTransportWriteToUnreachableAddr(t *testing.T) {
writeCt := 50

// Listen for TCP connections on a random port
freePorts, err := getFreePorts(1)
require.NoError(t, err)
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: freePorts[0]}
listener, err := net.ListenTCP("tcp", addr)
require.NoError(t, err)
defer listener.Close()

logs := &concurrency.SyncBuffer{}
logger := log.NewLogfmtLogger(logs)

cfg := TCPTransportConfig{}
flagext.DefaultValues(&cfg)
cfg.MaxConcurrentWrites = writeCt
cfg.PacketDialTimeout = 500 * time.Millisecond
transport, err := NewTCPTransport(cfg, logger, nil)
require.NoError(t, err)

// Configure TLS only for writes. The dialing should timeout (because of the timeoutReader)
transport.cfg.TLSEnabled = true
transport.cfg.TLS = tls.ClientConfig{
Reader: &timeoutReader{},
CertPath: "fake",
KeyPath: "fake",
CAPath: "fake",
}

timeStart := time.Now()

for i := 0; i < writeCt; i++ {
_, err = transport.WriteTo([]byte("test"), addr.String())
require.NoError(t, err)
}

require.NoError(t, transport.Shutdown())

gotErrorCt := strings.Count(logs.String(), "context deadline exceeded")
assert.Equal(t, writeCt, gotErrorCt, "expected %d errors, got %d", writeCt, gotErrorCt)
assert.GreaterOrEqual(t, time.Since(timeStart), 500*time.Millisecond, "expected to take at least 500ms (timeout duration)")
assert.LessOrEqual(t, time.Since(timeStart), 2*time.Second, "expected to take less than 2s (timeout + a good margin), writing to unreachable addresses should not block")
}

func TestFinalAdvertiseAddr(t *testing.T) {
tests := map[string]struct {
advertiseAddr string
Expand All @@ -79,7 +140,9 @@ func TestFinalAdvertiseAddr(t *testing.T) {
logs := &concurrency.SyncBuffer{}
logger := log.NewLogfmtLogger(logs)

cfg := TCPTransportConfig{}
cfg := TCPTransportConfig{
MaxConcurrentWrites: rand.Intn(10) + 1,
}
flagext.DefaultValues(&cfg)
cfg.BindAddrs = testData.bindAddrs
cfg.BindPort = testData.bindPort
Expand All @@ -97,7 +160,10 @@ func TestFinalAdvertiseAddr(t *testing.T) {
}

func TestNonIPsAreRejected(t *testing.T) {
cfg := TCPTransportConfig{BindAddrs: flagext.StringSlice{"localhost"}}
cfg := TCPTransportConfig{
BindAddrs: flagext.StringSlice{"localhost"},
MaxConcurrentWrites: rand.Intn(10) + 1,
}
_, err := NewTCPTransport(cfg, nil, nil)
require.EqualError(t, err, `could not parse bind addr "localhost" as IP address`)
}

0 comments on commit 3d073ca

Please sign in to comment.