From 870724c994ff492e2952417e665e1325762f8edb Mon Sep 17 00:00:00 2001 From: Thom Seddon Date: Mon, 29 Jun 2020 21:02:45 +0100 Subject: [PATCH] Fail if there is an error retrieving the user + extra test (#142) Previously this would fail, but permit the request, which isn't normally what you'd want. --- internal/server.go | 1 + internal/server_test.go | 73 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 73 insertions(+), 1 deletion(-) diff --git a/internal/server.go b/internal/server.go index 3fd76500..8ac03131 100644 --- a/internal/server.go +++ b/internal/server.go @@ -167,6 +167,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc { user, err := p.GetUser(token) if err != nil { logger.WithField("error", err).Error("Error getting user") + http.Error(w, "Service unavailable", 503) return } diff --git a/internal/server_test.go b/internal/server_test.go index be1a7661..2e543400 100644 --- a/internal/server_test.go +++ b/internal/server_test.go @@ -158,6 +158,12 @@ func TestServerAuthCallback(t *testing.T) { res, _ = doHttpRequest(req, c) assert.Equal(401, res.StatusCode, "auth callback with invalid cookie shouldn't be authorised") + // Should catch invalid provider cookie + req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:invalid:http://redirect") + c = MakeCSRFCookie(req, "12345678901234567890123456789012") + res, _ = doHttpRequest(req, c) + assert.Equal(401, res.StatusCode, "auth callback with invalid provider shouldn't be authorised") + // Should redirect valid request req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:google:http://redirect") c = MakeCSRFCookie(req, "12345678901234567890123456789012") @@ -170,6 +176,58 @@ func TestServerAuthCallback(t *testing.T) { assert.Equal("", fwd.Path, "valid request should be redirected to return url") } +func TestServerAuthCallbackExchangeFailure(t *testing.T) { + assert := assert.New(t) + config = newDefaultConfig() + + // Setup OAuth server + server, serverURL := NewFailingOAuthServer(t) + defer server.Close() + config.Providers.Google.TokenURL = &url.URL{ + Scheme: serverURL.Scheme, + Host: serverURL.Host, + Path: "/token", + } + config.Providers.Google.UserURL = &url.URL{ + Scheme: serverURL.Scheme, + Host: serverURL.Host, + Path: "/userinfo", + } + + // Should handle failed code exchange + req := newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:google:http://redirect") + c := MakeCSRFCookie(req, "12345678901234567890123456789012") + res, _ := doHttpRequest(req, c) + assert.Equal(503, res.StatusCode, "auth callback should handle failed code exchange") +} + +func TestServerAuthCallbackUserFailure(t *testing.T) { + assert := assert.New(t) + config = newDefaultConfig() + + // Setup OAuth server + server, serverURL := NewOAuthServer(t) + defer server.Close() + config.Providers.Google.TokenURL = &url.URL{ + Scheme: serverURL.Scheme, + Host: serverURL.Host, + Path: "/token", + } + serverFail, serverFailURL := NewFailingOAuthServer(t) + defer serverFail.Close() + config.Providers.Google.UserURL = &url.URL{ + Scheme: serverFailURL.Scheme, + Host: serverFailURL.Host, + Path: "/userinfo", + } + + // Should handle failed user request + req := newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:google:http://redirect") + c := MakeCSRFCookie(req, "12345678901234567890123456789012") + res, _ := doHttpRequest(req, c) + assert.Equal(503, res.StatusCode, "auth callback should handle failed user request") +} + func TestServerLogout(t *testing.T) { require := require.New(t) assert := assert.New(t) @@ -398,10 +456,16 @@ func TestServerRouteQuery(t *testing.T) { */ type OAuthServer struct { - t *testing.T + t *testing.T + fail bool } func (s *OAuthServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if s.fail { + http.Error(w, "Service unavailable", 500) + return + } + if r.URL.Path == "/token" { fmt.Fprintf(w, `{"access_token":"123456789"}`) } else if r.URL.Path == "/userinfo" { @@ -423,6 +487,13 @@ func NewOAuthServer(t *testing.T) (*httptest.Server, *url.URL) { return server, serverURL } +func NewFailingOAuthServer(t *testing.T) (*httptest.Server, *url.URL) { + handler := &OAuthServer{fail: true} + server := httptest.NewServer(handler) + serverURL, _ := url.Parse(server.URL) + return server, serverURL +} + func doHttpRequest(r *http.Request, c *http.Cookie) (*http.Response, string) { w := httptest.NewRecorder()