Skip to content

Commit

Permalink
Fix a race condition when browsers would send two requests
Browse files Browse the repository at this point in the history
Some browsers would send two requests, with one landing reliably between
when the OAuth2Listener was closed (and thus, its channel was closed)
and when the http server would be closed.

This change solves this problem by closing the channel after the
request is received and only ever processing a single request.

Other requests will receive responses, but will be silently ignored
  • Loading branch information
punmechanic committed Mar 26, 2024
1 parent 7555b79 commit 6000b9d
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 66 deletions.
63 changes: 25 additions & 38 deletions cli/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"net/http"
"net/url"
"strings"
"sync"

"github.com/RobotsAndPencils/go-saml"
"github.com/coreos/go-oidc"
Expand Down Expand Up @@ -63,11 +64,6 @@ type OAuth2CallbackInfo struct {
ErrorDescription string
}

type OAuth2Listener struct {
Socket net.Listener
callbackCh chan OAuth2CallbackInfo
}

func ParseCallbackRequest(r *http.Request) (OAuth2CallbackInfo, error) {
info := OAuth2CallbackInfo{
Error: r.FormValue("error"),
Expand All @@ -79,44 +75,35 @@ func ParseCallbackRequest(r *http.Request) (OAuth2CallbackInfo, error) {
return info, nil
}

func NewOAuth2Listener(socket net.Listener) OAuth2Listener {
return OAuth2Listener{
Socket: socket,
callbackCh: make(chan OAuth2CallbackInfo),
}
// OAuth2Listener will listen for a single callback request from a web server and return the code if it matched, or an error otherwise.
type OAuth2Listener struct {
once sync.Once
callbackCh chan OAuth2CallbackInfo
}

func (o OAuth2Listener) Close() error {
if o.callbackCh != nil {
close(o.callbackCh)
}
if o.Socket != nil {
return o.Socket.Close()
func NewOAuth2Listener() OAuth2Listener {
return OAuth2Listener{
callbackCh: make(chan OAuth2CallbackInfo, 1),
}
return nil
}

func (o OAuth2Listener) ServeHTTP(w http.ResponseWriter, r *http.Request) {
info, err := ParseCallbackRequest(r)
if err == nil {
// The only errors that might occur would be incorrectly formatted requests, which we will silently drop.
o.callbackCh <- info
}
func (o *OAuth2Listener) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// This can sometimes be called multiple times, depending on the browser.
// We will simply ignore any other requests and only serve the first.
o.once.Do(func() {
info, err := ParseCallbackRequest(r)
if err == nil {
// The only errors that might occur would be incorrectly formatted requests, which we will silently drop.
o.callbackCh <- info
}
close(o.callbackCh)
})

// This is displayed to the end user in their browser.
// We still want to provide feedback to the end-user.
fmt.Fprintln(w, "You may close this window now.")
}

func (o OAuth2Listener) Listen() error {
err := http.Serve(o.Socket, o)
if errors.Is(err, http.ErrServerClosed) {
return nil
}

return err
}

func (o OAuth2Listener) WaitForAuthorizationCode(ctx context.Context, state string) (string, error) {
func (o *OAuth2Listener) WaitForAuthorizationCode(ctx context.Context, state string) (string, error) {
select {
case info := <-o.callbackCh:
if info.Error != "" {
Expand Down Expand Up @@ -218,6 +205,7 @@ func (r RedirectionFlowHandler) HandlePendingSession(ctx context.Context, challe
if err != nil {
return nil, err
}
defer sock.Close()

_, port, err := net.SplitHostPort(sock.Addr().String())
if err != nil {
Expand All @@ -231,10 +219,9 @@ func (r RedirectionFlowHandler) HandlePendingSession(ctx context.Context, challe
oauth2.SetAuthURLParam("code_challenge", challenge.Challenge),
)

listener := NewOAuth2Listener(sock)
defer listener.Close()
// This error can be ignored.
go listener.Listen()
listener := NewOAuth2Listener()
// TODO: This error probably should not be ignored if it is not http.ErrServerClosed
go http.Serve(sock, &listener)

if err := r.OnDisplayURL(url); err != nil {
// This is unlikely to ever happen
Expand Down
70 changes: 42 additions & 28 deletions cli/oauth2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,55 +12,69 @@ import (
"github.com/stretchr/testify/require"
)

// This test ensures that the success flow of receiving an authorization code from a HTTP handler works correctly.
func Test_OAuth2Listener_WaitForAuthorizationCodeWorksCorrectly(t *testing.T) {
// TODO: Initializing private fields means we should probably refactor OAuth2Listener's constructor to not require net.Listener,
// but instead to have net.Listener be a thing in Listen();
//
// Users should be able to use ServeHTTP.
listener := OAuth2Listener{
callbackCh: make(chan OAuth2CallbackInfo),
func sendRequestToListener(listener *OAuth2Listener, values url.Values) {
uri := url.URL{
Scheme: "http",
Host: "localhost",
Path: "/oauth2/callback",
RawQuery: values.Encode(),
}

w := httptest.NewRecorder()
req := httptest.NewRequest("GET", uri.String(), nil)
listener.ServeHTTP(w, req)
}

// This test ensures that the success flow of receiving an authorization code from a HTTP handler works correctly.
func Test_OAuth2Listener_WaitForAuthorizationCodeWorksCorrectly(t *testing.T) {
listener := NewOAuth2Listener()
expectedState := "state goes here"
expectedCode := "code goes here"

go func() {
w := httptest.NewRecorder()
values := url.Values{
"code": []string{expectedCode},
"state": []string{expectedState},
}

uri := url.URL{
Scheme: "http",
Host: "localhost",
Path: "/oauth2/callback",
RawQuery: values.Encode(),
}

req := httptest.NewRequest("GET", uri.String(), nil)
listener.ServeHTTP(w, req)
}()
go sendRequestToListener(&listener, url.Values{
"code": []string{expectedCode},
"state": []string{expectedState},
})

deadline, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
code, err := listener.WaitForAuthorizationCode(deadline, expectedState)
assert.NoError(t, err)
cancel()

assert.Equal(t, expectedCode, code)
assert.NoError(t, listener.Close())
}

func Test_OAuth2Listener_ZeroValueNeverPanics(t *testing.T) {
var listener OAuth2Listener
deadline, _ := context.WithTimeout(context.Background(), 500*time.Millisecond)
_, err := listener.WaitForAuthorizationCode(deadline, "")
// This will timeout because the OAuth2Listener will forever listen on a nil channel
assert.ErrorIs(t, context.DeadlineExceeded, err)
assert.NoError(t, listener.Close())
}

// This test is going to be flaky because processes may open ports outside of our control.
// Test_OAuth2Listener_MultipleRequestsDoesNotCausePanic prevents an issue where OAuth2Listener would send a request to a closed channel
func Test_OAuth2Listener_MultipleRequestsDoesNotCausePanic(t *testing.T) {
listener := NewOAuth2Listener()
expectedState := "state goes here"
expectedCode := "code goes here"

go sendRequestToListener(&listener, url.Values{
"code": []string{expectedCode},
"state": []string{expectedState},
})

go sendRequestToListener(&listener, url.Values{
"code": []string{"not the expected code and should be discarded"},
"state": []string{"not the expected state and should be discarded"},
})

deadlineCtx, _ := context.WithTimeout(context.Background(), 500*time.Millisecond)
code, err := listener.WaitForAuthorizationCode(deadlineCtx, expectedState)
assert.NoError(t, err)
assert.Equal(t, expectedCode, code)
}

// Test_ListenAnyPort_WorksCorrectly is going to be flaky because processes may open ports outside of our control.
func Test_ListenAnyPort_WorksCorrectly(t *testing.T) {
ports := []string{"58080", "58081", "58082", "58083"}
socket, err := net.Listen("tcp4", net.JoinHostPort("127.0.0.1", ports[0]))
Expand Down

0 comments on commit 6000b9d

Please sign in to comment.