Skip to content

Commit

Permalink
Fail if there is an error retrieving the user + extra test (#142)
Browse files Browse the repository at this point in the history
Previously this would fail, but permit the request, which isn't
normally what you'd want.
  • Loading branch information
thomseddon authored Jun 29, 2020
1 parent be2b4ba commit 870724c
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 1 deletion.
1 change: 1 addition & 0 deletions internal/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
user, err := p.GetUser(token)
if err != nil {
logger.WithField("error", err).Error("Error getting user")
http.Error(w, "Service unavailable", 503)
return
}

Expand Down
73 changes: 72 additions & 1 deletion internal/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,12 @@ func TestServerAuthCallback(t *testing.T) {
res, _ = doHttpRequest(req, c)
assert.Equal(401, res.StatusCode, "auth callback with invalid cookie shouldn't be authorised")

// Should catch invalid provider cookie
req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:invalid:http://redirect")
c = MakeCSRFCookie(req, "12345678901234567890123456789012")
res, _ = doHttpRequest(req, c)
assert.Equal(401, res.StatusCode, "auth callback with invalid provider shouldn't be authorised")

// Should redirect valid request
req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:google:http://redirect")
c = MakeCSRFCookie(req, "12345678901234567890123456789012")
Expand All @@ -170,6 +176,58 @@ func TestServerAuthCallback(t *testing.T) {
assert.Equal("", fwd.Path, "valid request should be redirected to return url")
}

func TestServerAuthCallbackExchangeFailure(t *testing.T) {
assert := assert.New(t)
config = newDefaultConfig()

// Setup OAuth server
server, serverURL := NewFailingOAuthServer(t)
defer server.Close()
config.Providers.Google.TokenURL = &url.URL{
Scheme: serverURL.Scheme,
Host: serverURL.Host,
Path: "/token",
}
config.Providers.Google.UserURL = &url.URL{
Scheme: serverURL.Scheme,
Host: serverURL.Host,
Path: "/userinfo",
}

// Should handle failed code exchange
req := newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:google:http://redirect")
c := MakeCSRFCookie(req, "12345678901234567890123456789012")
res, _ := doHttpRequest(req, c)
assert.Equal(503, res.StatusCode, "auth callback should handle failed code exchange")
}

func TestServerAuthCallbackUserFailure(t *testing.T) {
assert := assert.New(t)
config = newDefaultConfig()

// Setup OAuth server
server, serverURL := NewOAuthServer(t)
defer server.Close()
config.Providers.Google.TokenURL = &url.URL{
Scheme: serverURL.Scheme,
Host: serverURL.Host,
Path: "/token",
}
serverFail, serverFailURL := NewFailingOAuthServer(t)
defer serverFail.Close()
config.Providers.Google.UserURL = &url.URL{
Scheme: serverFailURL.Scheme,
Host: serverFailURL.Host,
Path: "/userinfo",
}

// Should handle failed user request
req := newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:google:http://redirect")
c := MakeCSRFCookie(req, "12345678901234567890123456789012")
res, _ := doHttpRequest(req, c)
assert.Equal(503, res.StatusCode, "auth callback should handle failed user request")
}

func TestServerLogout(t *testing.T) {
require := require.New(t)
assert := assert.New(t)
Expand Down Expand Up @@ -398,10 +456,16 @@ func TestServerRouteQuery(t *testing.T) {
*/

type OAuthServer struct {
t *testing.T
t *testing.T
fail bool
}

func (s *OAuthServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if s.fail {
http.Error(w, "Service unavailable", 500)
return
}

if r.URL.Path == "/token" {
fmt.Fprintf(w, `{"access_token":"123456789"}`)
} else if r.URL.Path == "/userinfo" {
Expand All @@ -423,6 +487,13 @@ func NewOAuthServer(t *testing.T) (*httptest.Server, *url.URL) {
return server, serverURL
}

func NewFailingOAuthServer(t *testing.T) (*httptest.Server, *url.URL) {
handler := &OAuthServer{fail: true}
server := httptest.NewServer(handler)
serverURL, _ := url.Parse(server.URL)
return server, serverURL
}

func doHttpRequest(r *http.Request, c *http.Cookie) (*http.Response, string) {
w := httptest.NewRecorder()

Expand Down

0 comments on commit 870724c

Please sign in to comment.