Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Content-Type header, Remove MultiValueHeaders #7

Merged
merged 3 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 43 additions & 21 deletions pkg/middlewares/awslambda/aws_lambda.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func (a *awsLambda) GetTracingInformation() (string, ext.SpanKindEnum) {
func (a *awsLambda) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
logger := log.FromContext(middlewares.GetLoggerCtx(req.Context(), a.name, typeName))

base64Encoded, reqBody, err := bodyToBase64(req)
base64Encoded, contentType, reqBody, err := bodyToBase64(req)
if err != nil {
msg := fmt.Sprintf("Error encoding Lambda request body: %v", err)
logger.Error(msg)
Expand All @@ -149,6 +149,29 @@ func (a *awsLambda) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return
}

if req.ContentLength > 0 {
rCt := req.Header.Get("Content-Type")
switch rCt {
case "":
logger.Debug("Content-Type not set")
if !strings.HasPrefix(contentType, "text") {
logger.Debugf("Content-Type not like text, setting to :%s", contentType)
req.Header.Set("Content-Type", contentType)
} else {
req.Header.Set("Content-Type", "application/json")
}
case "application/x-www-form-urlencoded":
// If sending data through cURL on the commandline and
// the content-type header is missed, orr for
// applications that aren't explicitly setting Content-Type,
// override to 'application/json' if the body looks like JSON.
if isJSON(reqBody) {
req.Header.Set("Content-Type", "application/json")
}
}
logger.Debugf("Content-Type set to: %s, originally %s", req.Header.Get("Content-Type"), rCt)
}

// Ensure tracing headers are included in the request before copying
// them to the lambda request
tracing.InjectRequestHeaders(req)
Expand All @@ -159,7 +182,6 @@ func (a *awsLambda) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
QueryStringParameters: valuesToMap(req.URL.Query()),
MultiValueQueryStringParameters: valuesToMultiMap(req.URL.Query()),
Headers: headersToMap(req.Header),
MultiValueHeaders: headersToMultiMap(req.Header),
Body: reqBody,
IsBase64Encoded: base64Encoded,
RequestContext: events.APIGatewayProxyRequestContext{
Expand Down Expand Up @@ -205,6 +227,12 @@ func (a *awsLambda) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}

for key, values := range resp.MultiValueHeaders {
// NOTE This maybe specific to Content-Type, but it's listed in
// headers and multivalue headers so it ends up getting added twice.
// Is a multivalue header with only one item really multivalue?
if len(values) < 2 {
continue
}
for _, value := range values {
rw.Header().Add(key, value)
}
Expand Down Expand Up @@ -234,7 +262,8 @@ func (a *awsLambda) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}

// bodyToBase64 ensures the request body is base64 encoded.
func bodyToBase64(req *http.Request) (bool, string, error) {
func bodyToBase64(req *http.Request) (bool, string, string, error) {
contentType := ""
base64Encoded := false
body := ""
// base64 encode non-text request body
Expand All @@ -246,15 +275,15 @@ func bodyToBase64(req *http.Request) (bool, string, error) {
// Read the request body and reset it to be read again if needed
bodyBytes, err := io.ReadAll(req.Body)
if err != nil {
return base64Encoded, body, err
return base64Encoded, contentType, body, err
}
req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))

body = string(bodyBytes)

// Any non 'text/*' MIME types should be base64 encoded.
// DetectContentType does not check for 'application/json'
contentType := http.DetectContentType(bodyBytes)
contentType = http.DetectContentType(bodyBytes)
if !strings.HasPrefix(contentType, "text") {
base64Encoded = true
}
Expand All @@ -267,17 +296,17 @@ func bodyToBase64(req *http.Request) (bool, string, error) {

_, err := io.Copy(encoder, bytes.NewReader(bodyBytes))
if err != nil {
return base64Encoded, body, err
return base64Encoded, contentType, body, err
}
if err = encoder.Close(); err != nil {
return base64Encoded, body, err
return base64Encoded, contentType, body, err
}
// Set body to b64 encoded version
body = b64buf.String()
}
}

return base64Encoded, body, nil
return base64Encoded, contentType, body, nil
}

