Skip to content

Commit

Permalink
[Security] SECVULN-8621: Fix XSS Vulnerability where content-type hea…
Browse files Browse the repository at this point in the history
…der wasn't explicitly set in API requests (#21930)

* Fix XSS Vulnerability where content-type header wasn't explicitly set in API requests

* fix failing unit test
  • Loading branch information
NiniOak authored Nov 27, 2024
1 parent 83b6d99 commit 4b7f7a8
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 12 deletions.
3 changes: 3 additions & 0 deletions .changelog/21930.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:security
api: Enforces strict content-type header validation to protect against XSS vulnerability.
```
23 changes: 16 additions & 7 deletions agent/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package agent
import (
"encoding/json"
"fmt"
"github.com/hashicorp/go-hclog"
"io"
"net"
"net/http"
Expand All @@ -20,6 +19,8 @@ import (
"sync/atomic"
"time"

"github.com/hashicorp/go-hclog"

"github.com/NYTimes/gziphandler"
"github.com/armon/go-metrics"
"github.com/armon/go-metrics/prometheus"
Expand Down Expand Up @@ -348,16 +349,24 @@ func withRemoteAddrHandler(next http.Handler) http.Handler {
})
}

// Injects content type explicitly if not already set into response to prevent XSS
// ensureContentTypeHeader injects content-type explicitly if not already set into response to prevent XSS
func ensureContentTypeHeader(next http.Handler, logger hclog.Logger) http.Handler {

return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
next.ServeHTTP(resp, req)

val := resp.Header().Get(contentTypeHeader)
if val == "" {
resp.Header().Set(contentTypeHeader, plainContentType)
logger.Debug("warning: content-type header not explicitly set.", "request-path", req.URL)
contentType := api.GetContentType(req)

if req != nil {
logger.Debug("warning: request content-type is not supported", "request-path", req.URL)
req.Header.Set(contentTypeHeader, contentType)
}

if resp != nil {
respContentType := resp.Header().Get(contentTypeHeader)
if respContentType == "" || respContentType != contentType {
logger.Debug("warning: response content-type header not explicitly set.", "request-path", req.URL)
resp.Header().Set(contentTypeHeader, contentType)
}
}
})
}
Expand Down
86 changes: 83 additions & 3 deletions agent/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,6 @@ func TestHTTPAPI_DefaultACLPolicy(t *testing.T) {
})
}
}

func TestHTTPAPIResponseHeaders(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
Expand Down Expand Up @@ -646,6 +645,87 @@ func TestHTTPAPIResponseHeaders(t *testing.T) {
requireHasHeadersSet(t, a, "/", "text/plain; charset=utf-8")
}

func TestHTTPAPIValidateContentTypeHeaders(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
}

t.Parallel()
type testcase struct {
name string
endpoint string
method string
requestBody io.Reader
expectedContentType string
}

cases := []testcase{
{
name: "snapshot endpoint expect non-default content type",
method: http.MethodPut,
endpoint: "/v1/snapshot",
requestBody: bytes.NewBuffer([]byte("test")),
expectedContentType: "application/octet-stream",
},
{
name: "kv endpoint expect non-default content type",
method: http.MethodPut,
endpoint: "/v1/kv",
requestBody: bytes.NewBuffer([]byte("test")),
expectedContentType: "application/octet-stream",
},
{
name: "event/fire endpoint expect default content type",
method: http.MethodPut,
endpoint: "/v1/event/fire",
requestBody: bytes.NewBuffer([]byte("test")),
expectedContentType: "application/octet-stream",
},
{
name: "peering/token endpoint expect default content type",
method: http.MethodPost,
endpoint: "/v1/peering/token",
requestBody: bytes.NewBuffer([]byte("test")),
expectedContentType: "application/json",
},
}

for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

a := NewTestAgent(t, "")
defer a.Shutdown()

requireContentTypeHeadersSet(t, a, tc.method, tc.endpoint, tc.requestBody, tc.expectedContentType)
})
}
}

func requireContentTypeHeadersSet(t *testing.T, a *TestAgent, method, path string, body io.Reader, contentType string) {
t.Helper()

resp := httptest.NewRecorder()
req, _ := http.NewRequest(method, path, body)
a.enableDebug.Store(true)

a.srv.handler().ServeHTTP(resp, req)

reqHdrs := req.Header
respHdrs := resp.Header()

// require request content-type
require.NotEmpty(t, reqHdrs.Get("Content-Type"))
require.Equal(t, contentType, reqHdrs.Get("Content-Type"),
"Request Header Content-Type value incorrect")

// require response content-type
require.NotEmpty(t, respHdrs.Get("Content-Type"))
require.Equal(t, contentType, respHdrs.Get("Content-Type"),
"Response Header Content-Type value incorrect")
}

func requireHasHeadersSet(t *testing.T, a *TestAgent, path string, contentType string) {
t.Helper()

Expand All @@ -663,7 +743,7 @@ func requireHasHeadersSet(t *testing.T, a *TestAgent, path string, contentType s
"X-XSS-Protection header value incorrect")

require.Equal(t, contentType, hdrs.Get("Content-Type"),
"")
"Response Content-Type header value incorrect")
}

func TestUIResponseHeaders(t *testing.T) {
Expand Down Expand Up @@ -704,7 +784,7 @@ func TestErrorContentTypeHeaderSet(t *testing.T) {
`)
defer a.Shutdown()

