diff --git a/example_test.go b/example_test.go index 6aa93636c..3ad367320 100644 --- a/example_test.go +++ b/example_test.go @@ -89,6 +89,19 @@ func ExampleConn_Subscribe() { }) } +func ExampleConn_ForceReconnect() { + nc, _ := nats.Connect(nats.DefaultURL) + defer nc.Close() + + nc.Subscribe("foo", func(m *nats.Msg) { + fmt.Printf("Received a message: %s\n", string(m.Data)) + }) + + // Reconnect to the server. + // the subscription will be recreated after the reconnect. + nc.ForceReconnect() +} + // This Example shows a synchronous subscriber. func ExampleConn_SubscribeSync() { nc, _ := nats.Connect(nats.DefaultURL) diff --git a/nats.go b/nats.go index 8c0796a89..d94c9a9c7 100644 --- a/nats.go +++ b/nats.go @@ -2161,6 +2161,47 @@ func (nc *Conn) waitForExits() { nc.wg.Wait() } +// ForceReconnect forces a reconnect attempt to the server. +// This is a non-blocking call and will start the reconnect +// process without waiting for it to complete. +// +// If the connection is already in the process of reconnecting, +// this call will force an immediate reconnect attempt (bypassing +// the current reconnect delay). +func (nc *Conn) ForceReconnect() error { + nc.mu.Lock() + defer nc.mu.Unlock() + + if nc.isClosed() { + return ErrConnectionClosed + } + if nc.isReconnecting() { + // if we're already reconnecting, force a reconnect attempt + // even if we're in the middle of a backoff + if nc.rqch != nil { + close(nc.rqch) + } + return nil + } + + // Clear any queued pongs + nc.clearPendingFlushCalls() + + // Clear any queued and blocking requests. + nc.clearPendingRequestCalls() + + // Stop ping timer if set. + nc.stopPingTimer() + + // Go ahead and make sure we have flushed the outbound + nc.bw.flush() + nc.conn.Close() + + nc.changeConnStatus(RECONNECTING) + go nc.doReconnect(nil, true) + return nil +} + // ConnectedUrl reports the connected server's URL func (nc *Conn) ConnectedUrl() string { if nc == nil { @@ -2420,7 +2461,7 @@ func (nc *Conn) connect() (bool, error) { nc.setup() nc.changeConnStatus(RECONNECTING) nc.bw.switchToPending() - go nc.doReconnect(ErrNoServers) + go nc.doReconnect(ErrNoServers, false) err = nil } else { nc.current = nil @@ -2720,7 +2761,7 @@ func (nc *Conn) stopPingTimer() { // Try to reconnect using the option parameters. // This function assumes we are allowed to reconnect. -func (nc *Conn) doReconnect(err error) { +func (nc *Conn) doReconnect(err error, forceReconnect bool) { // We want to make sure we have the other watchers shutdown properly // here before we proceed past this point. nc.waitForExits() @@ -2776,7 +2817,8 @@ func (nc *Conn) doReconnect(err error) { break } - doSleep := i+1 >= len(nc.srvPool) + doSleep := i+1 >= len(nc.srvPool) && !forceReconnect + forceReconnect = false nc.mu.Unlock() if !doSleep { @@ -2803,6 +2845,12 @@ func (nc *Conn) doReconnect(err error) { select { case <-rqch: rt.Stop() + + // we need to reset the rqch channel to avoid + // closing a closed channel in the next iteration + nc.mu.Lock() + nc.rqch = make(chan struct{}) + nc.mu.Unlock() case <-rt.C: } } @@ -2872,9 +2920,6 @@ func (nc *Conn) doReconnect(err error) { // Done with the pending buffer nc.bw.doneWithPending() - // This is where we are truly connected. - nc.status = CONNECTED - // Queue up the correct callback. If we are in initial connect state // (using retry on failed connect), we will call the ConnectedCB, // otherwise the ReconnectedCB. @@ -2930,7 +2975,7 @@ func (nc *Conn) processOpErr(err error) { // Clear any queued pongs, e.g. pending flush calls. nc.clearPendingFlushCalls() - go nc.doReconnect(err) + go nc.doReconnect(err, false) nc.mu.Unlock() return } diff --git a/test/conn_test.go b/test/conn_test.go index 7e5fcab01..afc5025b3 100644 --- a/test/conn_test.go +++ b/test/conn_test.go @@ -2946,16 +2946,6 @@ func TestRetryOnFailedConnectWithTLSError(t *testing.T) { } func TestConnStatusChangedEvents(t *testing.T) { - waitForStatus := func(t *testing.T, ch chan nats.Status, expected nats.Status) { - select { - case s := <-ch: - if s != expected { - t.Fatalf("Expected status: %s; got: %s", expected, s) - } - case <-time.After(5 * time.Second): - t.Fatalf("Timeout waiting for status %q", expected) - } - } t.Run("default events", func(t *testing.T) { s := RunDefaultServer() nc, err := nats.Connect(s.ClientURL()) @@ -2978,15 +2968,15 @@ func TestConnStatusChangedEvents(t *testing.T) { time.Sleep(50 * time.Millisecond) s.Shutdown() - waitForStatus(t, newStatus, nats.RECONNECTING) + WaitOnChannel(t, newStatus, nats.RECONNECTING) s = RunDefaultServer() defer s.Shutdown() - waitForStatus(t, newStatus, nats.CONNECTED) + WaitOnChannel(t, newStatus, nats.CONNECTED) nc.Close() - waitForStatus(t, newStatus, nats.CLOSED) + WaitOnChannel(t, newStatus, nats.CLOSED) select { case s := <-newStatus: @@ -3019,7 +3009,7 @@ func TestConnStatusChangedEvents(t *testing.T) { s = RunDefaultServer() defer s.Shutdown() nc.Close() - waitForStatus(t, newStatus, nats.CLOSED) + WaitOnChannel(t, newStatus, nats.CLOSED) select { case s := <-newStatus: diff --git a/test/helper_test.go b/test/helper_test.go index 9c04a40f9..7f2aedf0c 100644 --- a/test/helper_test.go +++ b/test/helper_test.go @@ -54,6 +54,18 @@ func WaitTime(ch chan bool, timeout time.Duration) error { return errors.New("timeout") } +func WaitOnChannel[T comparable](t *testing.T, ch <-chan T, expected T) { + t.Helper() + select { + case s := <-ch: + if s != expected { + t.Fatalf("Expected result: %v; got: %v", expected, s) + } + case <-time.After(5 * time.Second): + t.Fatalf("Timeout waiting for result %v", expected) + } +} + func stackFatalf(t tLogger, f string, args ...any) { lines := make([]string, 0, 32) msg := fmt.Sprintf(f, args...) diff --git a/test/reconnect_test.go b/test/reconnect_test.go index 66cc9b6ca..e543db72e 100644 --- a/test/reconnect_test.go +++ b/test/reconnect_test.go @@ -853,7 +853,7 @@ func TestAuthExpiredReconnect(t *testing.T) { jwtCB := func() (string, error) { claims := jwt.NewUserClaims("test") - claims.Expires = time.Now().Add(500 * time.Millisecond).Unix() + claims.Expires = time.Now().Add(time.Second).Unix() claims.Subject = upub jwt, err := claims.Encode(akp) if err != nil { @@ -884,21 +884,218 @@ func TestAuthExpiredReconnect(t *testing.T) { case <-time.After(2 * time.Second): t.Fatal("Did not get the auth expired error") } - select { - case s := <-stasusCh: - if s != nats.RECONNECTING { - t.Fatalf("Expected to be in reconnecting state after jwt expires, got %v", s) + WaitOnChannel(t, stasusCh, nats.RECONNECTING) + WaitOnChannel(t, stasusCh, nats.CONNECTED) + nc.Close() +} + +func TestForceReconnect(t *testing.T) { + s := RunDefaultServer() + + nc, err := nats.Connect(s.ClientURL(), nats.ReconnectWait(10*time.Second)) + if err != nil { + t.Fatalf("Unexpected error on connect: %v", err) + } + + statusCh := nc.StatusChanged(nats.RECONNECTING, nats.CONNECTED) + defer close(statusCh) + newStatus := make(chan nats.Status, 10) + // non-blocking channel, so we need to be constantly listening + go func() { + for { + s, ok := <-statusCh + if !ok { + return + } + newStatus <- s } - case <-time.After(2 * time.Second): - t.Fatal("Did not get the status change") + }() + + sub, err := nc.SubscribeSync("foo") + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + if err := nc.Publish("foo", []byte("msg")); err != nil { + t.Fatalf("Error on publish: %v", err) + } + _, err = sub.NextMsg(time.Second) + if err != nil { + t.Fatalf("Error getting message: %v", err) + } + + // Force a reconnect + err = nc.ForceReconnect() + if err != nil { + t.Fatalf("Unexpected error on reconnect: %v", err) + } + + WaitOnChannel(t, newStatus, nats.RECONNECTING) + WaitOnChannel(t, newStatus, nats.CONNECTED) + + if err := nc.Publish("foo", []byte("msg")); err != nil { + t.Fatalf("Error on publish: %v", err) + } + _, err = sub.NextMsg(time.Second) + if err != nil { + t.Fatalf("Error getting message: %v", err) + } + + // shutdown server and then force a reconnect + s.Shutdown() + WaitOnChannel(t, newStatus, nats.RECONNECTING) + _, err = sub.NextMsg(100 * time.Millisecond) + if err == nil { + t.Fatal("Expected error getting message") + } + + // restart server + s = RunDefaultServer() + defer s.Shutdown() + + if err := nc.ForceReconnect(); err != nil { + t.Fatalf("Unexpected error on reconnect: %v", err) + } + // wait for the reconnect + // because the connection has long ReconnectWait, + // if force reconnect does not work, the test will timeout + WaitOnChannel(t, newStatus, nats.CONNECTED) + + if err := nc.Publish("foo", []byte("msg")); err != nil { + t.Fatalf("Error on publish: %v", err) + } + _, err = sub.NextMsg(time.Second) + if err != nil { + t.Fatalf("Error getting message: %v", err) + } + nc.Close() +} + +func TestForceReconnectDisallowReconnect(t *testing.T) { + s := RunDefaultServer() + defer s.Shutdown() + + nc, err := nats.Connect(s.ClientURL(), nats.NoReconnect()) + if err != nil { + t.Fatalf("Unexpected error on connect: %v", err) + } + defer nc.Close() + + statusCh := nc.StatusChanged(nats.RECONNECTING, nats.CONNECTED) + defer close(statusCh) + newStatus := make(chan nats.Status, 10) + // non-blocking channel, so we need to be constantly listening + go func() { + for { + s, ok := <-statusCh + if !ok { + return + } + newStatus <- s + } + }() + + sub, err := nc.SubscribeSync("foo") + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + if err := nc.Publish("foo", []byte("msg")); err != nil { + t.Fatalf("Error on publish: %v", err) + } + _, err = sub.NextMsg(time.Second) + if err != nil { + t.Fatalf("Error getting message: %v", err) + } + + // Force a reconnect + err = nc.ForceReconnect() + if err != nil { + t.Fatalf("Unexpected error on reconnect: %v", err) } + + WaitOnChannel(t, newStatus, nats.RECONNECTING) + WaitOnChannel(t, newStatus, nats.CONNECTED) + + if err := nc.Publish("foo", []byte("msg")); err != nil { + t.Fatalf("Error on publish: %v", err) + } + _, err = sub.NextMsg(time.Second) + if err != nil { + t.Fatalf("Error getting message: %v", err) + } + +} + +func TestAuthExpiredForceReconnect(t *testing.T) { + ts := runTrustServer() + defer ts.Shutdown() + + _, err := nats.Connect(ts.ClientURL()) + if err == nil { + t.Fatalf("Expecting an error on connect") + } + ukp, err := nkeys.FromSeed(uSeed) + if err != nil { + t.Fatalf("Error creating user key pair: %v", err) + } + upub, err := ukp.PublicKey() + if err != nil { + t.Fatalf("Error getting user public key: %v", err) + } + akp, err := nkeys.FromSeed(aSeed) + if err != nil { + t.Fatalf("Error creating account key pair: %v", err) + } + + jwtCB := func() (string, error) { + claims := jwt.NewUserClaims("test") + claims.Expires = time.Now().Add(time.Second).Unix() + claims.Subject = upub + jwt, err := claims.Encode(akp) + if err != nil { + return "", err + } + return jwt, nil + } + sigCB := func(nonce []byte) ([]byte, error) { + kp, _ := nkeys.FromSeed(uSeed) + sig, _ := kp.Sign(nonce) + return sig, nil + } + + errCh := make(chan error, 1) + nc, err := nats.Connect(ts.ClientURL(), nats.UserJWT(jwtCB, sigCB), nats.ReconnectWait(10*time.Second), + nats.ErrorHandler(func(_ *nats.Conn, _ *nats.Subscription, err error) { + errCh <- err + })) + if err != nil { + t.Fatalf("Expected to connect, got %v", err) + } + defer nc.Close() + statusCh := nc.StatusChanged(nats.RECONNECTING, nats.CONNECTED) + defer close(statusCh) + newStatus := make(chan nats.Status, 10) + // non-blocking channel, so we need to be constantly listening + go func() { + for { + s, ok := <-statusCh + if !ok { + return + } + newStatus <- s + } + }() + time.Sleep(100 * time.Millisecond) select { - case s := <-stasusCh: - if s != nats.CONNECTED { - t.Fatalf("Expected to reconnect, got %v", s) + case err := <-errCh: + if !errors.Is(err, nats.ErrAuthExpired) { + t.Fatalf("Expected auth expired error, got %v", err) } case <-time.After(2 * time.Second): - t.Fatal("Did not get the status change") + t.Fatal("Did not get the auth expired error") } - nc.Close() + if err := nc.ForceReconnect(); err != nil { + t.Fatalf("Unexpected error on reconnect: %v", err) + } + WaitOnChannel(t, newStatus, nats.RECONNECTING) + WaitOnChannel(t, newStatus, nats.CONNECTED) } diff --git a/test/sub_test.go b/test/sub_test.go index 0bf2880c1..f0f83a8d5 100644 --- a/test/sub_test.go +++ b/test/sub_test.go @@ -1617,18 +1617,6 @@ func TestSubscribe_ClosedHandler(t *testing.T) { } func TestSubscriptionEvents(t *testing.T) { - - waitForStatus := func(t *testing.T, ch <-chan nats.SubStatus, expected nats.SubStatus) { - t.Helper() - select { - case s := <-ch: - if s != expected { - t.Fatalf("Expected status: %s; got: %s", expected, s) - } - case <-time.After(5 * time.Second): - t.Fatalf("Timeout waiting for status %q", expected) - } - } t.Run("default events", func(t *testing.T) { s := RunDefaultServer() defer s.Shutdown() @@ -1651,19 +1639,19 @@ func TestSubscriptionEvents(t *testing.T) { status := sub.StatusChanged() // initial status - waitForStatus(t, status, nats.SubscriptionActive) + WaitOnChannel(t, status, nats.SubscriptionActive) for i := 0; i < 11; i++ { nc.Publish("foo", []byte("Hello")) } - waitForStatus(t, status, nats.SubscriptionSlowConsumer) + WaitOnChannel(t, status, nats.SubscriptionSlowConsumer) close(blockChan) sub.Drain() - waitForStatus(t, status, nats.SubscriptionDraining) + WaitOnChannel(t, status, nats.SubscriptionDraining) - waitForStatus(t, status, nats.SubscriptionClosed) + WaitOnChannel(t, status, nats.SubscriptionClosed) }) t.Run("slow consumer event only", func(t *testing.T) { @@ -1691,7 +1679,7 @@ func TestSubscriptionEvents(t *testing.T) { for i := 0; i < 20; i++ { nc.Publish("foo", []byte("Hello")) } - waitForStatus(t, status, nats.SubscriptionSlowConsumer) + WaitOnChannel(t, status, nats.SubscriptionSlowConsumer) close(blockChan) // now try with sync sub @@ -1706,7 +1694,7 @@ func TestSubscriptionEvents(t *testing.T) { for i := 0; i < 20; i++ { nc.Publish("foo", []byte("Hello")) } - waitForStatus(t, status, nats.SubscriptionSlowConsumer) + WaitOnChannel(t, status, nats.SubscriptionSlowConsumer) }) t.Run("do not block channel if it's not read", func(t *testing.T) { @@ -1730,7 +1718,7 @@ func TestSubscriptionEvents(t *testing.T) { } sub.SetPendingLimits(10, 1024) status := sub.StatusChanged() - waitForStatus(t, status, nats.SubscriptionActive) + WaitOnChannel(t, status, nats.SubscriptionActive) // chan length is 10, so make sure we switch state more times for i := 0; i < 20; i++ {