Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

zmq4: make number of retries for dial configurable #126

Merged
merged 2 commits into from
Jun 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions czmq4_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ var (
},
{
name: "ipc-crouter-cdealer",
skip: true,
endpoint: func() string { return "ipc://crouter-cdealer" },
router: func(ctx context.Context) zmq4.Socket {
return zmq4.NewCRouter(ctx, zmq4.CWithID(zmq4.SocketIdentity("router")))
Expand Down
8 changes: 8 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ func WithLogger(msg *log.Logger) Option {
}
}

// WithDialerMaxRetries configures the maximum number of retries
// when dialing an endpoint (-1 means infinite retries).
func WithDialerMaxRetries(maxRetries int) Option {
sbinet marked this conversation as resolved.
Show resolved Hide resolved
return func(s *socket) {
s.maxRetries = maxRetries
}
}

/*
// TODO(sbinet)

Expand Down
24 changes: 14 additions & 10 deletions socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ import (
)

const (
defaultRetry = 250 * time.Millisecond
defaultTimeout = 5 * time.Minute
defaultRetry = 250 * time.Millisecond
defaultTimeout = 5 * time.Minute
defaultMaxRetries = 10
)

var (
Expand All @@ -30,13 +31,14 @@ var (

// socket implements the ZeroMQ socket interface
type socket struct {
ep string // socket end-point
typ SocketType
id SocketIdentity
retry time.Duration
sec Security
log *log.Logger
subTopics func() []string
ep string // socket end-point
typ SocketType
id SocketIdentity
retry time.Duration
maxRetries int
sec Security
log *log.Logger
subTopics func() []string

mu sync.RWMutex
ids map[string]*Conn // ZMTP connection IDs
Expand All @@ -63,6 +65,7 @@ func newDefaultSocket(ctx context.Context, sockType SocketType) *socket {
return &socket{
typ: sockType,
retry: defaultRetry,
maxRetries: defaultMaxRetries,
sec: nullSecurity{},
ids: make(map[string]*Conn),
conns: nil,
Expand Down Expand Up @@ -247,7 +250,8 @@ connect:
}

if err != nil {
if retries < 10 {
// retry if retry count is lower than maximum retry count and context has not been canceled
if (sck.maxRetries == -1 || retries < sck.maxRetries) && sck.ctx.Err() == nil {
retries++
time.Sleep(sck.retry)
goto connect
Expand Down
62 changes: 62 additions & 0 deletions socket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package zmq4_test

import (
"context"
"errors"
"fmt"
"io"
"net"
Expand All @@ -14,6 +15,7 @@ import (
"time"

"github.com/go-zeromq/zmq4"
"github.com/go-zeromq/zmq4/transport"
"golang.org/x/sync/errgroup"
)

Expand Down Expand Up @@ -260,3 +262,63 @@ func TestSocketSendSubscriptionOnConnect(t *testing.T) {
t.Fatalf("invalid message received: got '%s', wanted '%s'", msg.Frames[0], message)
}
}

type transportMock struct {
dialCalledCount int
errOnDial bool
conn net.Conn
}

func (t *transportMock) Dial(ctx context.Context, dialer transport.Dialer, addr string) (net.Conn, error) {
t.dialCalledCount++
if t.errOnDial {
return nil, errors.New("test error")
}
return t.conn, nil
}

func (t *transportMock) Listen(ctx context.Context, addr string) (net.Listener, error) {
return nil, nil
}

func (t *transportMock) Addr(ep string) (addr string, err error) {
return "", nil
}

func TestConnMaxRetries(t *testing.T) {
retryCount := 123
socket := zmq4.NewSub(context.Background(), zmq4.WithDialerRetry(time.Microsecond), zmq4.WithDialerMaxRetries(retryCount))
transport := &transportMock{errOnDial: true}
transportName := "test-maxretries"
zmq4.RegisterTransport(transportName, transport)
err := socket.Dial(transportName + "://test")

if err == nil {
t.Fatal("expected error")
}

if transport.dialCalledCount != retryCount+1 {
t.Fatalf("Dial called %d times, expected %d", transport.dialCalledCount, retryCount+1)
}
}

func TestConnMaxRetriesInfinite(t *testing.T) {
timeout := time.Millisecond
retryTime := time.Nanosecond

ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
socket := zmq4.NewSub(ctx, zmq4.WithDialerRetry(retryTime), zmq4.WithDialerMaxRetries(-1))
transport := &transportMock{errOnDial: true}
transportName := "test-infiniteretries"
zmq4.RegisterTransport(transportName, transport)
err := socket.Dial(transportName + "://test")
if err == nil {
t.Fatal("expected error")
}

atLeastExpectedRetries := 100
if transport.dialCalledCount < atLeastExpectedRetries {
t.Fatalf("Dial called %d times, expected at least %d", transport.dialCalledCount, atLeastExpectedRetries)
}
}