From 13bb1073308dcd853b083ee1bf7c7d06713a3dad Mon Sep 17 00:00:00 2001 From: Eric Chlebek Date: Tue, 9 Apr 2024 14:53:52 -0700 Subject: [PATCH 1/5] Add additional test for redirect behaviour Verifies that redirect chains work --- client/wsclient_test.go | 53 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/client/wsclient_test.go b/client/wsclient_test.go index cc9fd87d..14506b30 100644 --- a/client/wsclient_test.go +++ b/client/wsclient_test.go @@ -392,6 +392,59 @@ func TestRedirectWS(t *testing.T) { } } +func TestRedirectWSFollowChain(t *testing.T) { + // test that redirect following is recursive + redirectee := internal.StartMockServer(t) + middle := redirectServer("http://"+redirectee.Endpoint, 302) + middleURL, err := url.Parse(middle.URL) + if err != nil { + // unlikely + t.Fatal(err) + } + redirector := redirectServer("http://"+middleURL.Host, 302) + + var conn atomic.Value + redirectee.OnWSConnect = func(c *websocket.Conn) { + conn.Store(c) + } + + // Start an OpAMP/WebSocket client. + var connected int64 + var connectErr atomic.Value + settings := types.StartSettings{ + Callbacks: types.CallbacksStruct{ + OnConnectFunc: func(ctx context.Context) { + atomic.StoreInt64(&connected, 1) + }, + OnConnectFailedFunc: func(ctx context.Context, err error) { + if err != websocket.ErrBadHandshake { + connectErr.Store(err) + } + }, + }, + } + reURL, err := url.Parse(redirector.URL) + if err != nil { + // unlikely + t.Fatal(err) + } + reURL.Scheme = "ws" + settings.OpAMPServerURL = reURL.String() + client := NewWebSocket(nil) + startClient(t, settings, client) + + // Wait for connection to be established. + eventually(t, func() bool { + return conn.Load() != nil || connectErr.Load() != nil || client.lastInternalErr.Load() != nil + }) + + assert.True(t, connectErr.Load() == nil) + + // Stop the client. + err = client.Stop(context.Background()) + assert.NoError(t, err) +} + func TestHandlesStopBeforeStart(t *testing.T) { client := NewWebSocket(nil) require.Error(t, client.Stop(context.Background())) From 591ff375c596feb34f2766cc35fc625cbbe72dec Mon Sep 17 00:00:00 2001 From: Eric Chlebek Date: Tue, 9 Apr 2024 14:54:35 -0700 Subject: [PATCH 2/5] Add CheckRedirect callback This commit adds a CheckRedirect callback that opamp-go will call before following a redirect from the server it's trying to connect to. Like in net/http, CheckRedirect can be used to observe the request chain that the client is taking while attempting to make a connection. The user can optionally terminate redirect following by returning an error from CheckRedirect. Unlike in net/http, the via parameter for CheckRedirect is a slice of responses. Since the user would have no other way to access these in the context of opamp-go, CheckRedirect makes them available so that users can know exactly what status codes and headers are set in the response. Another small improvement is that the error callback is no longer called when redirecting. This should help to prevent undue error logging by opamp-go consumers. Since the CheckRedirect callback is now available, it also doesn't represent any loss in functionality to opamp-go consumers. --- client/types/callbacks.go | 6 +++ client/wsclient.go | 80 ++++++++++++++++++++++++++++++++------- client/wsclient_test.go | 62 ++++++++++++++++++++++++++++-- go.mod | 3 +- go.sum | 2 + 5 files changed, 135 insertions(+), 18 deletions(-) diff --git a/client/types/callbacks.go b/client/types/callbacks.go index 02cef82d..9b08dd07 100644 --- a/client/types/callbacks.go +++ b/client/types/callbacks.go @@ -2,6 +2,7 @@ package types import ( "context" + "net/http" "github.com/open-telemetry/opamp-go/protobufs" ) @@ -110,6 +111,11 @@ type Callbacks struct { // OnCommand is called when the Server requests that the connected Agent perform a command. OnCommand func(ctx context.Context, command *protobufs.ServerToAgentCommand) error + + // CheckRedirect is called before following a redirect. It is similar in + // nature to the CheckRedirect in net/http's Client. If the value is nil, + // then the http client's CheckRedirect will not be altered. + CheckRedirect func(req *http.Request, via []*http.Response) error } func (c *Callbacks) SetDefaults() { diff --git a/client/wsclient.go b/client/wsclient.go index 6219a28f..22b2228b 100644 --- a/client/wsclient.go +++ b/client/wsclient.go @@ -48,6 +48,12 @@ type wsClient struct { // Network connection timeout used for the WebSocket closing handshake. // This field is currently only modified during testing. connShutdownTimeout time.Duration + + // responseChain is used for the "via" argument in CheckRedirect. + // It is appended to with every redirect followed, and zeroed on a succesful + // connection. responseChain should only be referred to by the goroutine that + // runs tryConnectOnce and its synchronous callees. + responseChain []*http.Response } // NewWebSocket creates a new OpAMP Client that uses WebSocket transport. @@ -151,11 +157,69 @@ func (c *wsClient) SendCustomMessage(message *protobufs.CustomMessage) (messageS return c.common.SendCustomMessage(message) } +// handleRedirect checks a failed websocket upgrade response for a 3xx response +// and a Location header. If found, it sets the URL to the location found in the +// header so that it is tried on the next retry, instead of the current URL. +func (c *wsClient) handleRedirect(ctx context.Context, resp *http.Response) error { + // append to the responseChain so that subsequent redirects will have access + c.responseChain = append(c.responseChain, resp) + + // very liberal handling of 3xx that largely ignores HTTP semantics + redirect, err := resp.Location() + if err != nil { + c.common.Logger.Errorf(ctx, "%d redirect, but no valid location: %s", resp.StatusCode, err) + return err + } + + // It's slightly tricky to make CheckRedirect work. The WS HTTP request is + // formed within the websocket library. To work around that, copy the + // previous request, available in the response, and set the URL to the new + // location. It should then result in the same URL that the websocket + // library will form. + nextRequest := resp.Request.Clone(ctx) + nextRequest.URL = redirect + + // if CheckRedirect results in an error, it gets returned, terminating + // redirection. As with stdlib, the error is wrapped in url.Error. + if c.common.Callbacks.CheckRedirect != nil { + if err := c.common.Callbacks.CheckRedirect(nextRequest, c.responseChain); err != nil { + return &url.Error{ + Op: "Get", + URL: nextRequest.URL.String(), + Err: err, + } + } + } + + // rewrite the scheme for the sake of tolerance + if redirect.Scheme == "http" { + redirect.Scheme = "ws" + } else if redirect.Scheme == "https" { + redirect.Scheme = "wss" + } + c.common.Logger.Debugf(ctx, "%d redirect to %s", resp.StatusCode, redirect) + + // Set the URL to the redirect, so that it connects to it on the + // next cycle. + c.url = redirect + + return nil +} + // Try to connect once. Returns an error if connection fails and optional retryAfter // duration to indicate to the caller to retry after the specified time as instructed // by the Server. func (c *wsClient) tryConnectOnce(ctx context.Context) (retryAfter sharedinternal.OptionalDuration, err error) { var resp *http.Response + var redirecting bool + defer func() { + if err != nil && !redirecting { + c.responseChain = nil + if c.common.Callbacks != nil && !c.common.IsStopping() { + c.common.Callbacks.OnConnectFailed(ctx, err) + } + } + }() conn, resp, err := c.dialer.DialContext(ctx, c.url.String(), c.getHeader()) if err != nil { if !c.common.IsStopping() { @@ -164,22 +228,10 @@ func (c *wsClient) tryConnectOnce(ctx context.Context) (retryAfter sharedinterna if resp != nil { duration := sharedinternal.ExtractRetryAfterHeader(resp) if resp.StatusCode >= 300 && resp.StatusCode < 400 { - // very liberal handling of 3xx that largely ignores HTTP semantics - redirect, err := resp.Location() - if err != nil { - c.common.Logger.Errorf(ctx, "%d redirect, but no valid location: %s", resp.StatusCode, err) + redirecting = true + if err := c.handleRedirect(ctx, resp); err != nil { return duration, err } - // rewrite the scheme for the sake of tolerance - if redirect.Scheme == "http" { - redirect.Scheme = "ws" - } else if redirect.Scheme == "https" { - redirect.Scheme = "wss" - } - c.common.Logger.Debugf(ctx, "%d redirect to %s", resp.StatusCode, redirect) - // Set the URL to the redirect, so that it connects to it on the - // next cycle. - c.url = redirect } else { c.common.Logger.Errorf(ctx, "Server responded with status=%v", resp.Status) } diff --git a/client/wsclient_test.go b/client/wsclient_test.go index 14506b30..cb964205 100644 --- a/client/wsclient_test.go +++ b/client/wsclient_test.go @@ -2,6 +2,7 @@ package client import ( "context" + "errors" "fmt" "net/http" "net/http/httptest" @@ -13,6 +14,7 @@ import ( "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" @@ -322,12 +324,44 @@ func errServer() *httptest.Server { })) } +type checkRedirectMock struct { + mock.Mock + t testing.TB + viaLen int +} + +func (c *checkRedirectMock) CheckRedirect(req *http.Request, via []*http.Response) error { + if req == nil { + c.t.Error("nil request in CheckRedirect") + } + if len(via) > c.viaLen { + c.t.Error("via should be shorter than viaLen") + } + location, err := via[len(via)-1].Location() + if err != nil { + c.t.Error(err) + } + // the URL of the request should match the location header of the last response + assert.Equal(c.t, req.URL, location, "request URL should equal the location in the response") + return c.Called(req, via).Error(0) +} + +func mockRedirect(t testing.TB, viaLen int, err error) *checkRedirectMock { + m := &checkRedirectMock{ + t: t, + viaLen: viaLen, + } + m.On("CheckRedirect", mock.Anything, mock.Anything).Return(err) + return m +} + func TestRedirectWS(t *testing.T) { redirectee := internal.StartMockServer(t) tests := []struct { - Name string - Redirector *httptest.Server - ExpError bool + Name string + Redirector *httptest.Server + ExpError bool + MockRedirect *checkRedirectMock }{ { Name: "redirect ws scheme", @@ -342,6 +376,17 @@ func TestRedirectWS(t *testing.T) { Redirector: errServer(), ExpError: true, }, + { + Name: "check redirect", + Redirector: redirectServer("ws://"+redirectee.Endpoint, 302), + MockRedirect: mockRedirect(t, 1, nil), + }, + { + Name: "check redirect returns error", + Redirector: redirectServer("ws://"+redirectee.Endpoint, 302), + MockRedirect: mockRedirect(t, 1, errors.New("hello")), + ExpError: true, + }, } for _, test := range tests { @@ -357,6 +402,8 @@ func TestRedirectWS(t *testing.T) { settings := types.StartSettings{ Callbacks: types.Callbacks{ OnConnect: func(ctx context.Context) { + Callbacks: &types.Callbacks{ + OnConnectFunc: func(ctx context.Context) { atomic.StoreInt64(&connected, 1) }, OnConnectFailed: func(ctx context.Context, err error) { @@ -366,6 +413,9 @@ func TestRedirectWS(t *testing.T) { }, }, } + if test.MockRedirect != nil { + settings.Callbacks.(*types.CallbacksStruct).CheckRedirectFunc = test.MockRedirect.CheckRedirect + } reURL, err := url.Parse(test.Redirector.URL) assert.NoError(t, err) reURL.Scheme = "ws" @@ -388,6 +438,10 @@ func TestRedirectWS(t *testing.T) { // Stop the client. err = client.Stop(context.Background()) assert.NoError(t, err) + + if test.MockRedirect != nil { + test.MockRedirect.AssertCalled(t, "CheckRedirect", mock.Anything, mock.Anything) + } }) } } @@ -411,6 +465,7 @@ func TestRedirectWSFollowChain(t *testing.T) { // Start an OpAMP/WebSocket client. var connected int64 var connectErr atomic.Value + mr := mockRedirect(t, 2, nil) settings := types.StartSettings{ Callbacks: types.CallbacksStruct{ OnConnectFunc: func(ctx context.Context) { @@ -421,6 +476,7 @@ func TestRedirectWSFollowChain(t *testing.T) { connectErr.Store(err) } }, + CheckRedirectFunc: mr.CheckRedirect, }, } reURL, err := url.Parse(redirector.URL) diff --git a/go.mod b/go.mod index 2742c8a9..4b9746d1 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.22 require ( github.com/cenkalti/backoff/v4 v4.3.0 + github.com/google/go-cmp v0.5.6 github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.3 github.com/stretchr/testify v1.10.0 @@ -12,8 +13,8 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/google/go-cmp v0.5.6 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 3390120c..ea122d10 100644 --- a/go.sum +++ b/go.sum @@ -12,6 +12,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= From da944f3acd32058ce502f5f3187eedb01606ab95 Mon Sep 17 00:00:00 2001 From: Eric Chlebek Date: Thu, 11 Apr 2024 15:36:35 -0700 Subject: [PATCH 3/5] Clean up lint --- client/wsclient_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/client/wsclient_test.go b/client/wsclient_test.go index cb964205..9dbbf804 100644 --- a/client/wsclient_test.go +++ b/client/wsclient_test.go @@ -333,6 +333,7 @@ type checkRedirectMock struct { func (c *checkRedirectMock) CheckRedirect(req *http.Request, via []*http.Response) error { if req == nil { c.t.Error("nil request in CheckRedirect") + return errors.New("nil request in CheckRedirect") } if len(via) > c.viaLen { c.t.Error("via should be shorter than viaLen") From 22851bbf0f7406115ad96e0d2be55490f5d5b647 Mon Sep 17 00:00:00 2001 From: Eric Chlebek Date: Thu, 28 Nov 2024 18:46:53 -0800 Subject: [PATCH 4/5] Implement CheckRedirect for HTTP This commit adds support for a CheckRedirect callback to the HTTP opamp client. It also unifies the API for CheckRedirect between WS and HTTP, so that the same callback can be used in either circumstance. Signed-off-by: Eric Chlebek --- client/httpclient_test.go | 88 +++++++++++++++++++++++++++++++++++ client/internal/httpsender.go | 8 ++++ client/types/callbacks.go | 16 +++++-- client/wsclient.go | 12 ++++- client/wsclient_test.go | 39 +++++++++------- 5 files changed, 141 insertions(+), 22 deletions(-) diff --git a/client/httpclient_test.go b/client/httpclient_test.go index a3845c45..4807d7a6 100644 --- a/client/httpclient_test.go +++ b/client/httpclient_test.go @@ -3,13 +3,17 @@ package client import ( "compress/gzip" "context" + "errors" "io" "net/http" + "net/http/httptest" + "net/url" "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "google.golang.org/protobuf/proto" "github.com/open-telemetry/opamp-go/client/internal" @@ -223,3 +227,87 @@ func TestHTTPClientStartWithZeroHeartbeatInterval(t *testing.T) { // Shutdown the Server. srv.Close() } + +func mockRedirectHTTP(t testing.TB, viaLen int, err error) *checkRedirectMock { + m := &checkRedirectMock{ + t: t, + viaLen: viaLen, + http: true, + } + m.On("CheckRedirect", mock.Anything, mock.Anything, mock.Anything).Return(err) + return m +} + +func TestRedirectHTTP(t *testing.T) { + redirectee := internal.StartMockServer(t) + tests := []struct { + Name string + Redirector *httptest.Server + ExpError bool + MockRedirect *checkRedirectMock + }{ + { + Name: "simple redirect", + Redirector: redirectServer("http://"+redirectee.Endpoint, 302), + }, + { + Name: "check redirect", + Redirector: redirectServer("http://"+redirectee.Endpoint, 302), + MockRedirect: mockRedirectHTTP(t, 1, nil), + }, + { + Name: "check redirect returns error", + Redirector: redirectServer("http://"+redirectee.Endpoint, 302), + MockRedirect: mockRedirect(t, 1, errors.New("hello")), + ExpError: true, + }, + } + + for _, test := range tests { + t.Run(test.Name, func(t *testing.T) { + var connectErr atomic.Value + var connected atomic.Value + + settings := &types.StartSettings{ + Callbacks: types.Callbacks{ + OnConnect: func(ctx context.Context) { + connected.Store(1) + }, + OnConnectFailed: func(ctx context.Context, err error) { + connectErr.Store(err) + }, + }, + } + if test.MockRedirect != nil { + settings.Callbacks = types.Callbacks{ + OnConnect: func(ctx context.Context) { + connected.Store(1) + }, + OnConnectFailed: func(ctx context.Context, err error) { + connectErr.Store(err) + }, + CheckRedirect: test.MockRedirect.CheckRedirect, + } + } + reURL, _ := url.Parse(test.Redirector.URL) // err can't be non-nil + settings.OpAMPServerURL = reURL.String() + client := NewHTTP(nil) + prepareClient(t, settings, client) + + err := client.Start(context.Background(), *settings) + if err != nil { + t.Fatal(err) + } + defer client.Stop(context.Background()) + // Wait for connection to be established. + eventually(t, func() bool { + return connected.Load() != nil || connectErr.Load() != nil + }) + if test.ExpError && connectErr.Load() == nil { + t.Error("expected non-nil error") + } else if err := connectErr.Load(); !test.ExpError && err != nil { + t.Fatal(err) + } + }) + } +} diff --git a/client/internal/httpsender.go b/client/internal/httpsender.go index 502bf7e4..a97e1311 100644 --- a/client/internal/httpsender.go +++ b/client/internal/httpsender.go @@ -98,6 +98,14 @@ func (h *HTTPSender) Run( h.callbacks = callbacks h.receiveProcessor = newReceivedProcessor(h.logger, callbacks, h, clientSyncedState, packagesStateProvider, capabilities, packageSyncMutex) + // we need to detect if the redirect was ever set, if not, we want default behaviour + if callbacks.CheckRedirect != nil { + h.client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + // viaResp only non-nil for ws client + return callbacks.CheckRedirect(req, via, nil) + } + } + for { pollingTimer := time.NewTimer(time.Millisecond * time.Duration(atomic.LoadInt64(&h.pollingIntervalMs))) select { diff --git a/client/types/callbacks.go b/client/types/callbacks.go index 9b08dd07..801f91cd 100644 --- a/client/types/callbacks.go +++ b/client/types/callbacks.go @@ -112,10 +112,18 @@ type Callbacks struct { // OnCommand is called when the Server requests that the connected Agent perform a command. OnCommand func(ctx context.Context, command *protobufs.ServerToAgentCommand) error - // CheckRedirect is called before following a redirect. It is similar in - // nature to the CheckRedirect in net/http's Client. If the value is nil, - // then the http client's CheckRedirect will not be altered. - CheckRedirect func(req *http.Request, via []*http.Response) error + // CheckRedirect is called before following a redirect, allowing the client + // the opportunity to observe the redirect chain, and optionally terminate + // following redirects early. + // + // CheckRedirect is intended to be similar, although not exactly equivalent, + // to net/http.Client's CheckRedirect feature. Unlike in net/http, the via + // parameter is a slice of HTTP responses, instead of requests. This gives + // an opportunity to users to know what the exact response headers and + // status were. The request itself can be obtained from the response. + // + // The responses in the via parameter are passed with their bodies closed. + CheckRedirect func(req *http.Request, viaReq []*http.Request, via []*http.Response) error } func (c *Callbacks) SetDefaults() { diff --git a/client/wsclient.go b/client/wsclient.go index 22b2228b..f19d8ab4 100644 --- a/client/wsclient.go +++ b/client/wsclient.go @@ -157,6 +157,14 @@ func (c *wsClient) SendCustomMessage(message *protobufs.CustomMessage) (messageS return c.common.SendCustomMessage(message) } +func viaReq(resps []*http.Response) []*http.Request { + reqs := make([]*http.Request, 0, len(resps)) + for _, resp := range resps { + reqs = append(reqs, resp.Request) + } + return reqs +} + // handleRedirect checks a failed websocket upgrade response for a 3xx response // and a Location header. If found, it sets the URL to the location found in the // header so that it is tried on the next retry, instead of the current URL. @@ -182,7 +190,7 @@ func (c *wsClient) handleRedirect(ctx context.Context, resp *http.Response) erro // if CheckRedirect results in an error, it gets returned, terminating // redirection. As with stdlib, the error is wrapped in url.Error. if c.common.Callbacks.CheckRedirect != nil { - if err := c.common.Callbacks.CheckRedirect(nextRequest, c.responseChain); err != nil { + if err := c.common.Callbacks.CheckRedirect(nextRequest, viaReq(c.responseChain), c.responseChain); err != nil { return &url.Error{ Op: "Get", URL: nextRequest.URL.String(), @@ -215,7 +223,7 @@ func (c *wsClient) tryConnectOnce(ctx context.Context) (retryAfter sharedinterna defer func() { if err != nil && !redirecting { c.responseChain = nil - if c.common.Callbacks != nil && !c.common.IsStopping() { + if !c.common.IsStopping() { c.common.Callbacks.OnConnectFailed(ctx, err) } } diff --git a/client/wsclient_test.go b/client/wsclient_test.go index 9dbbf804..436ceb55 100644 --- a/client/wsclient_test.go +++ b/client/wsclient_test.go @@ -328,22 +328,31 @@ type checkRedirectMock struct { mock.Mock t testing.TB viaLen int + http bool } -func (c *checkRedirectMock) CheckRedirect(req *http.Request, via []*http.Response) error { +func (c *checkRedirectMock) CheckRedirect(req *http.Request, viaReq []*http.Request, via []*http.Response) error { if req == nil { c.t.Error("nil request in CheckRedirect") return errors.New("nil request in CheckRedirect") } - if len(via) > c.viaLen { - c.t.Error("via should be shorter than viaLen") + if len(viaReq) > c.viaLen { + c.t.Error("viaReq should be shorter than viaLen") } - location, err := via[len(via)-1].Location() - if err != nil { - c.t.Error(err) + if !c.http { + // websocket transport + if len(via) > c.viaLen { + c.t.Error("via should be shorter than viaLen") + } + } + if !c.http && len(via) > 0 { + location, err := via[len(via)-1].Location() + if err != nil { + c.t.Error(err) + } + // the URL of the request should match the location header of the last response + assert.Equal(c.t, req.URL, location, "request URL should equal the location in the response") } - // the URL of the request should match the location header of the last response - assert.Equal(c.t, req.URL, location, "request URL should equal the location in the response") return c.Called(req, via).Error(0) } @@ -352,7 +361,7 @@ func mockRedirect(t testing.TB, viaLen int, err error) *checkRedirectMock { t: t, viaLen: viaLen, } - m.On("CheckRedirect", mock.Anything, mock.Anything).Return(err) + m.On("CheckRedirect", mock.Anything, mock.Anything, mock.Anything).Return(err) return m } @@ -403,8 +412,6 @@ func TestRedirectWS(t *testing.T) { settings := types.StartSettings{ Callbacks: types.Callbacks{ OnConnect: func(ctx context.Context) { - Callbacks: &types.Callbacks{ - OnConnectFunc: func(ctx context.Context) { atomic.StoreInt64(&connected, 1) }, OnConnectFailed: func(ctx context.Context, err error) { @@ -415,7 +422,7 @@ func TestRedirectWS(t *testing.T) { }, } if test.MockRedirect != nil { - settings.Callbacks.(*types.CallbacksStruct).CheckRedirectFunc = test.MockRedirect.CheckRedirect + settings.Callbacks.CheckRedirect = test.MockRedirect.CheckRedirect } reURL, err := url.Parse(test.Redirector.URL) assert.NoError(t, err) @@ -468,16 +475,16 @@ func TestRedirectWSFollowChain(t *testing.T) { var connectErr atomic.Value mr := mockRedirect(t, 2, nil) settings := types.StartSettings{ - Callbacks: types.CallbacksStruct{ - OnConnectFunc: func(ctx context.Context) { + Callbacks: types.Callbacks{ + OnConnect: func(ctx context.Context) { atomic.StoreInt64(&connected, 1) }, - OnConnectFailedFunc: func(ctx context.Context, err error) { + OnConnectFailed: func(ctx context.Context, err error) { if err != websocket.ErrBadHandshake { connectErr.Store(err) } }, - CheckRedirectFunc: mr.CheckRedirect, + CheckRedirect: mr.CheckRedirect, }, } reURL, err := url.Parse(redirector.URL) From 42050ae3b388c4873b17bcb77ef2b35c32414038 Mon Sep 17 00:00:00 2001 From: Eric Chlebek Date: Thu, 2 Jan 2025 12:19:40 -0800 Subject: [PATCH 5/5] Use mockRedirectHTTP in http tests Signed-off-by: Eric Chlebek --- client/httpclient_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/httpclient_test.go b/client/httpclient_test.go index 4807d7a6..fc670411 100644 --- a/client/httpclient_test.go +++ b/client/httpclient_test.go @@ -258,7 +258,7 @@ func TestRedirectHTTP(t *testing.T) { { Name: "check redirect returns error", Redirector: redirectServer("http://"+redirectee.Endpoint, 302), - MockRedirect: mockRedirect(t, 1, errors.New("hello")), + MockRedirect: mockRedirectHTTP(t, 1, errors.New("hello")), ExpError: true, }, }