Skip to content

Commit

Permalink
Update error handling to return APIError (#295)
Browse files Browse the repository at this point in the history
  • Loading branch information
juancwu authored Dec 26, 2024
2 parents 52f0d7a + 4a22068 commit 19f6523
Show file tree
Hide file tree
Showing 9 changed files with 210 additions and 192 deletions.
61 changes: 32 additions & 29 deletions backend/internal/middleware/file_check.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package middleware

import (
"KonferCA/SPUR/internal/v1/v1_common"
"fmt"
"mime/multipart"
"net/http"
Expand All @@ -15,14 +16,15 @@ FileConfig holds configuration for the file validation middleware.
MinSize and MaxSize are in bytes.
Example:
1MB = 1 * 1024 * 1024 bytes
10MB = 10 * 1024 * 1024 bytes
1MB = 1 * 1024 * 1024 bytes
10MB = 10 * 1024 * 1024 bytes
*/
type FileConfig struct {
MinSize int64
MaxSize int64
AllowedTypes []string // ex. ["image/jpeg", "image/png", "application/pdf"]
StrictValidation bool // If true, always verify content type matches header
StrictValidation bool // If true, always verify content type matches header
}

/*
Expand All @@ -31,15 +33,16 @@ FileCheck middleware ensures uploaded files meet specified criteria:
- MIME type validation
Usage:
e.POST("/upload", handler, middleware.FileCheck(middleware.FileConfig{
MinSize: 1024, // 1KB minimum
MaxSize: 10485760, // 10MB maximum
AllowedTypes: []string{
"image/jpeg",
"image/png",
"application/pdf",
},
}))
e.POST("/upload", handler, middleware.FileCheck(middleware.FileConfig{
MinSize: 1024, // 1KB minimum
MaxSize: 10485760, // 10MB maximum
AllowedTypes: []string{
"image/jpeg",
"image/png",
"application/pdf",
},
}))
*/
func FileCheck(config FileConfig) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
Expand All @@ -52,17 +55,17 @@ func FileCheck(config FileConfig) echo.MiddlewareFunc {
// first check content-length as early rejection
contentLength := c.Request().ContentLength
if contentLength == -1 {
return echo.NewHTTPError(http.StatusBadRequest, "content length required")
return v1_common.Fail(c, http.StatusBadRequest, "content length required", nil)
}

if contentLength > config.MaxSize {
return echo.NewHTTPError(http.StatusRequestEntityTooLarge,
fmt.Sprintf("file size %d exceeds maximum allowed size of %d", contentLength, config.MaxSize))
return v1_common.Fail(c, http.StatusRequestEntityTooLarge,
fmt.Sprintf("file size %d exceeds maximum allowed size of %d", contentLength, config.MaxSize), nil)
}

// parse multipart form with max size limit to prevent memory exhaustion
if err := c.Request().ParseMultipartForm(config.MaxSize); err != nil {
return echo.NewHTTPError(http.StatusRequestEntityTooLarge, "file too large")
return v1_common.Fail(c, http.StatusRequestEntityTooLarge, "file too large", err)
}

// check actual file sizes and MIME types
Expand All @@ -87,12 +90,12 @@ func validateFile(file *multipart.FileHeader, config FileConfig) error {
// Check file size
size := file.Size
if size > config.MaxSize {
return echo.NewHTTPError(http.StatusRequestEntityTooLarge,
fmt.Sprintf("file %s size %d exceeds maximum allowed size of %d", file.Filename, size, config.MaxSize))
return v1_common.NewError(v1_common.ErrorTypeValidation, http.StatusRequestEntityTooLarge,
fmt.Sprintf("file %s size %d exceeds maximum allowed size of %d", file.Filename, size, config.MaxSize), "")
}
if size < config.MinSize {
return echo.NewHTTPError(http.StatusBadRequest,
fmt.Sprintf("file %s size %d below minimum required size of %d", file.Filename, size, config.MinSize))
return v1_common.NewError(v1_common.ErrorTypeValidation, http.StatusBadRequest,
fmt.Sprintf("file %s size %d below minimum required size of %d", file.Filename, size, config.MinSize), "")
}

// Check MIME type if restrictions are specified
Expand All @@ -104,22 +107,22 @@ func validateFile(file *multipart.FileHeader, config FileConfig) error {
if declaredType == "" || config.StrictValidation {
f, err := file.Open()
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "could not read file")
return v1_common.NewError(v1_common.ErrorTypeValidation, http.StatusBadRequest, "could not read file", "")
}
defer f.Close()

mime, err := mimetype.DetectReader(f)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "could not detect file type")
return v1_common.NewError(v1_common.ErrorTypeValidation, http.StatusBadRequest, "could not detect file type", "")
}

actualType := mime.String()

// If we have both types, verify they match (when strict validation is enabled)
if declaredType != "" && config.StrictValidation && !strings.EqualFold(declaredType, actualType) {
return echo.NewHTTPError(http.StatusBadRequest,
fmt.Sprintf("declared Content-Type (%s) doesn't match actual content type (%s)",
declaredType, actualType))
return v1_common.NewError(v1_common.ErrorTypeValidation, http.StatusBadRequest,
fmt.Sprintf("declared Content-Type (%s) doesn't match actual content type (%s)",
declaredType, actualType), "")
}

// Use actual type if no declared type, otherwise use declared type
Expand All @@ -137,11 +140,11 @@ func validateFile(file *multipart.FileHeader, config FileConfig) error {
}

if !isAllowed {
return echo.NewHTTPError(http.StatusBadRequest,
fmt.Sprintf("file type %s not allowed for %s. Allowed types: %v",
declaredType, file.Filename, config.AllowedTypes))
return v1_common.NewError(v1_common.ErrorTypeValidation, http.StatusBadRequest,
fmt.Sprintf("file type %s not allowed for %s. Allowed types: %v",
declaredType, file.Filename, config.AllowedTypes), "")
}
}

return nil
}
}
34 changes: 18 additions & 16 deletions backend/internal/middleware/file_check_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package middleware

