From 4aedf6664c132aa30326c42b600d8b7bdf2e1f18 Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Mon, 25 Mar 2024 21:26:32 -0700 Subject: [PATCH] Make it harder to cause a race condition to begin with Instead of returning a struct with mutable struct, we can make it harder to cause a race condition by wrapping all of that state instead of a function closure which cannot be modified from the outside, and by using sync.Once() to ensure that only one request is ever handled. We only need to have the sync.Once() for close()ing the channel to permit a user to permit a user to free the resources manually. It is still possible to cause a panic by closing this channel before receiving a request. --- cli/oauth2.go | 120 ++++++++++++++++++++++++--------------------- cli/oauth2_test.go | 80 +++++++++++++++++++----------- 2 files changed, 114 insertions(+), 86 deletions(-) diff --git a/cli/oauth2.go b/cli/oauth2.go index 12a5cc30..21f349d6 100644 --- a/cli/oauth2.go +++ b/cli/oauth2.go @@ -57,67 +57,68 @@ func DiscoverOAuth2Config(ctx context.Context, domain, clientID string) (*oauth2 return &cfg, nil } -type OAuth2CallbackInfo struct { - Code string - State string - Error string - ErrorDescription string -} - -func ParseCallbackRequest(r *http.Request) (OAuth2CallbackInfo, error) { - info := OAuth2CallbackInfo{ - Error: r.FormValue("error"), - ErrorDescription: r.FormValue("error_description"), - State: r.FormValue("state"), - Code: r.FormValue("code"), - } - - return info, nil +// OAuth2CallbackState encapsulates all of the information from an oauth2 callback. +// +// To retrieve the Code from the struct, you must use the Verify(string) function. +type OAuth2CallbackState struct { + code string + state string + errorMessage string + errorDescription string } -// OAuth2Listener will listen for a single callback request from a web server and return the code if it matched, or an error otherwise. -type OAuth2Listener struct { - once sync.Once - callbackCh chan OAuth2CallbackInfo +// FromRequest parses the given http.Request and populates the OAuth2CallbackState with those values. +func (o *OAuth2CallbackState) FromRequest(r *http.Request) { + o.errorMessage = r.FormValue("error") + o.errorDescription = r.FormValue("error_description") + o.state = r.FormValue("state") + o.code = r.FormValue("code") } -func NewOAuth2Listener() OAuth2Listener { - return OAuth2Listener{ - callbackCh: make(chan OAuth2CallbackInfo, 1), +// Verify safely compares the given state with the state from the OAuth2 callback. +// +// If they match, the code is returned, with a nil value. Otherwise, an empty string and an error is returned. +func (o OAuth2CallbackState) Verify(expectedState string) (string, error) { + if o.errorMessage != "" { + return "", OAuth2Error{Reason: o.errorMessage, Description: o.errorDescription} } -} -func (o *OAuth2Listener) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // This can sometimes be called multiple times, depending on the browser. - // We will simply ignore any other requests and only serve the first. - o.once.Do(func() { - info, err := ParseCallbackRequest(r) - if err == nil { - // The only errors that might occur would be incorrectly formatted requests, which we will silently drop. - o.callbackCh <- info - } - close(o.callbackCh) - }) + if strings.Compare(o.state, expectedState) != 0 { + return "", OAuth2Error{Reason: "invalid_state", Description: "state mismatch"} + } - // We still want to provide feedback to the end-user. - fmt.Fprintln(w, "You may close this window now.") + return o.code, nil } -func (o *OAuth2Listener) WaitForAuthorizationCode(ctx context.Context, state string) (string, error) { - select { - case info := <-o.callbackCh: - if info.Error != "" { - return "", OAuth2Error{Reason: info.Error, Description: info.ErrorDescription} - } - - if strings.Compare(info.State, state) != 0 { - return "", OAuth2Error{Reason: "invalid_state", Description: "state mismatch"} - } +// OAuth2CallbackHandler returns a http.Handler, channel and function triple. +// +// The http handler will accept exactly one request, which it will assume is an OAuth2 callback, parse it into an OAuth2CallbackState and then provide it to the given channel. Subsequent requests will be silently ignored. +// +// The function may be called to ensure that the channel is closed. The channel is closed when a request is received. In general, it is a good idea to ensure this function is called in a defer() block. +func OAuth2CallbackHandler() (http.Handler, <-chan OAuth2CallbackState, func()) { + ch := make(chan OAuth2CallbackState, 1) + var reqHandle, closeHandle sync.Once + closeFn := func() { + closeHandle.Do(func() { + close(ch) + }) + } - return info.Code, nil - case <-ctx.Done(): - return "", ctx.Err() + fn := func(w http.ResponseWriter, r *http.Request) { + // This can sometimes be called multiple times, depending on the browser. + // We will simply ignore any other requests and only serve the first. + reqHandle.Do(func() { + var state OAuth2CallbackState + state.FromRequest(r) + ch <- state + closeFn() + }) + + // We still want to provide feedback to the end-user. + fmt.Fprintln(w, "You may close this window now.") } + + return http.HandlerFunc(fn), ch, closeFn } type OAuth2Error struct { @@ -219,21 +220,26 @@ func (r RedirectionFlowHandler) HandlePendingSession(ctx context.Context, challe oauth2.SetAuthURLParam("code_challenge", challenge.Challenge), ) - listener := NewOAuth2Listener() + callbackHandler, ch, cancel := OAuth2CallbackHandler() // TODO: This error probably should not be ignored if it is not http.ErrServerClosed - go http.Serve(sock, &listener) + go http.Serve(sock, callbackHandler) + defer cancel() if err := r.OnDisplayURL(url); err != nil { // This is unlikely to ever happen return nil, fmt.Errorf("failed to display link: %w", err) } - code, err := listener.WaitForAuthorizationCode(ctx, state) - if err != nil { - return nil, fmt.Errorf("failed to get authorization code: %w", err) + select { + case info := <-ch: + code, err := info.Verify(state) + if err != nil { + return nil, fmt.Errorf("failed to get authorization code: %w", err) + } + return r.Config.Exchange(ctx, code, oauth2.SetAuthURLParam("code_verifier", challenge.Verifier)) + case <-ctx.Done(): + return nil, ctx.Err() } - - return r.Config.Exchange(ctx, code, oauth2.SetAuthURLParam("code_verifier", challenge.Verifier)) } func ExchangeAccessTokenForWebSSOToken(ctx context.Context, client *http.Client, oauthCfg *oauth2.Config, token *TokenSet, applicationID string) (*oauth2.Token, error) { diff --git a/cli/oauth2_test.go b/cli/oauth2_test.go index 39e90300..260f310d 100644 --- a/cli/oauth2_test.go +++ b/cli/oauth2_test.go @@ -3,16 +3,16 @@ package main import ( "context" "net" + "net/http" "net/http/httptest" "net/url" "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func sendRequestToListener(listener *OAuth2Listener, values url.Values) { +func sendOAuth2CallbackRequest(handler http.Handler, values url.Values) { uri := url.URL{ Scheme: "http", Host: "localhost", @@ -22,56 +22,78 @@ func sendRequestToListener(listener *OAuth2Listener, values url.Values) { w := httptest.NewRecorder() req := httptest.NewRequest("GET", uri.String(), nil) - listener.ServeHTTP(w, req) + handler.ServeHTTP(w, req) } -// This test ensures that the success flow of receiving an authorization code from a HTTP handler works correctly. -func Test_OAuth2Listener_WaitForAuthorizationCodeWorksCorrectly(t *testing.T) { - listener := NewOAuth2Listener() +func Test_OAuth2CallbackHandler_YieldsCorrectlyFormattedState(t *testing.T) { + handler, ch, cancel := OAuth2CallbackHandler() + t.Cleanup(func() { + cancel() + }) + expectedState := "state goes here" expectedCode := "code goes here" - go sendRequestToListener(&listener, url.Values{ + go sendOAuth2CallbackRequest(handler, url.Values{ "code": []string{expectedCode}, "state": []string{expectedState}, }) - deadline, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) - code, err := listener.WaitForAuthorizationCode(deadline, expectedState) + callbackState := <-ch + code, err := callbackState.Verify(expectedState) assert.NoError(t, err) - cancel() - assert.Equal(t, expectedCode, code) } -func Test_OAuth2Listener_ZeroValueNeverPanics(t *testing.T) { - var listener OAuth2Listener - deadline, _ := context.WithTimeout(context.Background(), 500*time.Millisecond) - _, err := listener.WaitForAuthorizationCode(deadline, "") - // This will timeout because the OAuth2Listener will forever listen on a nil channel - assert.ErrorIs(t, context.DeadlineExceeded, err) +func Test_OAuth2CallbackState_VerifyWorksCorrectly(t *testing.T) { + t.Run("happy path", func(t *testing.T) { + expectedState := "state goes here" + expectedCode := "code goes here" + callbackState := OAuth2CallbackState{ + code: expectedCode, + state: expectedState, + } + code, err := callbackState.Verify(expectedState) + assert.NoError(t, err) + assert.Equal(t, expectedCode, code) + }) + + t.Run("unhappy path", func(t *testing.T) { + expectedState := "state goes here" + expectedCode := "code goes here" + callbackState := OAuth2CallbackState{ + code: expectedCode, + state: expectedState, + } + _, err := callbackState.Verify("mismatching state") + var oauthErr OAuth2Error + assert.ErrorAs(t, err, &oauthErr) + assert.Equal(t, "invalid_state", oauthErr.Reason) + }) } // Test_OAuth2Listener_MultipleRequestsDoesNotCausePanic prevents an issue where OAuth2Listener would send a request to a closed channel func Test_OAuth2Listener_MultipleRequestsDoesNotCausePanic(t *testing.T) { - listener := NewOAuth2Listener() - expectedState := "state goes here" - expectedCode := "code goes here" + handler, ch, cancel := OAuth2CallbackHandler() + t.Cleanup(func() { + cancel() + }) - go sendRequestToListener(&listener, url.Values{ - "code": []string{expectedCode}, - "state": []string{expectedState}, + go sendOAuth2CallbackRequest(handler, url.Values{ + // We send empty values because we don't care about processing in this test + "code": []string{""}, + "state": []string{""}, }) - go sendRequestToListener(&listener, url.Values{ + // We drain the channel of the first request so the handler completes. + // Without this step, we would get 'stuck' in the sync.Once(). + <-ch + + // We send this request synchronously to ensure that any panics are caught during the test. + sendOAuth2CallbackRequest(handler, url.Values{ "code": []string{"not the expected code and should be discarded"}, "state": []string{"not the expected state and should be discarded"}, }) - - deadlineCtx, _ := context.WithTimeout(context.Background(), 500*time.Millisecond) - code, err := listener.WaitForAuthorizationCode(deadlineCtx, expectedState) - assert.NoError(t, err) - assert.Equal(t, expectedCode, code) } // Test_ListenAnyPort_WorksCorrectly is going to be flaky because processes may open ports outside of our control.