Skip to content

Commit

Permalink
Merge pull request #7 from linuxfoundation/content-type
Browse files Browse the repository at this point in the history
Fix Content-Type header, Remove MultiValueHeaders
  • Loading branch information
bramwelt authored May 2, 2024
2 parents 62ff146 + 2d7c4bd commit 5ac7182
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 28 deletions.
64 changes: 43 additions & 21 deletions pkg/middlewares/awslambda/aws_lambda.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func (a *awsLambda) GetTracingInformation() (string, ext.SpanKindEnum) {
func (a *awsLambda) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
logger := log.FromContext(middlewares.GetLoggerCtx(req.Context(), a.name, typeName))

base64Encoded, reqBody, err := bodyToBase64(req)
base64Encoded, contentType, reqBody, err := bodyToBase64(req)
if err != nil {
msg := fmt.Sprintf("Error encoding Lambda request body: %v", err)
logger.Error(msg)
Expand All @@ -149,6 +149,29 @@ func (a *awsLambda) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return
}

if req.ContentLength > 0 {
rCt := req.Header.Get("Content-Type")
switch rCt {
case "":
logger.Debug("Content-Type not set")
if !strings.HasPrefix(contentType, "text") {
logger.Debugf("Content-Type not like text, setting to :%s", contentType)
req.Header.Set("Content-Type", contentType)
} else {
req.Header.Set("Content-Type", "application/json")
}
case "application/x-www-form-urlencoded":
// If sending data through cURL on the commandline and
// the content-type header is missed, orr for
// applications that aren't explicitly setting Content-Type,
// override to 'application/json' if the body looks like JSON.
if isJSON(reqBody) {
req.Header.Set("Content-Type", "application/json")
}
}
logger.Debugf("Content-Type set to: %s, originally %s", req.Header.Get("Content-Type"), rCt)
}

// Ensure tracing headers are included in the request before copying
// them to the lambda request
tracing.InjectRequestHeaders(req)
Expand All @@ -159,7 +182,6 @@ func (a *awsLambda) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
QueryStringParameters: valuesToMap(req.URL.Query()),
MultiValueQueryStringParameters: valuesToMultiMap(req.URL.Query()),
Headers: headersToMap(req.Header),
MultiValueHeaders: headersToMultiMap(req.Header),
Body: reqBody,
IsBase64Encoded: base64Encoded,
RequestContext: events.APIGatewayProxyRequestContext{
Expand Down Expand Up @@ -205,6 +227,12 @@ func (a *awsLambda) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}

for key, values := range resp.MultiValueHeaders {
// NOTE This maybe specific to Content-Type, but it's listed in
// headers and multivalue headers so it ends up getting added twice.
// Is a multivalue header with only one item really multivalue?
if len(values) < 2 {
continue
}
for _, value := range values {
rw.Header().Add(key, value)
}
Expand Down Expand Up @@ -234,7 +262,8 @@ func (a *awsLambda) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}

// bodyToBase64 ensures the request body is base64 encoded.
func bodyToBase64(req *http.Request) (bool, string, error) {
func bodyToBase64(req *http.Request) (bool, string, string, error) {
contentType := ""
base64Encoded := false
body := ""
// base64 encode non-text request body
Expand All @@ -246,15 +275,15 @@ func bodyToBase64(req *http.Request) (bool, string, error) {
// Read the request body and reset it to be read again if needed
bodyBytes, err := io.ReadAll(req.Body)
if err != nil {
return base64Encoded, body, err
return base64Encoded, contentType, body, err
}
req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))

body = string(bodyBytes)

// Any non 'text/*' MIME types should be base64 encoded.
// DetectContentType does not check for 'application/json'
contentType := http.DetectContentType(bodyBytes)
contentType = http.DetectContentType(bodyBytes)
if !strings.HasPrefix(contentType, "text") {
base64Encoded = true
}
Expand All @@ -267,17 +296,17 @@ func bodyToBase64(req *http.Request) (bool, string, error) {

_, err := io.Copy(encoder, bytes.NewReader(bodyBytes))
if err != nil {
return base64Encoded, body, err
return base64Encoded, contentType, body, err
}
if err = encoder.Close(); err != nil {
return base64Encoded, body, err
return base64Encoded, contentType, body, err
}
// Set body to b64 encoded version
body = b64buf.String()
}
}

return base64Encoded, body, nil
return base64Encoded, contentType, body, nil
}

