Skip to content

Commit

Permalink
zmq4: make number of retries for dial configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
thielepaul committed Jun 17, 2022
1 parent 04c84de commit c33556f
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 10 deletions.
8 changes: 8 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ func WithLogger(msg *log.Logger) Option {
}
}

// WithDialerMaxRetries configures the maximum number of retries
// when dialing an endpoint (-1 means infinite retries).
func WithDialerMaxRetries(maxRetries int) Option {
return func(s *socket) {
s.maxRetries = maxRetries
}
}

/*
// TODO(sbinet)
Expand Down
24 changes: 14 additions & 10 deletions socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ import (
)

const (
defaultRetry = 250 * time.Millisecond
defaultTimeout = 5 * time.Minute
defaultRetry = 250 * time.Millisecond
defaultTimeout = 5 * time.Minute
defaultMaxRetries = 10
)

var (
Expand All @@ -30,13 +31,14 @@ var (

// socket implements the ZeroMQ socket interface
type socket struct {
ep string // socket end-point
typ SocketType
id SocketIdentity
retry time.Duration
sec Security
log *log.Logger
subTopics func() []string
ep string // socket end-point
typ SocketType
id SocketIdentity
retry time.Duration
maxRetries int
sec Security
log *log.Logger
subTopics func() []string

mu sync.RWMutex
ids map[string]*Conn // ZMTP connection IDs
Expand All @@ -63,6 +65,7 @@ func newDefaultSocket(ctx context.Context, sockType SocketType) *socket {
return &socket{
typ: sockType,
retry: defaultRetry,
maxRetries: defaultMaxRetries,
sec: nullSecurity{},
ids: make(map[string]*Conn),
conns: nil,
Expand Down Expand Up @@ -247,7 +250,8 @@ connect:
}

if err != nil {
if retries < 10 {
// retry if retry count is lower than maximum retry count and context has not been canceled
if (sck.maxRetries == -1 || retries < sck.maxRetries) && sck.ctx.Err() == nil {
retries++
time.Sleep(sck.retry)
goto connect
Expand Down
71 changes: 71 additions & 0 deletions socket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package zmq4_test

import (
"context"
"errors"
"fmt"
"io"
"net"
Expand All @@ -14,6 +15,7 @@ import (
"time"

"github.com/go-zeromq/zmq4"
"github.com/go-zeromq/zmq4/transport"
"golang.org/x/sync/errgroup"
)

Expand Down Expand Up @@ -260,3 +262,72 @@ func TestSocketSendSubscriptionOnConnect(t *testing.T) {
t.Fatalf("invalid message received: got '%s', wanted '%s'", msg.Frames[0], message)
}
}

type transportMock struct {
dialCalledCount int
errOnDial bool
conn net.Conn
}

func (t *transportMock) Dial(ctx context.Context, dialer transport.Dialer, addr string) (net.Conn, error) {
t.dialCalledCount++
if t.errOnDial {
return nil, errors.New("test error")
}
return t.conn, nil
}

func (t *transportMock) Listen(ctx context.Context, addr string) (net.Listener, error) {
return nil, nil
}

func (t *transportMock) Addr(ep string) (addr string, err error) {
return "", nil
}

func (t *transportMock) init() {
t.dialCalledCount = 0
t.errOnDial = false
t.conn = nil
zmq4.RegisterTransport("test", testTransport)
}

var testTransport = &transportMock{}

func TestConnMaxRetries(t *testing.T) {
retryCount := 123
socket := zmq4.NewSub(context.Background(), zmq4.WithDialerRetry(time.Microsecond), zmq4.WithDialerMaxRetries(retryCount))
testTransport.init()
testTransport.errOnDial = true
err := socket.Dial("test://test")

if err == nil {
t.Fatal("expected error")
}

if testTransport.dialCalledCount != retryCount+1 {
t.Fatalf("Dial called %d times, expected %d", testTransport.dialCalledCount, retryCount+1)
}
}

func TestConnMaxRetriesInfinite(t *testing.T) {
timeout := time.Millisecond
retryTime := time.Nanosecond

ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
socket := zmq4.NewSub(ctx, zmq4.WithDialerRetry(retryTime), zmq4.WithDialerMaxRetries(-1))
testTransport.init()
testTransport.errOnDial = true

err := socket.Dial("test://test")

if err == nil {
t.Fatal("expected error")
}

atLeastExpectedRetries := 100
if testTransport.dialCalledCount < atLeastExpectedRetries {
t.Fatalf("Dial called %d times, expected at least %d", testTransport.dialCalledCount, atLeastExpectedRetries)
}
}

0 comments on commit c33556f

Please sign in to comment.