Skip to content

Commit

Permalink
Follow HTTP redirects after failed WS dials (#251)
Browse files Browse the repository at this point in the history
This commit allows the opamp client to follow redirects after websocket handshake failures. Redirect following is not implemented by gorilla/websocket, but can be handled by inspecting the returned response object for 3xx status and Location header.

Closes #250
  • Loading branch information
echlebek authored Feb 23, 2024
1 parent ce8a8dd commit 8b26910
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 1 deletion.
27 changes: 26 additions & 1 deletion client/wsclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net/http"
"net/url"
"sync"
"sync/atomic"
"time"

"github.com/cenkalti/backoff/v4"
Expand Down Expand Up @@ -35,6 +36,10 @@ type wsClient struct {

// The sender is responsible for sending portion of the OpAMP protocol.
sender *internal.WSSender

// last non-nil internal error that was encountered in the conn retry loop,
// currently used only for testing.
lastInternalErr atomic.Pointer[error]
}

// NewWebSocket creates a new OpAMP Client that uses WebSocket transport.
Expand Down Expand Up @@ -131,8 +136,27 @@ func (c *wsClient) tryConnectOnce(ctx context.Context) (err error, retryAfter sh
c.common.Callbacks.OnConnectFailed(ctx, err)
}
if resp != nil {
c.common.Logger.Errorf(ctx, "Server responded with status=%v", resp.Status)
duration := sharedinternal.ExtractRetryAfterHeader(resp)
if resp.StatusCode >= 300 && resp.StatusCode < 400 {
// very liberal handling of 3xx that largely ignores HTTP semantics
redirect, err := resp.Location()
if err != nil {
c.common.Logger.Errorf(ctx, "%d redirect, but no valid location: %s", resp.StatusCode, err)
return err, duration
}
// rewrite the scheme for the sake of tolerance
if redirect.Scheme == "http" {
redirect.Scheme = "ws"
} else if redirect.Scheme == "https" {
redirect.Scheme = "wss"
}
c.common.Logger.Debugf(ctx, "%d redirect to %s", resp.StatusCode, redirect)
// Set the URL to the redirect, so that it connects to it on the
// next cycle.
c.url = redirect
} else {
c.common.Logger.Errorf(ctx, "Server responded with status=%v", resp.Status)
}
return err, duration
}
return err, sharedinternal.OptionalDuration{Defined: false}
Expand Down Expand Up @@ -167,6 +191,7 @@ func (c *wsClient) ensureConnected(ctx context.Context) error {
case <-timer.C:
{
if err, retryAfter := c.tryConnectOnce(ctx); err != nil {
c.lastInternalErr.Store(&err)
if errors.Is(err, context.Canceled) {
c.common.Logger.Debugf(ctx, "Client is stopped, will not try anymore.")
return err
Expand Down
85 changes: 85 additions & 0 deletions client/wsclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package client
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -177,3 +180,85 @@ func TestVerifyWSCompress(t *testing.T) {
})
}
}

func redirectServer(to string, status int) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
http.Redirect(w, req, to, http.StatusSeeOther)
}))
}

func errServer() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(302)
}))
}

func TestRedirectWS(t *testing.T) {
redirectee := internal.StartMockServer(t)
tests := []struct {
Name string
Redirector *httptest.Server
ExpError bool
}{
{
Name: "redirect ws scheme",
Redirector: redirectServer("ws://"+redirectee.Endpoint, 302),
},
{
Name: "redirect http scheme",
Redirector: redirectServer("http://"+redirectee.Endpoint, 302),
},
{
Name: "missing location header",
Redirector: errServer(),
ExpError: true,
},
}

for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
var conn atomic.Value
redirectee.OnWSConnect = func(c *websocket.Conn) {
conn.Store(c)
}

// Start an OpAMP/WebSocket client.
var connected int64
var connectErr atomic.Value
settings := types.StartSettings{
Callbacks: types.CallbacksStruct{
OnConnectFunc: func(ctx context.Context) {
atomic.StoreInt64(&connected, 1)
},
OnConnectFailedFunc: func(ctx context.Context, err error) {
if err != websocket.ErrBadHandshake {
connectErr.Store(err)
}
},
},
}
reURL, err := url.Parse(test.Redirector.URL)
assert.NoError(t, err)
reURL.Scheme = "ws"
settings.OpAMPServerURL = reURL.String()
client := NewWebSocket(nil)
startClient(t, settings, client)

// Wait for connection to be established.
eventually(t, func() bool {
return conn.Load() != nil || connectErr.Load() != nil || client.lastInternalErr.Load() != nil
})
if test.ExpError {
if connectErr.Load() == nil && client.lastInternalErr.Load() == nil {
t.Error("expected non-nil error")
}
} else {
assert.True(t, connectErr.Load() == nil)
}

// Stop the client.
err = client.Stop(context.Background())
assert.NoError(t, err)
})
}
}

0 comments on commit 8b26910

Please sign in to comment.