Skip to content

Commit

Permalink
Remove CSRF checking middleware (#50358)
Browse files Browse the repository at this point in the history
The remaining two endpoints that were checking the CSRF token were
both unauthenticated requests. We don't need a CSRF token here because
we require Content-Type: application/json for these requests.
  • Loading branch information
zmb3 authored Dec 18, 2024
1 parent 1345dbf commit c222030
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 116 deletions.
19 changes: 1 addition & 18 deletions lib/httplib/httplib.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ import (
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/observability/tracing"
tracehttp "github.com/gravitational/teleport/api/observability/tracing/http"
"github.com/gravitational/teleport/lib/httplib/csrf"
"github.com/gravitational/teleport/lib/utils"
)

Expand Down Expand Up @@ -155,23 +154,6 @@ func MakeStdHandlerWithErrorWriter(fn StdHandlerFunc, errWriter ErrorWriter) htt
}
}

// WithCSRFProtection ensures that request to unauthenticated API is checked against CSRF attacks
func WithCSRFProtection(fn HandlerFunc) httprouter.Handle {
handlerFn := MakeHandler(fn)
return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
if r.Method != http.MethodGet && r.Method != http.MethodHead {
errHeader := csrf.VerifyHTTPHeader(r)
errForm := csrf.VerifyFormField(r)
if errForm != nil && errHeader != nil {
slog.WarnContext(r.Context(), "unable to validate CSRF token", "header_error", errHeader, "form_error", errForm)
trace.WriteError(w, trace.AccessDenied("access denied"))
return
}
}
handlerFn(w, r, p)
}
}

