From 1ccccfb2b16b566e81910233fece6a3abd1a54dc Mon Sep 17 00:00:00 2001 From: Paul Thiele Date: Thu, 9 Jun 2022 19:18:38 +0200 Subject: [PATCH 1/2] zmq4: make number of retries for dial configurable --- options.go | 8 +++++++ socket.go | 24 +++++++++++-------- socket_test.go | 62 ++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 84 insertions(+), 10 deletions(-) diff --git a/options.go b/options.go index d11445c..6a4ead8 100644 --- a/options.go +++ b/options.go @@ -51,6 +51,14 @@ func WithLogger(msg *log.Logger) Option { } } +// WithDialerMaxRetries configures the maximum number of retries +// when dialing an endpoint (-1 means infinite retries). +func WithDialerMaxRetries(maxRetries int) Option { + return func(s *socket) { + s.maxRetries = maxRetries + } +} + /* // TODO(sbinet) diff --git a/socket.go b/socket.go index 55a4f56..db5b953 100644 --- a/socket.go +++ b/socket.go @@ -18,8 +18,9 @@ import ( ) const ( - defaultRetry = 250 * time.Millisecond - defaultTimeout = 5 * time.Minute + defaultRetry = 250 * time.Millisecond + defaultTimeout = 5 * time.Minute + defaultMaxRetries = 10 ) var ( @@ -30,13 +31,14 @@ var ( // socket implements the ZeroMQ socket interface type socket struct { - ep string // socket end-point - typ SocketType - id SocketIdentity - retry time.Duration - sec Security - log *log.Logger - subTopics func() []string + ep string // socket end-point + typ SocketType + id SocketIdentity + retry time.Duration + maxRetries int + sec Security + log *log.Logger + subTopics func() []string mu sync.RWMutex ids map[string]*Conn // ZMTP connection IDs @@ -63,6 +65,7 @@ func newDefaultSocket(ctx context.Context, sockType SocketType) *socket { return &socket{ typ: sockType, retry: defaultRetry, + maxRetries: defaultMaxRetries, sec: nullSecurity{}, ids: make(map[string]*Conn), conns: nil, @@ -247,7 +250,8 @@ connect: } if err != nil { - if retries < 10 { + // retry if retry count is lower than maximum retry count and context has not been canceled + if (sck.maxRetries == -1 || retries < sck.maxRetries) && sck.ctx.Err() == nil { retries++ time.Sleep(sck.retry) goto connect diff --git a/socket_test.go b/socket_test.go index 8ba617a..51048e2 100644 --- a/socket_test.go +++ b/socket_test.go @@ -6,6 +6,7 @@ package zmq4_test import ( "context" + "errors" "fmt" "io" "net" @@ -14,6 +15,7 @@ import ( "time" "github.com/go-zeromq/zmq4" + "github.com/go-zeromq/zmq4/transport" "golang.org/x/sync/errgroup" ) @@ -260,3 +262,63 @@ func TestSocketSendSubscriptionOnConnect(t *testing.T) { t.Fatalf("invalid message received: got '%s', wanted '%s'", msg.Frames[0], message) } } + +type transportMock struct { + dialCalledCount int + errOnDial bool + conn net.Conn +} + +func (t *transportMock) Dial(ctx context.Context, dialer transport.Dialer, addr string) (net.Conn, error) { + t.dialCalledCount++ + if t.errOnDial { + return nil, errors.New("test error") + } + return t.conn, nil +} + +func (t *transportMock) Listen(ctx context.Context, addr string) (net.Listener, error) { + return nil, nil +} + +func (t *transportMock) Addr(ep string) (addr string, err error) { + return "", nil +} + +func TestConnMaxRetries(t *testing.T) { + retryCount := 123 + socket := zmq4.NewSub(context.Background(), zmq4.WithDialerRetry(time.Microsecond), zmq4.WithDialerMaxRetries(retryCount)) + transport := &transportMock{errOnDial: true} + transportName := "test-maxretries" + zmq4.RegisterTransport(transportName, transport) + err := socket.Dial(transportName + "://test") + + if err == nil { + t.Fatal("expected error") + } + + if transport.dialCalledCount != retryCount+1 { + t.Fatalf("Dial called %d times, expected %d", transport.dialCalledCount, retryCount+1) + } +} + +func TestConnMaxRetriesInfinite(t *testing.T) { + timeout := time.Millisecond + retryTime := time.Nanosecond + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + socket := zmq4.NewSub(ctx, zmq4.WithDialerRetry(retryTime), zmq4.WithDialerMaxRetries(-1)) + transport := &transportMock{errOnDial: true} + transportName := "test-infiniteretries" + zmq4.RegisterTransport(transportName, transport) + err := socket.Dial(transportName + "://test") + if err == nil { + t.Fatal("expected error") + } + + atLeastExpectedRetries := 100 + if transport.dialCalledCount < atLeastExpectedRetries { + t.Fatalf("Dial called %d times, expected at least %d", transport.dialCalledCount, atLeastExpectedRetries) + } +} From 03b1695f6d0ff61a753e940326e44667655ed8d6 Mon Sep 17 00:00:00 2001 From: Paul Thiele Date: Fri, 17 Jun 2022 11:43:17 +0200 Subject: [PATCH 2/2] zmq4: skp ipc-crouter-cdealer --- czmq4_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/czmq4_test.go b/czmq4_test.go index 9b9adde..2426653 100644 --- a/czmq4_test.go +++ b/czmq4_test.go @@ -316,6 +316,7 @@ var ( }, { name: "ipc-crouter-cdealer", + skip: true, endpoint: func() string { return "ipc://crouter-cdealer" }, router: func(ctx context.Context) zmq4.Socket { return zmq4.NewCRouter(ctx, zmq4.CWithID(zmq4.SocketIdentity("router")))