Skip to content

Commit

Permalink
Content type and duplicate headers
Browse files Browse the repository at this point in the history
Signed-off-by: Trevor Bramwell <[email protected]>
  • Loading branch information
bramwelt committed May 2, 2024
1 parent 62ff146 commit 4a9598b
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 13 deletions.
48 changes: 41 additions & 7 deletions pkg/middlewares/awslambda/aws_lambda.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func (a *awsLambda) GetTracingInformation() (string, ext.SpanKindEnum) {
func (a *awsLambda) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
logger := log.FromContext(middlewares.GetLoggerCtx(req.Context(), a.name, typeName))

base64Encoded, reqBody, err := bodyToBase64(req)
base64Encoded, contentType, body, err := bodyToBase64(req)
if err != nil {
msg := fmt.Sprintf("Error encoding Lambda request body: %v", err)
logger.Error(msg)
Expand All @@ -149,6 +149,26 @@ func (a *awsLambda) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return
}

// If Content-Type is set, isn't set, assume it's JSON
rCt := req.Header.Get("Content-Type")
switch rCt {
case "":
logger.Debug("Content-Type not set")
if !strings.HasPrefix(contentType, "text") {
logger.Debugf("Content-Type not like text, setting to :%s", contentType)
req.Header.Set("Content-Type", contentType)
} else {
req.Header.Set("Content-Type", "application/json")
}
case "application/x-www-form-urlencoded":
if isJSON(rCt) {
req.Header.Set("Content-Type", "application/json")
}
default:
req.Header.Set("Content-Type", "application/json")
}
logger.Debugf("Content-Type set to: %s, originally %s", req.Header.Get("Content-Type"), rCt)

// Ensure tracing headers are included in the request before copying
// them to the lambda request
tracing.InjectRequestHeaders(req)
Expand Down Expand Up @@ -205,6 +225,12 @@ func (a *awsLambda) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}

for key, values := range resp.MultiValueHeaders {
// NOTE This maybe specific to Content-Type, but it's listed in
// headers and multivalue headers so it ends up getting added twice.
// Is a multivalue header with only one item really multivalue?
if len(values) < 2 {
continue
}
for _, value := range values {
rw.Header().Add(key, value)
}
Expand Down Expand Up @@ -234,7 +260,8 @@ func (a *awsLambda) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}

// bodyToBase64 ensures the request body is base64 encoded.
func bodyToBase64(req *http.Request) (bool, string, error) {
func bodyToBase64(req *http.Request) (bool, string, string, error) {
contentType := ""
base64Encoded := false
body := ""
// base64 encode non-text request body
Expand All @@ -246,15 +273,15 @@ func bodyToBase64(req *http.Request) (bool, string, error) {
// Read the request body and reset it to be read again if needed
bodyBytes, err := io.ReadAll(req.Body)
if err != nil {
return base64Encoded, body, err
return base64Encoded, contentType, body, err
}
req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))

body = string(bodyBytes)

// Any non 'text/*' MIME types should be base64 encoded.
// DetectContentType does not check for 'application/json'
contentType := http.DetectContentType(bodyBytes)
contentType = http.DetectContentType(bodyBytes)
if !strings.HasPrefix(contentType, "text") {
base64Encoded = true
}
Expand All @@ -267,17 +294,17 @@ func bodyToBase64(req *http.Request) (bool, string, error) {

_, err := io.Copy(encoder, bytes.NewReader(bodyBytes))
if err != nil {
return base64Encoded, body, err
return base64Encoded, contentType, body, err
}
if err = encoder.Close(); err != nil {
return base64Encoded, body, err
return base64Encoded, contentType, body, err
}
// Set body to b64 encoded version
body = b64buf.String()
}
}

return base64Encoded, body, nil
return base64Encoded, contentType, body, nil
}