func (a *awsLambda) invokeFunction(ctx context.Context, request events.APIGatewayProxyRequest) (*events.APIGatewayProxyResponse, error) {
Expand Down Expand Up @@ -331,19 +360,6 @@ func headersToMap(h http.Header) map[string]string {
return values
}

func headersToMultiMap(h http.Header) map[string][]string {
values := map[string][]string{}
for name, headers := range h {
if len(headers) < 2 {
continue
}

values[name] = headers
}

return values
}

func valueToString(f interface{}) (string, bool) {
var v string
typeof := reflect.TypeOf(f)
Expand Down Expand Up @@ -429,3 +445,9 @@ func valuesToMultiMap(i url.Values) map[string][]string {

return values
}

// Check if a string looks like JSON.
func isJSON(s string) bool {
var js interface{}
return json.Unmarshal([]byte(s), &js) == nil
}
33 changes: 26 additions & 7 deletions pkg/middlewares/awslambda/aws_lambda_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,7 @@ func Test_AWSLambdaMiddleware_InvokeBasic(t *testing.T) {
assert.Equal(t, "/test/example/path", lReq.Path)
assert.Equal(t, map[string]string{"a": "1", "b": "2"}, lReq.QueryStringParameters)
assert.Equal(t, map[string][]string{"c": {"3", "4"}, "d[]": {"5", "6"}}, lReq.MultiValueQueryStringParameters)
assert.Equal(t, map[string]string{"Content-Type": "text/plain"}, lReq.Headers)
assert.Equal(t, map[string][]string{"X-Test": {"foo", "foobar"}}, lReq.MultiValueHeaders)
assert.Equal(t, map[string]string{"Content-Type": "application/json"}, lReq.Headers)
assert.Equal(t, "This is the body", lReq.Body)

res.WriteHeader(http.StatusOK)
Expand Down Expand Up @@ -144,7 +143,7 @@ func Test_AWSLambdaMiddleware_InvokeBasic(t *testing.T) {
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", "text/plain")
req.Header.Set("Content-Type", "application/json")
req.Header.Add("X-Test", "foo")
req.Header.Add("X-Test", "foobar")

Expand Down Expand Up @@ -178,10 +177,11 @@ func Test_AWSLambdaMiddleware_GetTracingInformation(t *testing.T) {
func Test_AWSLambdaMiddleware_bodyToBase64_empty(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "/", nil)
require.NoError(t, err)
isEncoded, body, err := bodyToBase64(req)
isEncoded, contentType, body, err := bodyToBase64(req)

assert.False(t, isEncoded)
assert.Equal(t, "", body)
assert.Equal(t, "", contentType)
require.NoError(t, err)
}

Expand All @@ -191,10 +191,27 @@ func Test_AWSLambdaMiddleware_bodyToBase64_notEncodedJSON(t *testing.T) {

req, err := http.NewRequest(http.MethodPost, "/", strings.NewReader(reqBody))
require.NoError(t, err)
isEncoded, body, err := bodyToBase64(req)
isEncoded, contentType, body, err := bodyToBase64(req)

assert.False(t, isEncoded)
assert.Equal(t, reqBody, body)
assert.Equal(t, "text/plain; charset=utf-8", contentType)
require.NoError(t, err)
}

func Test_AWSLambdaMiddleware_bodyToBase64_EncodedJSON(t *testing.T) {
bodyBytes, err := json.Marshal(`{"test": "encoded"}`)
if err != nil {
t.Fatal(err)
}

req, err := http.NewRequest(http.MethodPost, "/", strings.NewReader(string(bodyBytes)))
require.NoError(t, err)
isEncoded, contentType, body, err := bodyToBase64(req)

assert.False(t, isEncoded)
assert.Equal(t, string(bodyBytes), body)
assert.Equal(t, "text/plain; charset=utf-8", contentType)
require.NoError(t, err)
}

Expand All @@ -206,10 +223,11 @@ func Test_AWSLambdaMiddleware_bodyToBase64_withcontent(t *testing.T) {

req, err := http.NewRequest(http.MethodPost, "/", strings.NewReader(reqBody))
require.NoError(t, err)
isEncoded, body, err := bodyToBase64(req)
isEncoded, contentType, body, err := bodyToBase64(req)

assert.True(t, isEncoded)
assert.Equal(t, expected, body)
assert.Equal(t, "application/zip", contentType)
require.NoError(t, err)

// image/jpeg
Expand All @@ -218,9 +236,10 @@ func Test_AWSLambdaMiddleware_bodyToBase64_withcontent(t *testing.T) {

req2, err2 := http.NewRequest(http.MethodPost, "/", strings.NewReader(reqBody2))
require.NoError(t, err2)
isEncoded2, body2, err2 := bodyToBase64(req2)
isEncoded2, contentType2, body2, err2 := bodyToBase64(req2)

assert.True(t, isEncoded2)
assert.Equal(t, expected2, body2)
assert.Equal(t, "image/jpeg", contentType2)
require.NoError(t, err2)
}

0 comments on commit 5ac7182

Please sign in to comment.