import (
"KonferCA/SPUR/internal/v1/v1_common"
"bytes"
"fmt"
"mime/multipart"
Expand All @@ -15,7 +16,7 @@ import (

func TestFileCheck(t *testing.T) {
e := echo.New()

handler := func(c echo.Context) error {
return c.String(http.StatusOK, "success")
}
Expand All @@ -24,19 +25,19 @@ func TestFileCheck(t *testing.T) {
createMultipartRequest := func(filename string, content []byte, contentType string) (*http.Request, error) {
body := new(bytes.Buffer)
writer := multipart.NewWriter(body)

// Create form file with headers
h := make(textproto.MIMEHeader)
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, "file", filename))
if contentType != "" {
h.Set("Content-Type", contentType)
}

part, err := writer.CreatePart(h)
if err != nil {
return nil, err
}

part.Write(content)
writer.Close()

Expand All @@ -48,7 +49,7 @@ func TestFileCheck(t *testing.T) {

// Sample file contents with proper headers
jpegHeader := []byte{
0xFF, 0xD8, 0xFF, 0xE0, 0x00, 0x10, 0x4A, 0x46,
0xFF, 0xD8, 0xFF, 0xE0, 0x00, 0x10, 0x4A, 0x46,
0x49, 0x46, 0x00, 0x01,
}
pngHeader := []byte{
Expand Down Expand Up @@ -78,7 +79,7 @@ func TestFileCheck(t *testing.T) {
StrictValidation: true,
},
filename: "test.jpg",
content: append(jpegHeader, []byte("dummy content")...),
content: append(jpegHeader, []byte("dummy content")...),
contentType: "image/jpeg",
expectedStatus: http.StatusOK,
},
Expand All @@ -91,7 +92,7 @@ func TestFileCheck(t *testing.T) {
StrictValidation: false,
},
filename: "test.png",
content: append(pngHeader, []byte("dummy content")...),
content: append(pngHeader, []byte("dummy content")...),
expectedStatus: http.StatusOK,
},
{
Expand All @@ -103,7 +104,7 @@ func TestFileCheck(t *testing.T) {
StrictValidation: true,
},
filename: "test.jpg",
content: append(pngHeader, []byte("dummy content")...),
content: append(pngHeader, []byte("dummy content")...),
contentType: "image/jpeg",
expectedStatus: http.StatusBadRequest,
expectedError: "doesn't match actual content type",
Expand All @@ -116,7 +117,7 @@ func TestFileCheck(t *testing.T) {
AllowedTypes: []string{"image/jpeg"},
},
filename: "large.jpg",
content: append(jpegHeader, bytes.Repeat([]byte("a"), 150)...),
content: append(jpegHeader, bytes.Repeat([]byte("a"), 150)...),
contentType: "image/jpeg",
expectedStatus: http.StatusRequestEntityTooLarge,
expectedError: "file size",
Expand All @@ -129,7 +130,7 @@ func TestFileCheck(t *testing.T) {
AllowedTypes: []string{"image/jpeg"},
},
filename: "small.jpg",
content: append(jpegHeader, []byte("tiny")...),
content: append(jpegHeader, []byte("tiny")...),
contentType: "image/jpeg",
expectedStatus: http.StatusBadRequest,
expectedError: "below minimum required size",
Expand All @@ -142,7 +143,7 @@ func TestFileCheck(t *testing.T) {
AllowedTypes: []string{"image/jpeg", "image/png"},
},
filename: "document.pdf",
content: append(pdfHeader, []byte("dummy content")...),
content: append(pdfHeader, []byte("dummy content")...),
contentType: "application/pdf",
expectedStatus: http.StatusBadRequest,
expectedError: "file type",
Expand All @@ -155,7 +156,7 @@ func TestFileCheck(t *testing.T) {
AllowedTypes: []string{"image/jpeg", "image/png", "application/pdf"},
},
filename: "document.pdf",
content: append(pdfHeader, []byte("dummy content")...),
content: append(pdfHeader, []byte("dummy content")...),
contentType: "application/pdf",
expectedStatus: http.StatusOK,
},
Expand All @@ -168,7 +169,7 @@ func TestFileCheck(t *testing.T) {
StrictValidation: true,
},
filename: "document.pdf",
content: append(pdfHeader, []byte("dummy content")...),
content: append(pdfHeader, []byte("dummy content")...),
contentType: "application/pdf",
expectedStatus: http.StatusOK,
},
Expand All @@ -186,7 +187,7 @@ func TestFileCheck(t *testing.T) {
err = h(c)

if tt.expectedStatus != http.StatusOK {
he, ok := err.(*echo.HTTPError)
he, ok := err.(*v1_common.APIError)
assert.True(t, ok)
assert.Equal(t, tt.expectedStatus, he.Code)
if tt.expectedError != "" {
Expand All @@ -208,9 +209,10 @@ func TestFileCheck(t *testing.T) {
MinSize: 5,
MaxSize: 100,
})(handler)

err := h(c)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
})
}
}