func (a *awsLambda) invokeFunction(ctx context.Context, request events.APIGatewayProxyRequest) (*events.APIGatewayProxyResponse, error) {
Expand Down Expand Up @@ -429,3 +456,10 @@ func valuesToMultiMap(i url.Values) map[string][]string {

return values
}

// Check if a string looks like JSON
func isJSON(s string) bool {
var js interface{}
return json.Unmarshal([]byte(s), &js) == nil

}
32 changes: 26 additions & 6 deletions pkg/middlewares/awslambda/aws_lambda_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func Test_AWSLambdaMiddleware_InvokeBasic(t *testing.T) {
assert.Equal(t, "/test/example/path", lReq.Path)
assert.Equal(t, map[string]string{"a": "1", "b": "2"}, lReq.QueryStringParameters)
assert.Equal(t, map[string][]string{"c": {"3", "4"}, "d[]": {"5", "6"}}, lReq.MultiValueQueryStringParameters)
assert.Equal(t, map[string]string{"Content-Type": "text/plain"}, lReq.Headers)
assert.Equal(t, map[string]string{"Content-Type": "application/json"}, lReq.Headers)
assert.Equal(t, map[string][]string{"X-Test": {"foo", "foobar"}}, lReq.MultiValueHeaders)
assert.Equal(t, "This is the body", lReq.Body)

Expand Down Expand Up @@ -144,7 +144,7 @@ func Test_AWSLambdaMiddleware_InvokeBasic(t *testing.T) {
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", "text/plain")
req.Header.Set("Content-Type", "application/json")
req.Header.Add("X-Test", "foo")
req.Header.Add("X-Test", "foobar")

Expand Down Expand Up @@ -178,10 +178,11 @@ func Test_AWSLambdaMiddleware_GetTracingInformation(t *testing.T) {
func Test_AWSLambdaMiddleware_bodyToBase64_empty(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "/", nil)
require.NoError(t, err)
isEncoded, body, err := bodyToBase64(req)
isEncoded, contentType, body, err := bodyToBase64(req)

assert.False(t, isEncoded)
assert.Equal(t, "", body)
assert.Equal(t, "", contentType)
require.NoError(t, err)
}

Expand All @@ -191,10 +192,27 @@ func Test_AWSLambdaMiddleware_bodyToBase64_notEncodedJSON(t *testing.T) {

req, err := http.NewRequest(http.MethodPost, "/", strings.NewReader(reqBody))
require.NoError(t, err)
isEncoded, body, err := bodyToBase64(req)
isEncoded, contentType, body, err := bodyToBase64(req)

assert.False(t, isEncoded)
assert.Equal(t, reqBody, body)
assert.Equal(t, "text/plain; charset=utf-8", contentType)
require.NoError(t, err)
}

func Test_AWSLambdaMiddleware_bodyToBase64_EncodedJSON(t *testing.T) {
bodyBytes, err := json.Marshal(`{"test": "encoded"}`)
if err != nil {
t.Fatal(err)
}

req, err := http.NewRequest(http.MethodPost, "/", strings.NewReader(string(bodyBytes)))
require.NoError(t, err)
isEncoded, contentType, body, err := bodyToBase64(req)

assert.False(t, isEncoded)
assert.Equal(t, string(bodyBytes), body)
assert.Equal(t, "text/plain; charset=utf-8", contentType)
require.NoError(t, err)
}

Expand All @@ -206,10 +224,11 @@ func Test_AWSLambdaMiddleware_bodyToBase64_withcontent(t *testing.T) {

req, err := http.NewRequest(http.MethodPost, "/", strings.NewReader(reqBody))
require.NoError(t, err)
isEncoded, body, err := bodyToBase64(req)
isEncoded, contentType, body, err := bodyToBase64(req)

assert.True(t, isEncoded)
assert.Equal(t, expected, body)
assert.Equal(t, "application/zip", contentType)
require.NoError(t, err)

// image/jpeg
Expand All @@ -218,9 +237,10 @@ func Test_AWSLambdaMiddleware_bodyToBase64_withcontent(t *testing.T) {

req2, err2 := http.NewRequest(http.MethodPost, "/", strings.NewReader(reqBody2))
require.NoError(t, err2)
isEncoded2, body2, err2 := bodyToBase64(req2)
isEncoded2, contentType2, body2, err2 := bodyToBase64(req2)

assert.True(t, isEncoded2)
assert.Equal(t, expected2, body2)
assert.Equal(t, "image/jpeg", contentType2)
require.NoError(t, err2)
}

0 comments on commit 4a9598b

Please sign in to comment.