Skip to content

Commit

Permalink
Make it harder to cause a race condition to begin with
Browse files Browse the repository at this point in the history
Instead of returning a struct with mutable struct, we can make it harder
to cause a race condition by wrapping all of that state instead of a
function closure which cannot be modified from the outside, and by using
sync.Once() to ensure that only one request is ever handled.

We only need to have the sync.Once() for close()ing the channel to
permit a user to permit a user to free the resources manually. It is
still possible to cause a panic by closing this channel before receiving
a request.
  • Loading branch information
punmechanic committed Mar 26, 2024
1 parent 6000b9d commit 4aedf66
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 86 deletions.
120 changes: 63 additions & 57 deletions cli/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,67 +57,68 @@ func DiscoverOAuth2Config(ctx context.Context, domain, clientID string) (*oauth2
return &cfg, nil
}

type OAuth2CallbackInfo struct {
Code string
State string
Error string
ErrorDescription string
}

func ParseCallbackRequest(r *http.Request) (OAuth2CallbackInfo, error) {
info := OAuth2CallbackInfo{
Error: r.FormValue("error"),
ErrorDescription: r.FormValue("error_description"),
State: r.FormValue("state"),
Code: r.FormValue("code"),
}

return info, nil
// OAuth2CallbackState encapsulates all of the information from an oauth2 callback.
//
// To retrieve the Code from the struct, you must use the Verify(string) function.
type OAuth2CallbackState struct {
code string
state string
errorMessage string
errorDescription string
}

// OAuth2Listener will listen for a single callback request from a web server and return the code if it matched, or an error otherwise.
type OAuth2Listener struct {
once sync.Once
callbackCh chan OAuth2CallbackInfo
// FromRequest parses the given http.Request and populates the OAuth2CallbackState with those values.
func (o *OAuth2CallbackState) FromRequest(r *http.Request) {
o.errorMessage = r.FormValue("error")
o.errorDescription = r.FormValue("error_description")
o.state = r.FormValue("state")
o.code = r.FormValue("code")
}

func NewOAuth2Listener() OAuth2Listener {
return OAuth2Listener{
callbackCh: make(chan OAuth2CallbackInfo, 1),
// Verify safely compares the given state with the state from the OAuth2 callback.
//
// If they match, the code is returned, with a nil value. Otherwise, an empty string and an error is returned.
func (o OAuth2CallbackState) Verify(expectedState string) (string, error) {
if o.errorMessage != "" {
return "", OAuth2Error{Reason: o.errorMessage, Description: o.errorDescription}
}
}

func (o *OAuth2Listener) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// This can sometimes be called multiple times, depending on the browser.
// We will simply ignore any other requests and only serve the first.
o.once.Do(func() {
info, err := ParseCallbackRequest(r)
if err == nil {
// The only errors that might occur would be incorrectly formatted requests, which we will silently drop.
o.callbackCh <- info
}
close(o.callbackCh)
})
if strings.Compare(o.state, expectedState) != 0 {
return "", OAuth2Error{Reason: "invalid_state", Description: "state mismatch"}
}

// We still want to provide feedback to the end-user.
fmt.Fprintln(w, "You may close this window now.")
return o.code, nil
}

func (o *OAuth2Listener) WaitForAuthorizationCode(ctx context.Context, state string) (string, error) {
select {
case info := <-o.callbackCh:
if info.Error != "" {
return "", OAuth2Error{Reason: info.Error, Description: info.ErrorDescription}
}

if strings.Compare(info.State, state) != 0 {
return "", OAuth2Error{Reason: "invalid_state", Description: "state mismatch"}
}
// OAuth2CallbackHandler returns a http.Handler, channel and function triple.
//
// The http handler will accept exactly one request, which it will assume is an OAuth2 callback, parse it into an OAuth2CallbackState and then provide it to the given channel. Subsequent requests will be silently ignored.
//
// The function may be called to ensure that the channel is closed. The channel is closed when a request is received. In general, it is a good idea to ensure this function is called in a defer() block.
func OAuth2CallbackHandler() (http.Handler, <-chan OAuth2CallbackState, func()) {
ch := make(chan OAuth2CallbackState, 1)
var reqHandle, closeHandle sync.Once
closeFn := func() {
closeHandle.Do(func() {
close(ch)
})
}

return info.Code, nil
case <-ctx.Done():
return "", ctx.Err()
fn := func(w http.ResponseWriter, r *http.Request) {
// This can sometimes be called multiple times, depending on the browser.
// We will simply ignore any other requests and only serve the first.
reqHandle.Do(func() {
var state OAuth2CallbackState
state.FromRequest(r)
ch <- state
closeFn()
})

// We still want to provide feedback to the end-user.
fmt.Fprintln(w, "You may close this window now.")
}

return http.HandlerFunc(fn), ch, closeFn
}

type OAuth2Error struct {
Expand Down Expand Up @@ -219,21 +220,26 @@ func (r RedirectionFlowHandler) HandlePendingSession(ctx context.Context, challe
oauth2.SetAuthURLParam("code_challenge", challenge.Challenge),
)

listener := NewOAuth2Listener()
callbackHandler, ch, cancel := OAuth2CallbackHandler()
// TODO: This error probably should not be ignored if it is not http.ErrServerClosed
go http.Serve(sock, &listener)
go http.Serve(sock, callbackHandler)
defer cancel()

if err := r.OnDisplayURL(url); err != nil {
// This is unlikely to ever happen
return nil, fmt.Errorf("failed to display link: %w", err)
}

code, err := listener.WaitForAuthorizationCode(ctx, state)
if err != nil {
return nil, fmt.Errorf("failed to get authorization code: %w", err)
select {
case info := <-ch:
code, err := info.Verify(state)
if err != nil {
return nil, fmt.Errorf("failed to get authorization code: %w", err)
}
return r.Config.Exchange(ctx, code, oauth2.SetAuthURLParam("code_verifier", challenge.Verifier))
case <-ctx.Done():
return nil, ctx.Err()
}

return r.Config.Exchange(ctx, code, oauth2.SetAuthURLParam("code_verifier", challenge.Verifier))
}

func ExchangeAccessTokenForWebSSOToken(ctx context.Context, client *http.Client, oauthCfg *oauth2.Config, token *TokenSet, applicationID string) (*oauth2.Token, error) {
Expand Down
80 changes: 51 additions & 29 deletions cli/oauth2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@ package main
import (
"context"
"net"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func sendRequestToListener(listener *OAuth2Listener, values url.Values) {
func sendOAuth2CallbackRequest(handler http.Handler, values url.Values) {
uri := url.URL{
Scheme: "http",
Host: "localhost",
Expand All @@ -22,56 +22,78 @@ func sendRequestToListener(listener *OAuth2Listener, values url.Values) {

w := httptest.NewRecorder()
req := httptest.NewRequest("GET", uri.String(), nil)
listener.ServeHTTP(w, req)
handler.ServeHTTP(w, req)
}

// This test ensures that the success flow of receiving an authorization code from a HTTP handler works correctly.
func Test_OAuth2Listener_WaitForAuthorizationCodeWorksCorrectly(t *testing.T) {
listener := NewOAuth2Listener()
func Test_OAuth2CallbackHandler_YieldsCorrectlyFormattedState(t *testing.T) {
handler, ch, cancel := OAuth2CallbackHandler()
t.Cleanup(func() {
cancel()
})

expectedState := "state goes here"
expectedCode := "code goes here"

go sendRequestToListener(&listener, url.Values{
go sendOAuth2CallbackRequest(handler, url.Values{
"code": []string{expectedCode},
"state": []string{expectedState},
})

deadline, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
code, err := listener.WaitForAuthorizationCode(deadline, expectedState)
callbackState := <-ch
code, err := callbackState.Verify(expectedState)
assert.NoError(t, err)
cancel()

assert.Equal(t, expectedCode, code)
}

func Test_OAuth2Listener_ZeroValueNeverPanics(t *testing.T) {
var listener OAuth2Listener
deadline, _ := context.WithTimeout(context.Background(), 500*time.Millisecond)
_, err := listener.WaitForAuthorizationCode(deadline, "")
// This will timeout because the OAuth2Listener will forever listen on a nil channel
assert.ErrorIs(t, context.DeadlineExceeded, err)
func Test_OAuth2CallbackState_VerifyWorksCorrectly(t *testing.T) {
t.Run("happy path", func(t *testing.T) {
expectedState := "state goes here"
expectedCode := "code goes here"
callbackState := OAuth2CallbackState{
code: expectedCode,
state: expectedState,
}
code, err := callbackState.Verify(expectedState)
assert.NoError(t, err)
assert.Equal(t, expectedCode, code)
})

t.Run("unhappy path", func(t *testing.T) {
expectedState := "state goes here"
expectedCode := "code goes here"
callbackState := OAuth2CallbackState{
code: expectedCode,
state: expectedState,
}
_, err := callbackState.Verify("mismatching state")
var oauthErr OAuth2Error
assert.ErrorAs(t, err, &oauthErr)
assert.Equal(t, "invalid_state", oauthErr.Reason)
})
}

// Test_OAuth2Listener_MultipleRequestsDoesNotCausePanic prevents an issue where OAuth2Listener would send a request to a closed channel
func Test_OAuth2Listener_MultipleRequestsDoesNotCausePanic(t *testing.T) {
listener := NewOAuth2Listener()
expectedState := "state goes here"
expectedCode := "code goes here"
handler, ch, cancel := OAuth2CallbackHandler()
t.Cleanup(func() {
cancel()
})

go sendRequestToListener(&listener, url.Values{
"code": []string{expectedCode},
"state": []string{expectedState},
go sendOAuth2CallbackRequest(handler, url.Values{
// We send empty values because we don't care about processing in this test
"code": []string{""},
"state": []string{""},
})

go sendRequestToListener(&listener, url.Values{
// We drain the channel of the first request so the handler completes.
// Without this step, we would get 'stuck' in the sync.Once().
<-ch

// We send this request synchronously to ensure that any panics are caught during the test.
sendOAuth2CallbackRequest(handler, url.Values{
"code": []string{"not the expected code and should be discarded"},
"state": []string{"not the expected state and should be discarded"},
})

deadlineCtx, _ := context.WithTimeout(context.Background(), 500*time.Millisecond)
code, err := listener.WaitForAuthorizationCode(deadlineCtx, expectedState)
assert.NoError(t, err)
assert.Equal(t, expectedCode, code)
}

// Test_ListenAnyPort_WorksCorrectly is going to be flaky because processes may open ports outside of our control.
Expand Down

0 comments on commit 4aedf66

Please sign in to comment.