18 changes: 10 additions & 8 deletions backend/internal/middleware/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"KonferCA/SPUR/db"
"KonferCA/SPUR/internal/jwt"
"KonferCA/SPUR/internal/v1/v1_common"

"github.com/jackc/pgx/v5/pgxpool"
"github.com/labstack/echo/v4"
Expand All @@ -27,24 +28,24 @@ func AuthWithConfig(config AuthConfig, dbPool *pgxpool.Pool) echo.MiddlewareFunc
// get the authorization header
auth := c.Request().Header.Get(echo.HeaderAuthorization)
if auth == "" {
return echo.NewHTTPError(http.StatusUnauthorized, "missing authorization header")
return v1_common.Fail(c, http.StatusUnauthorized, "missing authorization header", nil)
}

// check bearer format
parts := strings.Split(auth, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
return echo.NewHTTPError(http.StatusUnauthorized, "invalid authorization format")
return v1_common.Fail(c, http.StatusUnauthorized, "invalid authorization format", nil)
}

// get user salt from db using claims
claims, err := jwt.ParseUnverifiedClaims(parts[1])
if err != nil {
return echo.NewHTTPError(http.StatusUnauthorized, "invalid token")
return v1_common.Fail(c, http.StatusUnauthorized, "invalid token", err)
}

// validate token type
if claims.TokenType != config.AcceptTokenType {
return echo.NewHTTPError(http.StatusUnauthorized, "invalid token type")
return v1_common.Fail(c, http.StatusUnauthorized, "invalid token type", nil)
}

// check if user role is allowed
Expand All @@ -56,19 +57,19 @@ func AuthWithConfig(config AuthConfig, dbPool *pgxpool.Pool) echo.MiddlewareFunc
}
}
if !roleValid {
return echo.NewHTTPError(http.StatusForbidden, "insufficient permissions")
return v1_common.Fail(c, http.StatusForbidden, "insufficient permissions", nil)
}

// get user's token salt and user data from db
user, err := queries.GetUserByID(c.Request().Context(), claims.UserID)
if err != nil {
return echo.NewHTTPError(http.StatusUnauthorized, "invalid token")
return v1_common.Fail(c, http.StatusUnauthorized, "invalid token", nil)
}

// verify token with user's salt
claims, err = jwt.VerifyTokenWithSalt(parts[1], user.TokenSalt)
if err != nil {
return echo.NewHTTPError(http.StatusUnauthorized, "invalid token")
return v1_common.Fail(c, http.StatusUnauthorized, "invalid token", nil)
}

// store claims and user in context for handlers
Expand All @@ -81,4 +82,5 @@ func AuthWithConfig(config AuthConfig, dbPool *pgxpool.Pool) echo.MiddlewareFunc
return next(c)
}
}
}
}

9 changes: 7 additions & 2 deletions backend/internal/middleware/rate_limit.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package middleware

import (
"KonferCA/SPUR/internal/v1/v1_common"
"net/http"
"sync"
"time"
Expand Down Expand Up @@ -122,9 +123,11 @@ func (rl *RateLimiter) RateLimit() echo.MiddlewareFunc {
Dur("remaining", remaining).
Msg("request blocked: rate limit exceeded")

return echo.NewHTTPError(
return v1_common.Fail(
c,
http.StatusTooManyRequests,
"too many requests, please try again in "+remaining.Round(time.Second).String(),
nil,
)
}

Expand All @@ -151,9 +154,11 @@ func (rl *RateLimiter) RateLimit() echo.MiddlewareFunc {
Dur("block_duration", blockDuration).
Msg("IP blocked: rate limit exceeded")

return echo.NewHTTPError(
return v1_common.Fail(
c,
http.StatusTooManyRequests,
"too many requests, please try again in "+blockDuration.Round(time.Second).String(),
nil,
)
}

Expand Down
Loading

0 comments on commit 19f6523

Please sign in to comment.