diff --git a/router/gin/endpoint.go b/router/gin/endpoint.go index 53eb1f379..5f414f730 100644 --- a/router/gin/endpoint.go +++ b/router/gin/endpoint.go @@ -21,10 +21,19 @@ const requestParamsAsterisk string = "*" // HandlerFactory creates a handler function that adapts the gin router with the injected proxy type HandlerFactory func(*config.EndpointConfig, proxy.Proxy) gin.HandlerFunc -// EndpointHandler implements the HandleFactory interface using the default ToHTTPError function +// ErrorResponseWriter writes the string representation of an error into the response body +// and sets a Content-Type header for errors that implement the encodedResponseError interface. +var ErrorResponseWriter = func(c *gin.Context, statusCode int, err error) { + if te, ok := err.(encodedResponseError); ok && te.Encoding() != "" { + c.Header("Content-Type", te.Encoding()) + } + c.Writer.WriteString(err.Error()) +} + +// EndpointHandler implements the HandlerFactory interface using the default ToHTTPError function var EndpointHandler = CustomErrorEndpointHandler(logging.NoOp, server.DefaultToHTTPError) -// CustomErrorEndpointHandler returns a HandleFactory using the injected ToHTTPError function and logger +// CustomErrorEndpointHandler returns a HandlerFactory using the injected ToHTTPError function and logger func CustomErrorEndpointHandler(logger logging.Logger, errF server.ToHTTPError) HandlerFactory { return func(configuration *config.EndpointConfig, prxy proxy.Proxy) gin.HandlerFunc { cacheControlHeaderValue := fmt.Sprintf("public, max-age=%d", int(configuration.CacheTTL.Seconds())) @@ -83,16 +92,15 @@ func CustomErrorEndpointHandler(logger logging.Logger, errF server.ToHTTPError) } if response == nil { + var statusCode int if t, ok := err.(responseError); ok { - c.Status(t.StatusCode()) + statusCode = t.StatusCode() } else { - c.Status(errF(err)) + statusCode = errF(err) } + c.Status(statusCode) if returnErrorMsg { - if te, ok := err.(encodedResponseError); ok && te.Encoding() != "" { - c.Header("Content-Type", te.Encoding()) - } - c.Writer.WriteString(err.Error()) + ErrorResponseWriter(c, statusCode, err) } cancel() return diff --git a/router/gin/engine.go b/router/gin/engine.go index 1dab75a27..61a8de9e1 100644 --- a/router/gin/engine.go +++ b/router/gin/engine.go @@ -4,6 +4,7 @@ package gin import ( "encoding/json" + "errors" "fmt" "io" "net/http" @@ -136,12 +137,12 @@ func paramChecker() gin.HandlerFunc { for _, param := range c.Params { s, err := url.PathUnescape(param.Value) if err != nil { - c.String(http.StatusBadRequest, fmt.Sprintf("error: %s", err)) + ErrorResponseWriter(c, http.StatusBadRequest, fmt.Errorf("error: %s", err)) c.AbortWithStatus(http.StatusBadRequest) return } if s != param.Value || strings.Contains(s, "?") || strings.Contains(s, "#") { - c.String(http.StatusBadRequest, "error: encoded url params") + ErrorResponseWriter(c, http.StatusBadRequest, errors.New("error: encoded url params")) c.AbortWithStatus(http.StatusBadRequest) return }