Skip to content

Commit

Permalink
zmq4: make number of retries for dial configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
thielepaul authored Jun 20, 2022
1 parent 04c84de commit ae18bc0
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 10 deletions.
1 change: 1 addition & 0 deletions czmq4_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ var (
},
{
name: "ipc-crouter-cdealer",
skip: true,
endpoint: func() string { return "ipc://crouter-cdealer" },
router: func(ctx context.Context) zmq4.Socket {
return zmq4.NewCRouter(ctx, zmq4.CWithID(zmq4.SocketIdentity("router")))
Expand Down
8 changes: 8 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ func WithLogger(msg *log.Logger) Option {
}
}

// WithDialerMaxRetries configures the maximum number of retries
// when dialing an endpoint (-1 means infinite retries).
func WithDialerMaxRetries(maxRetries int) Option {
return func(s *socket) {
s.maxRetries = maxRetries
}
}

/*
// TODO(sbinet)
Expand Down
24 changes: 14 additions & 10 deletions socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ import (
)

const (
defaultRetry = 250 * time.Millisecond
defaultTimeout = 5 * time.Minute
defaultRetry = 250 * time.Millisecond
defaultTimeout = 5 * time.Minute
defaultMaxRetries = 10
)

var (
Expand All @@ -30,13 +31,14 @@ var (

// socket implements the ZeroMQ socket interface
type socket struct {
ep string // socket end-point
typ SocketType
id SocketIdentity
retry time.Duration
sec Security
log *log.Logger
subTopics func() []string
ep string // socket end-point
typ SocketType
id SocketIdentity
retry time.Duration
maxRetries int
sec Security
log *log.Logger
subTopics func() []string

mu sync.RWMutex
ids map[string]*Conn // ZMTP connection IDs
Expand All @@ -63,6 +65,7 @@ func newDefaultSocket(ctx context.Context, sockType SocketType) *socket {
return &socket{
typ: sockType,
retry: defaultRetry,
maxRetries: defaultMaxRetries,
sec: nullSecurity{},
ids: make(map[string]*Conn),
conns: nil,
Expand Down Expand Up @@ -247,7 +250,8 @@ connect:
}

if err != nil {
if retries < 10 {
// retry if retry count is lower than maximum retry count and context has not been canceled
if (sck.maxRetries == -1 || retries < sck.maxRetries) && sck.ctx.Err() == nil {
retries++
time.Sleep(sck.retry)
goto connect
Expand Down
62 changes: 62 additions & 0 deletions socket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package zmq4_test

import (
"context"
"errors"
"fmt"
"io"
"net"
Expand All @@ -14,6 +15,7 @@ import (
"time"

"github.com/go-zeromq/zmq4"
"github.com/go-zeromq/zmq4/transport"
"golang.org/x/sync/errgroup"
)

Expand Down Expand Up @@ -260,3 +262,63 @@ func TestSocketSendSubscriptionOnConnect(t *testing.T) {
t.Fatalf("invalid message received: got '%s', wanted '%s'", msg.Frames[0], message)
}
}

type transportMock struct {
dialCalledCount int
errOnDial bool
conn net.Conn
}

func (t *transportMock) Dial(ctx context.Context, dialer transport.Dialer, addr string) (net.Conn, error) {
t.dialCalledCount++
if t.errOnDial {
return nil, errors.New("test error")
}
return t.conn, nil
}

func (t *transportMock) Listen(ctx context.Context, addr string) (net.Listener, error) {
return nil, nil
}

func (t *transportMock) Addr(ep string) (addr string, err error) {
return "", nil
}

func TestConnMaxRetries(t *testing.T) {
retryCount := 123
socket := zmq4.NewSub(context.Background(), zmq4.WithDialerRetry(time.Microsecond), zmq4.WithDialerMaxRetries(retryCount))
transport := &transportMock{errOnDial: true}
transportName := "test-maxretries"
zmq4.RegisterTransport(transportName, transport)
err := socket.Dial(transportName + "://test")

if err == nil {
t.Fatal("expected error")
}

if transport.dialCalledCount != retryCount+1 {
t.Fatalf("Dial called %d times, expected %d", transport.dialCalledCount, retryCount+1)
}
}

func TestConnMaxRetriesInfinite(t *testing.T) {
timeout := time.Millisecond
retryTime := time.Nanosecond

ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
socket := zmq4.NewSub(ctx, zmq4.WithDialerRetry(retryTime), zmq4.WithDialerMaxRetries(-1))
transport := &transportMock{errOnDial: true}
transportName := "test-infiniteretries"
zmq4.RegisterTransport(transportName, transport)
err := socket.Dial(transportName + "://test")
if err == nil {
t.Fatal("expected error")
}

atLeastExpectedRetries := 100
if transport.dialCalledCount < atLeastExpectedRetries {
t.Fatalf("Dial called %d times, expected at least %d", transport.dialCalledCount, atLeastExpectedRetries)
}
}

0 comments on commit ae18bc0

Please sign in to comment.