From c8327934a6864e950513c3a8d2ced93a00ad9067 Mon Sep 17 00:00:00 2001 From: smellthemoon <64083300+smellthemoon@users.noreply.github.com> Date: Fri, 24 Jan 2025 14:05:13 +0800 Subject: [PATCH] fix: [2.5]not enable rate limiter for restful v1 (#39555) issue: #39556 pr: #39553 Signed-off-by: lixinguo Co-authored-by: lixinguo --- internal/distributed/proxy/httpserver/handler_v1.go | 12 ++++++------ internal/distributed/proxy/httpserver/utils.go | 7 ++++++- internal/proxy/rate_limit_interceptor.go | 7 ++++++- internal/proxy/util.go | 4 +++- 4 files changed, 21 insertions(+), 9 deletions(-) diff --git a/internal/distributed/proxy/httpserver/handler_v1.go b/internal/distributed/proxy/httpserver/handler_v1.go index 22269458c15a3..2776316dccf32 100644 --- a/internal/distributed/proxy/httpserver/handler_v1.go +++ b/internal/distributed/proxy/httpserver/handler_v1.go @@ -533,7 +533,7 @@ func (h *HandlersV1) query(c *gin.Context) { username, _ := c.Get(ContextUsername) ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName) response, err := h.executeRestRequestInterceptor(ctx, c, req, func(reqCtx context.Context, req any) (any, error) { - if _, err := CheckLimiter(ctx, &req, h.proxy); err != nil { + if _, err := CheckLimiter(ctx, req, h.proxy); err != nil { c.AbortWithStatusJSON(http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error() + ", error: " + err.Error(), @@ -611,7 +611,7 @@ func (h *HandlersV1) get(c *gin.Context) { return nil, RestRequestInterceptorErr } queryReq := req.(*milvuspb.QueryRequest) - if _, err := CheckLimiter(ctx, &req, h.proxy); err != nil { + if _, err := CheckLimiter(ctx, req, h.proxy); err != nil { c.AbortWithStatusJSON(http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error() + ", error: " + err.Error(), @@ -691,7 +691,7 @@ func (h *HandlersV1) delete(c *gin.Context) { } deleteReq.Expr = filter } - if _, err := CheckLimiter(ctx, &req, h.proxy); err != nil { + if _, err := CheckLimiter(ctx, req, h.proxy); err != nil { c.AbortWithStatusJSON(http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error() + ", error: " + err.Error(), @@ -774,7 +774,7 @@ func (h *HandlersV1) insert(c *gin.Context) { }) return nil, RestRequestInterceptorErr } - if _, err := CheckLimiter(ctx, &req, h.proxy); err != nil { + if _, err := CheckLimiter(ctx, req, h.proxy); err != nil { c.AbortWithStatusJSON(http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error() + ", error: " + err.Error(), @@ -880,7 +880,7 @@ func (h *HandlersV1) upsert(c *gin.Context) { }) return nil, RestRequestInterceptorErr } - if _, err := CheckLimiter(ctx, &req, h.proxy); err != nil { + if _, err := CheckLimiter(ctx, req, h.proxy); err != nil { c.AbortWithStatusJSON(http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error() + ", error: " + err.Error(), @@ -983,7 +983,7 @@ func (h *HandlersV1) search(c *gin.Context) { username, _ := c.Get(ContextUsername) ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName) response, err := h.executeRestRequestInterceptor(ctx, c, req, func(reqCtx context.Context, req any) (any, error) { - if _, err := CheckLimiter(ctx, &req, h.proxy); err != nil { + if _, err := CheckLimiter(ctx, req, h.proxy); err != nil { c.AbortWithStatusJSON(http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error() + ", error: " + err.Error(), diff --git a/internal/distributed/proxy/httpserver/utils.go b/internal/distributed/proxy/httpserver/utils.go index f7276679b9ec3..3de7ecec3a093 100644 --- a/internal/distributed/proxy/httpserver/utils.go +++ b/internal/distributed/proxy/httpserver/utils.go @@ -1458,7 +1458,12 @@ func CheckLimiter(ctx context.Context, req interface{}, pxy types.ProxyComponent return nil, err } - dbID, collectionIDToPartIDs, rt, n, err := proxy.GetRequestInfo(ctx, req) + request, ok := req.(proto.Message) + if !ok { + return nil, merr.WrapErrParameterInvalidMsg("wrong req format when check limiter") + } + + dbID, collectionIDToPartIDs, rt, n, err := proxy.GetRequestInfo(ctx, request) if err != nil { return nil, err } diff --git a/internal/proxy/rate_limit_interceptor.go b/internal/proxy/rate_limit_interceptor.go index 828f1caac40df..339702e4c445c 100644 --- a/internal/proxy/rate_limit_interceptor.go +++ b/internal/proxy/rate_limit_interceptor.go @@ -22,6 +22,7 @@ import ( "go.uber.org/zap" "google.golang.org/grpc" + "google.golang.org/protobuf/proto" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/types" @@ -38,7 +39,11 @@ import ( // RateLimitInterceptor returns a new unary server interceptors that performs request rate limiting. func RateLimitInterceptor(limiter types.Limiter) grpc.UnaryServerInterceptor { return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { - dbID, collectionIDToPartIDs, rt, n, err := GetRequestInfo(ctx, req) + request, ok := req.(proto.Message) + if !ok { + return nil, merr.WrapErrParameterInvalidMsg("wrong req format when check limiter") + } + dbID, collectionIDToPartIDs, rt, n, err := GetRequestInfo(ctx, request) if err != nil { log.Warn("failed to get request info", zap.Error(err)) return handler(ctx, req) diff --git a/internal/proxy/util.go b/internal/proxy/util.go index 892ef989f10b4..0dd098472df5b 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -19,6 +19,7 @@ package proxy import ( "context" "fmt" + "reflect" "strconv" "strings" "time" @@ -2107,7 +2108,7 @@ func GetCostValue(status *commonpb.Status) int { } // GetRequestInfo returns collection name and rateType of request and return tokens needed. -func GetRequestInfo(ctx context.Context, req interface{}) (int64, map[int64][]int64, internalpb.RateType, int, error) { +func GetRequestInfo(ctx context.Context, req proto.Message) (int64, map[int64][]int64, internalpb.RateType, int, error) { switch r := req.(type) { case *milvuspb.InsertRequest: dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName)) @@ -2185,6 +2186,7 @@ func GetRequestInfo(ctx context.Context, req interface{}) (int64, map[int64][]in if req == nil { return util.InvalidDBID, map[int64][]int64{}, 0, 0, fmt.Errorf("null request") } + log.RatedWarn(60, "not supported request type for rate limiter", zap.String("type", reflect.TypeOf(req).String())) return util.InvalidDBID, map[int64][]int64{}, 0, 0, nil } }