From 6000b9d50ce0bb2ece0cd0504c884e58c881326a Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Mon, 25 Mar 2024 18:43:56 -0700 Subject: [PATCH] Fix a race condition when browsers would send two requests Some browsers would send two requests, with one landing reliably between when the OAuth2Listener was closed (and thus, its channel was closed) and when the http server would be closed. This change solves this problem by closing the channel after the request is received and only ever processing a single request. Other requests will receive responses, but will be silently ignored --- cli/oauth2.go | 63 +++++++++++++++++------------------------ cli/oauth2_test.go | 70 +++++++++++++++++++++++++++------------------- 2 files changed, 67 insertions(+), 66 deletions(-) diff --git a/cli/oauth2.go b/cli/oauth2.go index 657fffa4..12a5cc30 100644 --- a/cli/oauth2.go +++ b/cli/oauth2.go @@ -13,6 +13,7 @@ import ( "net/http" "net/url" "strings" + "sync" "github.com/RobotsAndPencils/go-saml" "github.com/coreos/go-oidc" @@ -63,11 +64,6 @@ type OAuth2CallbackInfo struct { ErrorDescription string } -type OAuth2Listener struct { - Socket net.Listener - callbackCh chan OAuth2CallbackInfo -} - func ParseCallbackRequest(r *http.Request) (OAuth2CallbackInfo, error) { info := OAuth2CallbackInfo{ Error: r.FormValue("error"), @@ -79,44 +75,35 @@ func ParseCallbackRequest(r *http.Request) (OAuth2CallbackInfo, error) { return info, nil } -func NewOAuth2Listener(socket net.Listener) OAuth2Listener { - return OAuth2Listener{ - Socket: socket, - callbackCh: make(chan OAuth2CallbackInfo), - } +// OAuth2Listener will listen for a single callback request from a web server and return the code if it matched, or an error otherwise. +type OAuth2Listener struct { + once sync.Once + callbackCh chan OAuth2CallbackInfo } -func (o OAuth2Listener) Close() error { - if o.callbackCh != nil { - close(o.callbackCh) - } - if o.Socket != nil { - return o.Socket.Close() +func NewOAuth2Listener() OAuth2Listener { + return OAuth2Listener{ + callbackCh: make(chan OAuth2CallbackInfo, 1), } - return nil } -func (o OAuth2Listener) ServeHTTP(w http.ResponseWriter, r *http.Request) { - info, err := ParseCallbackRequest(r) - if err == nil { - // The only errors that might occur would be incorrectly formatted requests, which we will silently drop. - o.callbackCh <- info - } +func (o *OAuth2Listener) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // This can sometimes be called multiple times, depending on the browser. + // We will simply ignore any other requests and only serve the first. + o.once.Do(func() { + info, err := ParseCallbackRequest(r) + if err == nil { + // The only errors that might occur would be incorrectly formatted requests, which we will silently drop. + o.callbackCh <- info + } + close(o.callbackCh) + }) - // This is displayed to the end user in their browser. + // We still want to provide feedback to the end-user. fmt.Fprintln(w, "You may close this window now.") } -func (o OAuth2Listener) Listen() error { - err := http.Serve(o.Socket, o) - if errors.Is(err, http.ErrServerClosed) { - return nil - } - - return err -} - -func (o OAuth2Listener) WaitForAuthorizationCode(ctx context.Context, state string) (string, error) { +func (o *OAuth2Listener) WaitForAuthorizationCode(ctx context.Context, state string) (string, error) { select { case info := <-o.callbackCh: if info.Error != "" { @@ -218,6 +205,7 @@ func (r RedirectionFlowHandler) HandlePendingSession(ctx context.Context, challe if err != nil { return nil, err } + defer sock.Close() _, port, err := net.SplitHostPort(sock.Addr().String()) if err != nil { @@ -231,10 +219,9 @@ func (r RedirectionFlowHandler) HandlePendingSession(ctx context.Context, challe oauth2.SetAuthURLParam("code_challenge", challenge.Challenge), ) - listener := NewOAuth2Listener(sock) - defer listener.Close() - // This error can be ignored. - go listener.Listen() + listener := NewOAuth2Listener() + // TODO: This error probably should not be ignored if it is not http.ErrServerClosed + go http.Serve(sock, &listener) if err := r.OnDisplayURL(url); err != nil { // This is unlikely to ever happen diff --git a/cli/oauth2_test.go b/cli/oauth2_test.go index c2e97af5..39e90300 100644 --- a/cli/oauth2_test.go +++ b/cli/oauth2_test.go @@ -12,36 +12,29 @@ import ( "github.com/stretchr/testify/require" ) -// This test ensures that the success flow of receiving an authorization code from a HTTP handler works correctly. -func Test_OAuth2Listener_WaitForAuthorizationCodeWorksCorrectly(t *testing.T) { - // TODO: Initializing private fields means we should probably refactor OAuth2Listener's constructor to not require net.Listener, - // but instead to have net.Listener be a thing in Listen(); - // - // Users should be able to use ServeHTTP. - listener := OAuth2Listener{ - callbackCh: make(chan OAuth2CallbackInfo), +func sendRequestToListener(listener *OAuth2Listener, values url.Values) { + uri := url.URL{ + Scheme: "http", + Host: "localhost", + Path: "/oauth2/callback", + RawQuery: values.Encode(), } + w := httptest.NewRecorder() + req := httptest.NewRequest("GET", uri.String(), nil) + listener.ServeHTTP(w, req) +} + +// This test ensures that the success flow of receiving an authorization code from a HTTP handler works correctly. +func Test_OAuth2Listener_WaitForAuthorizationCodeWorksCorrectly(t *testing.T) { + listener := NewOAuth2Listener() expectedState := "state goes here" expectedCode := "code goes here" - go func() { - w := httptest.NewRecorder() - values := url.Values{ - "code": []string{expectedCode}, - "state": []string{expectedState}, - } - - uri := url.URL{ - Scheme: "http", - Host: "localhost", - Path: "/oauth2/callback", - RawQuery: values.Encode(), - } - - req := httptest.NewRequest("GET", uri.String(), nil) - listener.ServeHTTP(w, req) - }() + go sendRequestToListener(&listener, url.Values{ + "code": []string{expectedCode}, + "state": []string{expectedState}, + }) deadline, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) code, err := listener.WaitForAuthorizationCode(deadline, expectedState) @@ -49,18 +42,39 @@ func Test_OAuth2Listener_WaitForAuthorizationCodeWorksCorrectly(t *testing.T) { cancel() assert.Equal(t, expectedCode, code) - assert.NoError(t, listener.Close()) } func Test_OAuth2Listener_ZeroValueNeverPanics(t *testing.T) { var listener OAuth2Listener deadline, _ := context.WithTimeout(context.Background(), 500*time.Millisecond) _, err := listener.WaitForAuthorizationCode(deadline, "") + // This will timeout because the OAuth2Listener will forever listen on a nil channel assert.ErrorIs(t, context.DeadlineExceeded, err) - assert.NoError(t, listener.Close()) } -// This test is going to be flaky because processes may open ports outside of our control. +// Test_OAuth2Listener_MultipleRequestsDoesNotCausePanic prevents an issue where OAuth2Listener would send a request to a closed channel +func Test_OAuth2Listener_MultipleRequestsDoesNotCausePanic(t *testing.T) { + listener := NewOAuth2Listener() + expectedState := "state goes here" + expectedCode := "code goes here" + + go sendRequestToListener(&listener, url.Values{ + "code": []string{expectedCode}, + "state": []string{expectedState}, + }) + + go sendRequestToListener(&listener, url.Values{ + "code": []string{"not the expected code and should be discarded"}, + "state": []string{"not the expected state and should be discarded"}, + }) + + deadlineCtx, _ := context.WithTimeout(context.Background(), 500*time.Millisecond) + code, err := listener.WaitForAuthorizationCode(deadlineCtx, expectedState) + assert.NoError(t, err) + assert.Equal(t, expectedCode, code) +} + +// Test_ListenAnyPort_WorksCorrectly is going to be flaky because processes may open ports outside of our control. func Test_ListenAnyPort_WorksCorrectly(t *testing.T) { ports := []string{"58080", "58081", "58082", "58083"} socket, err := net.Listen("tcp4", net.JoinHostPort("127.0.0.1", ports[0]))