requireHasHeadersSet(t, a, "/fake-path-doesn't-exist", "text/plain; charset=utf-8")
requireHasHeadersSet(t, a, "/fake-path-doesn't-exist", "application/json")
}

func TestAcceptEncodingGzip(t *testing.T) {
Expand Down
15 changes: 15 additions & 0 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -1087,8 +1087,23 @@ func (c *Client) doRequest(r *request) (time.Duration, *http.Response, error) {
if err != nil {
return 0, nil, err
}

contentType := GetContentType(req)

if req != nil {
req.Header.Set(contentTypeHeader, contentType)
}

start := time.Now()
resp, err := c.config.HttpClient.Do(req)

if resp != nil {
respContentType := resp.Header.Get(contentTypeHeader)
if respContentType == "" || respContentType != contentType {
resp.Header.Set(contentTypeHeader, contentType)
}
}

diff := time.Since(start)
return diff, resp, err
}
Expand Down
4 changes: 2 additions & 2 deletions api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -935,11 +935,11 @@ func TestAPI_Headers(t *testing.T) {

_, _, err = kv.Get("test-headers", nil)
require.NoError(t, err)
require.Equal(t, "", request.Header.Get("Content-Type"))
require.Equal(t, "application/json", request.Header.Get("Content-Type"))

_, err = kv.Delete("test-headers", nil)
require.NoError(t, err)
require.Equal(t, "", request.Header.Get("Content-Type"))
require.Equal(t, "application/json", request.Header.Get("Content-Type"))

err = c.Snapshot().Restore(nil, strings.NewReader("foo"))
require.Error(t, err)
Expand Down
81 changes: 81 additions & 0 deletions api/content_type.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package api

import (
"net/http"
"strings"
)

const (
contentTypeHeader = "Content-Type"
plainContentType = "text/plain; charset=utf-8"
octetStream = "application/octet-stream"
jsonContentType = "application/json" // Default content type
)

// ContentTypeRule defines a rule for determining the content type of an HTTP request.
// This rule is based on the combination of the HTTP path, method, and the desired content type.
type ContentTypeRule struct {
path string
httpMethod string
contentType string
}

var ContentTypeRules = []ContentTypeRule{
{
path: "/v1/snapshot",
httpMethod: http.MethodPut,
contentType: octetStream,
},
{
path: "/v1/kv",
httpMethod: http.MethodPut,
contentType: octetStream,
},
{
path: "/v1/event/fire",
httpMethod: http.MethodPut,
contentType: octetStream,
},
}

// GetContentType returns the content type for a request
// This function isused as routing logic or middleware to determine and enforce
// the appropriate content type for HTTP requests.
func GetContentType(req *http.Request) string {
reqContentType := req.Header.Get(contentTypeHeader)

if isIndexPage(req) {
return plainContentType
}

// For GET, DELETE, or internal API paths, ensure a valid Content-Type is returned.
if req.Method == http.MethodGet || req.Method == http.MethodDelete || strings.HasPrefix(req.URL.Path, "/v1/internal") {
if reqContentType == "" {
// Default to JSON Content-Type if no Content-Type is provided.
return jsonContentType
}
// Return the provided Content-Type if it exists.
return reqContentType
}

for _, rule := range ContentTypeRules {
if matchesRule(req, rule) {
return rule.contentType
}
}
return jsonContentType
}

// matchesRule checks if a request matches a content type rule
func matchesRule(req *http.Request, rule ContentTypeRule) bool {
return strings.HasPrefix(req.URL.Path, rule.path) &&
(rule.httpMethod == "" || req.Method == rule.httpMethod)
}

// isIndexPage checks if the request is for the index page
func isIndexPage(req *http.Request) bool {
return req.URL.Path == "/" || req.URL.Path == "/ui"
}

0 comments on commit 4b7f7a8

Please sign in to comment.