func (a *awsLambda) invokeFunction(ctx context.Context, request events.APIGatewayProxyRequest) (*events.APIGatewayProxyResponse, error) {
Expand Down Expand Up @@ -331,19 +360,6 @@ func headersToMap(h http.Header) map[string]string {
return values
}

func headersToMultiMap(h http.Header) map[string][]string {
values := map[string][]string{}
for name, headers := range h {
if len(headers) < 2 {
continue
}

values[name] = headers
}

return values
}

func valueToString(f interface{}) (string, bool) {
var v string
typeof := reflect.TypeOf(f)
Expand Down Expand Up @@ -429,3 +445,9 @@ func valuesToMultiMap(i url.Values) map[string][]string {

return values
}

// Check if a string looks like JSON.
func isJSON(s string) bool {
var js interface{}
return json.Unmarshal([]byte(s), &js) == nil
}
33 changes: 26 additions & 7 deletions pkg/middlewares/awslambda/aws_lambda_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,7 @@ func Test_AWSLambdaMiddleware_InvokeBasic(t *testing.T) {
assert.Equal(t, "/test/example/path", lReq.Path)
assert.Equal(t, map[string]string{"a": "1", "b": "2"}, lReq.QueryStringParameters)
assert.Equal(t, map[string][]string{"c": {"3", "4"}, "d[]": {"5", "6"}}, lReq.MultiValueQueryStringParameters)
assert.Equal(t, map[string]string{"Content-Type": "text/plain"}, lReq.Headers)
assert.Equal(t, map[string][]string{"X-Test": {"foo", "foobar"}}, lReq.MultiValueHeaders)
assert.Equal(t, map[string]string{"Content-Type": "application/json"}, lReq.Headers)
assert.Equal(t, "This is the body", lReq.Body)

res.WriteHeader(http.StatusOK)
Expand Down Expand Up @@ -144,7 +143,7 @@ func Test_AWSLambdaMiddleware_InvokeBasic(t *testing.T) {
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", "text/plain")
req.Header.Set("Content-Type", "application/json")
req.Header.Add("X-Test", "foo")
req.Header.Add("X-Test", "foobar")

Expand Down Expand Up @@ -178,10 +177,11 @@ func Test_AWSLambdaMiddleware_GetTracingInformation(t *testing.T) {
func Test_AWSLambdaMiddleware_bodyToBase64_empty(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "/", nil)
require.NoError(t, err)
isEncoded, body, err := bodyToBase64(req)
isEncoded, contentType, body, err := bodyToBase64(req)

assert.False(t, isEncoded)
assert.Equal(t, "", body)
assert.Equal(t, "", contentType)
require.NoError(t, err)
}

Expand All @@ -191,10 +191,27 @@ func Test_AWSLambdaMiddleware_bodyToBase64_notEncodedJSON(t *testing.T) {

req, err := http.NewRequest(http.MethodPost, "/", strings.NewReader(reqBody))
require.NoError(t, err)
isEncoded, body, err := bodyToBase64(req)
isEncoded, contentType, body, err := bodyToBase64(req)

assert.False(t, isEncoded)
assert.Equal(t, reqBody, body)
assert.Equal(t, "text/plain; charset=utf-8", contentType)
require.NoError(t, err)
}

func Test_AWSLambdaMiddleware_bodyToBase64_EncodedJSON(t *testing.T) {
bodyBytes, err := json.Marshal(`{"test": "encoded"}`)
if err != nil {
t.Fatal(err)
}

req, err := http.NewRequest(http.MethodPost, "/", strings.NewReader(string(bodyBytes)))
require.NoError(t, err)
isEncoded, contentType, body, err := bodyToBase64(req)

assert.False(t, isEncoded)
assert.Equal(t, string(bodyBytes), body)
assert.Equal(t, "text/plain; charset=utf-8", contentType)
require.NoError(t, err)
}

Expand All @@ -206,10 +223,11 @@ func Test_AWSLambdaMiddleware_bodyToBase64_withcontent(t *testing.T) {

req, err := http.NewRequest(http.MethodPost, "/", strings.NewReader(reqBody))
require.NoError(t, err)
isEncoded, body, err := bodyToBase64(req)
isEncoded, contentType, body, err := bodyToBase64(req)

assert.True(t, isEncoded)
assert.Equal(t, expected, body)
assert.Equal(t, "application/zip", contentType)
require.NoError(t, err)

// image/jpeg
Expand All @@ -218,9 +236,10 @@ func Test_AWSLambdaMiddleware_bodyToBase64_withcontent(t *testing.T) {

req2, err2 := http.NewRequest(http.MethodPost, "/", strings.NewReader(reqBody2))
require.NoError(t, err2)
isEncoded2, body2, err2 := bodyToBase64(req2)
isEncoded2, contentType2, body2, err2 := bodyToBase64(req2)

assert.True(t, isEncoded2)
assert.Equal(t, expected2, body2)
assert.Equal(t, "image/jpeg", contentType2)
require.NoError(t, err2)
}
Loading