// ReadJSON reads HTTP json request and unmarshals it
// into passed any obj. A reasonable maximum size is enforced
// to mitigate resource exhaustion attacks.
Expand All @@ -188,6 +170,7 @@ func ReadResourceJSON(r *http.Request, val any) error {

func readJSON(r *http.Request, val any, maxSize int64) error {
// Check content type to mitigate CSRF attack.
// (Form POST requests don't support application/json payloads.)
contentType, _, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
if err != nil {
slog.WarnContext(r.Context(), "Error parsing media type for reading JSON", "error", err)
Expand Down
29 changes: 2 additions & 27 deletions lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ func (h *Handler) bindDefaultEndpoints() {
h.POST("/webapi/sessions/app", h.WithAuth(h.createAppSession))

// Web sessions
h.POST("/webapi/sessions/web", httplib.WithCSRFProtection(h.WithLimiterHandlerFunc(h.createWebSession)))
h.POST("/webapi/sessions/web", h.WithLimiter(h.createWebSession))
h.DELETE("/webapi/sessions/web", h.WithAuth(h.deleteWebSession))
h.POST("/webapi/sessions/web/renew", h.WithAuth(h.renewWebSession))
h.POST("/webapi/users", h.WithAuth(h.createUserHandle))
Expand All @@ -793,7 +793,7 @@ func (h *Handler) bindDefaultEndpoints() {
// h.GET("/webapi/users/password/token/:token", h.WithLimiter(h.getResetPasswordTokenHandle))
h.GET("/webapi/users/*wildcard", h.handleGetUserOrResetToken)

h.PUT("/webapi/users/password/token", httplib.WithCSRFProtection(h.changeUserAuthentication))
h.PUT("/webapi/users/password/token", h.WithLimiter(h.changeUserAuthentication))
h.PUT("/webapi/users/password", h.WithAuth(h.changePassword))
h.POST("/webapi/users/password/token", h.WithAuth(h.createResetPasswordToken))
h.POST("/webapi/users/privilege/token", h.WithAuth(h.createPrivilegeTokenHandle))
Expand Down Expand Up @@ -1994,7 +1994,6 @@ func (h *Handler) githubLoginWeb(w http.ResponseWriter, r *http.Request, p httpr
}

response, err := h.cfg.ProxyClient.CreateGithubAuthRequest(r.Context(), types.GithubAuthRequest{
CSRFToken: req.CSRFToken,
ConnectorID: req.ConnectorID,
CreateWebSession: true,
ClientRedirectURL: req.ClientRedirectURL,
Expand All @@ -2004,7 +2003,6 @@ func (h *Handler) githubLoginWeb(w http.ResponseWriter, r *http.Request, p httpr
if err != nil {
logger.WithError(err).Error("Error creating auth request.")
return client.LoginFailedRedirectURL

}

return response.RedirectURL
Expand Down Expand Up @@ -4705,21 +4703,6 @@ func (h *Handler) WithSession(fn ContextHandler) httprouter.Handle {
})
}

// WithAuthCookieAndCSRF ensures that a request is authenticated
// for plain old non-AJAX requests (does not check the Bearer header).
// It enforces CSRF checks (except for "safe" methods).
func (h *Handler) WithAuthCookieAndCSRF(fn ContextHandler) httprouter.Handle {
f := func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) {
sctx, err := h.AuthenticateRequest(w, r, false)
if err != nil {
return nil, trace.Wrap(err)
}
return fn(w, r, p, sctx)
}

return httplib.WithCSRFProtection(f)
}

// WithUnauthenticatedLimiter adds a conditional IP-based rate limiting that will limit only unauthenticated requests.
// This is a good default to use as both Cluster and User auth are checked here, but `WithLimiter` can be used if
// you're certain that no authenticated requests will be made.
Expand Down Expand Up @@ -5054,8 +5037,6 @@ type SSORequestParams struct {
// ConnectorID identifies the SSO connector to use to log in, from
// the connector_id query parameter.
ConnectorID string
// CSRFToken is the token in the CSRF cookie header.
CSRFToken string
}

// ParseSSORequestParams extracts the SSO request parameters from an http.Request,
Expand Down Expand Up @@ -5088,15 +5069,9 @@ func ParseSSORequestParams(r *http.Request) (*SSORequestParams, error) {
return nil, trace.BadParameter("missing connector_id query parameter")
}

csrfToken, err := csrf.ExtractTokenFromCookie(r)
if err != nil {
return nil, trace.Wrap(err)
}

return &SSORequestParams{
ClientRedirectURL: clientRedirectURL,
ConnectorID: connectorID,
CSRFToken: csrfToken,
}, nil
}

Expand Down
64 changes: 11 additions & 53 deletions lib/web/apiserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ import (
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/events/eventstest"
"github.com/gravitational/teleport/lib/httplib"
"github.com/gravitational/teleport/lib/httplib/csrf"
"github.com/gravitational/teleport/lib/inventory"
kubeproxy "github.com/gravitational/teleport/lib/kube/proxy"
"github.com/gravitational/teleport/lib/limiter"
Expand Down Expand Up @@ -947,50 +946,32 @@ func TestWebSessionsCRUD(t *testing.T) {
func TestCSRF(t *testing.T) {
t.Parallel()
s := newWebSuite(t)
type input struct {
reqToken string
cookieToken string
}

// create a valid user
user := "csrfuser"
pass := "abcdef123456"
otpSecret := newOTPSharedSecret()
s.createUser(t, user, user, pass, otpSecret)

encodedToken1 := "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992"
encodedToken2 := "bf355921bbf3ef3672a03e410d4194077dfa5fe863c652521763b3e7f81e7b11"
invalid := []input{
{reqToken: encodedToken2, cookieToken: encodedToken1},
{reqToken: "", cookieToken: encodedToken1},
{reqToken: "", cookieToken: ""},
{reqToken: encodedToken1, cookieToken: ""},
}

clt := s.client(t)
ctx := context.Background()

// valid
validReq := loginWebOTPParams{
webClient: clt,
clock: s.clock,
user: user,
password: pass,
otpSecret: otpSecret,
cookieCSRF: &encodedToken1,
headerCSRF: &encodedToken1,
webClient: clt,
clock: s.clock,
user: user,
password: pass,
otpSecret: otpSecret,
}
loginWebOTP(t, ctx, validReq)

// invalid
for i := range invalid {
req := validReq
req.cookieCSRF = &invalid[i].cookieToken
req.headerCSRF = &invalid[i].reqToken
httpResp, _, err := rawLoginWebOTP(ctx, req)
require.NoError(t, err, "Login via /webapi/sessions/new failed unexpectedly")
assert.Equal(t, http.StatusForbidden, httpResp.StatusCode, "HTTP status code mismatch")
}
// invalid - wrong content-type header
invalidReq := validReq
invalidReq.overrideContentType = "multipart/form-data"
httpResp, _, err := rawLoginWebOTP(ctx, invalidReq)
require.NoError(t, err, "Login via /webapi/sessions/new failed unexpectedly")
require.Equal(t, http.StatusBadRequest, httpResp.StatusCode, "HTTP status code mismatch")
}

func TestPasswordChange(t *testing.T) {
Expand Down Expand Up @@ -5953,13 +5934,9 @@ func TestChangeUserAuthentication_WithPrivacyPolicyEnabledError(t *testing.T) {
httpReqData, err := json.Marshal(req)
require.NoError(t, err)

// CSRF protected endpoint.
csrfToken := "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992"
httpReq, err := http.NewRequest("PUT", clt.Endpoint("webapi", "users", "password", "token"), bytes.NewBuffer(httpReqData))
require.NoError(t, err)
addCSRFCookieToReq(httpReq, csrfToken)
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set(csrf.HeaderName, csrfToken)
httpRes, err := httplib.ConvertResponse(clt.RoundTrip(func() (*http.Response, error) {
return clt.HTTPClient().Do(httpReq)
}))
Expand Down Expand Up @@ -6104,10 +6081,6 @@ func TestChangeUserAuthentication_settingDefaultClusterAuthPreference(t *testing
req, err := http.NewRequest("PUT", clt.Endpoint("webapi", "users", "password", "token"), bytes.NewBuffer(body))
require.NoError(t, err)

csrfToken, err := csrf.GenerateToken()
require.NoError(t, err)
addCSRFCookieToReq(req, csrfToken)
req.Header.Set(csrf.HeaderName, csrfToken)
req.Header.Set("Content-Type", "application/json")

re, err := clt.Client.RoundTrip(func() (*http.Response, error) {
Expand All @@ -6129,8 +6102,6 @@ func TestChangeUserAuthentication_settingDefaultClusterAuthPreference(t *testing
func TestParseSSORequestParams(t *testing.T) {
t.Parallel()

token := "someMeaninglessTokenString"

tests := []struct {
name, url string
wantErr bool
Expand All @@ -6142,7 +6113,6 @@ func TestParseSSORequestParams(t *testing.T) {
expected: &SSORequestParams{
ClientRedirectURL: "https://localhost:8080/web/cluster/im-a-cluster-name/nodes?search=tunnel&sort=hostname:asc",
ConnectorID: "oidc",
CSRFToken: token,
},
},
{
Expand All @@ -6151,7 +6121,6 @@ func TestParseSSORequestParams(t *testing.T) {
expected: &SSORequestParams{
ClientRedirectURL: "https://localhost:8080/web/cluster/im-a-cluster-name/nodes?search=tunnel&sort=hostname:asc",
ConnectorID: "github",
CSRFToken: token,
},
},
{
Expand All @@ -6160,7 +6129,6 @@ func TestParseSSORequestParams(t *testing.T) {
expected: &SSORequestParams{
ClientRedirectURL: "https://localhost:8080/web/cluster/im-a-cluster-name/apps?query=search(%22watermelon%22%2C%20%22this%22)%20%26%26%20labels%5B%22unique-id%22%5D%20%3D%3D%20%22hi%22&sort=name:asc",
ConnectorID: "saml",
CSRFToken: token,
},
},
{
Expand All @@ -6179,7 +6147,6 @@ func TestParseSSORequestParams(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
req, err := http.NewRequest("", tc.url, nil)
require.NoError(t, err)
addCSRFCookieToReq(req, token)

params, err := ParseSSORequestParams(req)

Expand Down Expand Up @@ -7932,15 +7899,6 @@ func (s *WebSuite) url() *url.URL {
return u
}

func addCSRFCookieToReq(req *http.Request, token string) {
cookie := &http.Cookie{
Name: csrf.CookieName,
Value: token,
}

req.AddCookie(cookie)
}

func removeSpace(in string) string {
for _, c := range []string{"\n", "\r", "\t"} {
in = strings.Replace(in, c, " ", -1)
Expand Down
22 changes: 4 additions & 18 deletions lib/web/login_helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package web

import (
"bytes"
"cmp"
"context"
"encoding/base32"
"encoding/json"
Expand All @@ -34,7 +35,6 @@ import (

"github.com/gravitational/teleport/lib/auth/mocku2f"
"github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/httplib/csrf"
)

// newOTPSharedSecret returns an OTP shared secret, encoded as a base32 string.
Expand All @@ -54,9 +54,8 @@ type loginWebOTPParams struct {
// If empty then no OTP is sent in the request.
otpSecret string

userAgent string // Optional.

cookieCSRF, headerCSRF *string // Explicit CSRF tokens. Optional.
userAgent string // Optional.
overrideContentType string // Optional.
}

// DrainedHTTPResponse mimics an http.Response, but without a body.
Expand Down Expand Up @@ -124,24 +123,11 @@ func rawLoginWebOTP(ctx context.Context, params loginWebOTPParams) (resp *Draine
}

// Set assorted headers.
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Content-Type", cmp.Or(params.overrideContentType, "application/json"))
if params.userAgent != "" {
req.Header.Set("User-Agent", params.userAgent)
}

// Set CSRF cookie and header.
const defaultCSRFToken = "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992"
cookieCSRF := defaultCSRFToken
if params.cookieCSRF != nil {
cookieCSRF = *params.cookieCSRF
}
addCSRFCookieToReq(req, cookieCSRF)
headerCSRF := defaultCSRFToken
if params.headerCSRF != nil {
headerCSRF = *params.headerCSRF
}
req.Header.Set(csrf.HeaderName, headerCSRF)

httpResp, err := webClient.HTTPClient().Do(req)
if err != nil {
return nil, nil, trace.Wrap(err, "do HTTP request")
Expand Down

0 comments on commit c222030

Please sign in to comment.