Skip to content

[API] add support for embeddings api #1208

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 52 additions & 23 deletions pkg/controller/modelrouter/modelrouter_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ package modelrouter
import (
"context"
"fmt"
"slices"
"strconv"
"strings"

appsv1 "k8s.io/api/apps/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
Expand All @@ -35,22 +37,34 @@ import (
modelv1alpha1 "github.com/vllm-project/aibrix/api/model/v1alpha1"
orchestrationv1alpha1 "github.com/vllm-project/aibrix/api/orchestration/v1alpha1"
"github.com/vllm-project/aibrix/pkg/config"
aibrixgateway "github.com/vllm-project/aibrix/pkg/plugins/gateway"
gatewayv1 "sigs.k8s.io/gateway-api/apis/v1"
gatewayv1beta1 "sigs.k8s.io/gateway-api/apis/v1beta1"
)

const (
// TODO (varun): cleanup model related identifiers and establish common consensus
modelHeaderIdentifier = "model"
modelIdentifier = "model.aibrix.ai/name"
modelPortIdentifier = "model.aibrix.ai/port"
modelHeaderIdentifier = "model"
modelIdentifier = "model.aibrix.ai/name"
modelPortIdentifier = "model.aibrix.ai/port"
modelSupportedRequestTypeIdentifier = "model.aibrix.ai/supported-request-types"
// TODO (varun): parameterize it or dynamically resolve it
aibrixEnvoyGateway = "aibrix-eg"
aibrixEnvoyGatewayNamespace = "aibrix-system"

defaultModelServingPort = 8000
)

var (
requestTypeIdentifierToSupportedRoutePathPrefix = map[string][]string{
string(aibrixgateway.OpenAiRequestEmbeddingsType): {string(aibrixgateway.OpenAiRequestEmbeddingsPath)},
string(aibrixgateway.OpenAiRequestChatCompletionsType): {string(aibrixgateway.OpenAiRequestCompletionsPath), string(aibrixgateway.OpenAiRequestChatCompletionsPath)},
string(aibrixgateway.OpenAiRequestCompletionsType): {string(aibrixgateway.OpenAiRequestCompletionsPath), string(aibrixgateway.OpenAiRequestChatCompletionsPath)},
}

defaultSupportedRequestType = string(aibrixgateway.OpenAiRequestChatCompletionsType)
)

//+kubebuilder:rbac:groups=apps,resources=deployments,verbs=get;list;watch;create;update;patch;delete
//+kubebuilder:rbac:groups=orchestration.aibrix.ai,resources=rayclusterfleets,verbs=get;list;watch;create;update;patch;delete
//+kubebuilder:rbac:groups=gateway.networking.k8s.io,resources=httproutes,verbs=get;list;watch;create;update;patch;delete
Expand Down Expand Up @@ -107,6 +121,38 @@ func Add(mgr manager.Manager, runtimeConfig config.RuntimeConfig) error {
return err
}

// getSupportedRoutesMatchFromLabelsOrDefault returns the HTTPRouteMatch based on the model route labels value
func getSupportedRoutesMatchFromLabelsOrDefault(labels map[string]string, modelHeaderMatch gatewayv1.HTTPHeaderMatch) []gatewayv1.HTTPRouteMatch {
var pathPrefixes []string
if routesLabelValue, ok := labels[modelSupportedRequestTypeIdentifier]; ok {
routesIdentifier := strings.Split(routesLabelValue, ",")
for id, paths := range requestTypeIdentifierToSupportedRoutePathPrefix {
if slices.Contains(routesIdentifier, id) {
pathPrefixes = append(pathPrefixes, paths...)
}
}
}

// Add the default pathPrefixes if no route defines via labels
if len(pathPrefixes) == 0 {
pathPrefixes = append(pathPrefixes, requestTypeIdentifierToSupportedRoutePathPrefix[defaultSupportedRequestType]...)
}

var routesmatch []gatewayv1.HTTPRouteMatch
for _, path := range pathPrefixes {
routesmatch = append(routesmatch, gatewayv1.HTTPRouteMatch{
Path: &gatewayv1.HTTPPathMatch{
Type: ptr.To(gatewayv1.PathMatchPathPrefix),
Value: ptr.To(path),
},
Headers: []gatewayv1.HTTPHeaderMatch{
modelHeaderMatch,
},
})
}
return routesmatch
}

type ModelRouter struct {
client.Client
Scheme *runtime.Scheme
Expand Down Expand Up @@ -192,6 +238,8 @@ func (m *ModelRouter) createHTTPRoute(namespace string, labels map[string]string
Value: modelName,
}

httpRoutesMatch := getSupportedRoutesMatchFromLabelsOrDefault(labels, modelHeaderMatch)

httpRoute := gatewayv1.HTTPRoute{
ObjectMeta: metav1.ObjectMeta{
Name: fmt.Sprintf("%s-router", modelName),
Expand All @@ -208,26 +256,7 @@ func (m *ModelRouter) createHTTPRoute(namespace string, labels map[string]string
},
Rules: []gatewayv1.HTTPRouteRule{
{
Matches: []gatewayv1.HTTPRouteMatch{
{
Path: &gatewayv1.HTTPPathMatch{
Type: ptr.To(gatewayv1.PathMatchPathPrefix),
Value: ptr.To("/v1/completions"),
},
Headers: []gatewayv1.HTTPHeaderMatch{
modelHeaderMatch,
},
},
{
Path: &gatewayv1.HTTPPathMatch{
Type: ptr.To(gatewayv1.PathMatchPathPrefix),
Value: ptr.To("/v1/chat/completions"),
},
Headers: []gatewayv1.HTTPHeaderMatch{
modelHeaderMatch,
},
},
},
Matches: httpRoutesMatch,
BackendRefs: []gatewayv1.HTTPBackendRef{
{
BackendRef: gatewayv1.BackendRef{
Expand Down
8 changes: 5 additions & 3 deletions pkg/plugins/gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
var respErrorCode int
var model string
var requestPath string
var requestType OpenAiRequestType
var routingAlgorithm types.RoutingAlgorithm
var routerCtx *types.RoutingContext
var stream, isRespError bool
Expand Down Expand Up @@ -113,7 +114,7 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
resp, user, rpm, routingAlgorithm, requestPath = s.HandleRequestHeaders(ctx, requestID, req)

case *extProcPb.ProcessingRequest_RequestBody:
resp, model, routerCtx, stream, traceTerm = s.HandleRequestBody(ctx, requestID, requestPath, req, user, routingAlgorithm)
resp, model, routerCtx, stream, requestType, traceTerm = s.HandleRequestBody(ctx, requestID, requestPath, req, user, routingAlgorithm)
if routerCtx != nil {
ctx = routerCtx
}
Expand All @@ -135,7 +136,7 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
resp = s.responseErrorProcessing(ctx, resp, respErrorCode, model, requestID,
string(req.Request.(*extProcPb.ProcessingRequest_ResponseBody).ResponseBody.GetBody()))
} else {
resp, completed = s.HandleResponseBody(ctx, requestID, req, user, rpm, model, stream, traceTerm, completed)
resp, completed = s.HandleResponseBody(ctx, requestID, req, requestType, user, rpm, model, stream, traceTerm, completed)
}
default:
klog.Infof("Unknown Request type %+v\n", v)
Expand Down Expand Up @@ -205,7 +206,8 @@ func (s *Server) validateHTTPRouteStatus(ctx context.Context, model string) erro
}

func (s *Server) responseErrorProcessing(ctx context.Context, resp *extProcPb.ProcessingResponse, respErrorCode int,
model, requestID, errMsg string) *extProcPb.ProcessingResponse {
model, requestID, errMsg string,
) *extProcPb.ProcessingResponse {
httprouteErr := s.validateHTTPRouteStatus(ctx, model)
if errMsg != "" && httprouteErr != nil {
errMsg = fmt.Sprintf("%s. %s", errMsg, httprouteErr.Error())
Expand Down
26 changes: 16 additions & 10 deletions pkg/plugins/gateway/gateway_req_body.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,27 @@ import (
)

func (s *Server) HandleRequestBody(ctx context.Context, requestID string, requestPath string, req *extProcPb.ProcessingRequest,
user utils.User, routingAlgorithm types.RoutingAlgorithm) (*extProcPb.ProcessingResponse, string, *types.RoutingContext, bool, int64) {
user utils.User, routingAlgorithm types.RoutingAlgorithm,
) (*extProcPb.ProcessingResponse, string, *types.RoutingContext, bool, OpenAiRequestType, int64) {
var routingCtx *types.RoutingContext
var term int64 // Identify the trace window

requestType := NewOpenAiRequestTypeFromPath(requestPath)

body := req.Request.(*extProcPb.ProcessingRequest_RequestBody)
model, message, stream, errRes := validateRequestBody(requestID, requestPath, body.RequestBody.GetBody(), user)
model, message, stream, errRes := validateRequestBody(requestID, requestType, body.RequestBody.GetBody(), user)
if errRes != nil {
return errRes, model, routingCtx, stream, term
return errRes, model, routingCtx, stream, requestType, term
}

// early reject the request if model doesn't exist.
if !s.cache.HasModel(model) {
klog.ErrorS(nil, "model doesn't exist in cache, probably wrong model name", "requestID", requestID, "model", model)
return generateErrorResponse(envoyTypePb.StatusCode_BadRequest,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: HeaderErrorNoModelBackends, RawValue: []byte(model)}}},
fmt.Sprintf("model %s does not exist", model)), model, routingCtx, stream, term
Key: HeaderErrorNoModelBackends, RawValue: []byte(model),
}}},
fmt.Sprintf("model %s does not exist", model)), model, routingCtx, stream, requestType, term
}

// early reject if no pods are ready to accept request for a model
Expand All @@ -56,8 +60,9 @@ func (s *Server) HandleRequestBody(ctx context.Context, requestID string, reques
klog.ErrorS(err, "no ready pod available", "requestID", requestID, "model", model)
return generateErrorResponse(envoyTypePb.StatusCode_ServiceUnavailable,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: HeaderErrorNoModelBackends, RawValue: []byte("true")}}},
fmt.Sprintf("error on getting pods for model %s", model)), model, routingCtx, stream, term
Key: HeaderErrorNoModelBackends, RawValue: []byte("true"),
}}},
fmt.Sprintf("error on getting pods for model %s", model)), model, routingCtx, stream, requestType, term
}

routingCtx = types.NewRoutingContext(ctx, routingAlgorithm, model, message, requestID, user.Name)
Expand All @@ -72,8 +77,9 @@ func (s *Server) HandleRequestBody(ctx context.Context, requestID string, reques
return generateErrorResponse(
envoyTypePb.StatusCode_ServiceUnavailable,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: HeaderErrorRouting, RawValue: []byte("true")}}},
"error on selecting target pod"), model, routingCtx, stream, term
Key: HeaderErrorRouting, RawValue: []byte("true"),
}}},
"error on selecting target pod"), model, routingCtx, stream, requestType, term
}
headers = buildEnvoyProxyHeaders(headers,
HeaderRoutingStrategy, string(routingAlgorithm),
Expand All @@ -93,5 +99,5 @@ func (s *Server) HandleRequestBody(ctx context.Context, requestID string, reques
},
},
},
}, model, routingCtx, stream, term
}, model, routingCtx, stream, requestType, term
}
159 changes: 158 additions & 1 deletion pkg/plugins/gateway/gateway_rsp_body.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,26 @@ import (
"github.com/vllm-project/aibrix/pkg/utils"
)

func (s *Server) HandleResponseBody(ctx context.Context, requestID string, req *extProcPb.ProcessingRequest, user utils.User, rpm int64, model string, stream bool, traceTerm int64, hasCompleted bool) (*extProcPb.ProcessingResponse, bool) {
func (s *Server) HandleResponseBody(ctx context.Context, requestID string, req *extProcPb.ProcessingRequest, requestType OpenAiRequestType, user utils.User, rpm int64, model string, stream bool, traceTerm int64, hasCompleted bool) (*extProcPb.ProcessingResponse, bool) {
b := req.Request.(*extProcPb.ProcessingRequest_ResponseBody)

switch requestType {
case OpenAiRequestChatCompletionsType, OpenAiRequestCompletionsType:
return s.handleChatCompletionsResponseBody(ctx, requestID, b, user, rpm, model, stream, traceTerm, hasCompleted)
case OpenAiRequestEmbeddingsType:
return s.handleEmbeddingsResponseBody(ctx, requestID, b, user, rpm, model, false, traceTerm, hasCompleted)
default:
// all other openAi request types (e.g. audio, image, ..) are not supported yet
return generateErrorResponse(
envoyTypePb.StatusCode_NotImplemented,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: HeaderErrorResponseUnknown, RawValue: []byte("true"),
}}},
"request type not supported"), true
}
}

func (s *Server) handleChatCompletionsResponseBody(ctx context.Context, requestID string, b *extProcPb.ProcessingRequest_ResponseBody, user utils.User, rpm int64, model string, stream bool, traceTerm int64, hasCompleted bool) (*extProcPb.ProcessingResponse, bool) {
var res openai.ChatCompletion
var usage openai.CompletionUsage
var promptTokens, completionTokens int64
Expand Down Expand Up @@ -203,3 +220,143 @@ func (s *Server) HandleResponseBody(ctx context.Context, requestID string, req *
},
}, complete
}

func (s *Server) handleEmbeddingsResponseBody(ctx context.Context, requestID string, b *extProcPb.ProcessingRequest_ResponseBody, user utils.User, rpm int64, model string, stream bool, traceTerm int64, hasCompleted bool) (*extProcPb.ProcessingResponse, bool) {
var res openai.CreateEmbeddingResponse
var usage openai.CreateEmbeddingResponseUsage
var promptTokens, completionTokens int64
var headers []*configPb.HeaderValueOption
complete := hasCompleted
routerCtx, _ := ctx.(*types.RoutingContext)

defer func() {
// Wrapped in a function to delay the evaluation of parameters. Using complete to make sure DoneRequestTrace only call once for a request.
if !hasCompleted && complete {
s.cache.DoneRequestTrace(routerCtx, requestID, model, promptTokens, completionTokens, traceTerm)
if routerCtx != nil {
routerCtx.Delete()
}
}
}()

// Use request ID as a key to store per-request buffer
// Retrieve or create buffer
buf, _ := requestBuffers.LoadOrStore(requestID, &bytes.Buffer{})
buffer := buf.(*bytes.Buffer)
// Append data to per-request buffer
buffer.Write(b.ResponseBody.Body)

if !b.ResponseBody.EndOfStream {
// Partial data received, wait for more chunks, we just return a common response here.
return &extProcPb.ProcessingResponse{
Response: &extProcPb.ProcessingResponse_ResponseBody{
ResponseBody: &extProcPb.BodyResponse{
Response: &extProcPb.CommonResponse{},
},
},
}, complete
}

// Last part received, process the full response
finalBody := buffer.Bytes()
// Clean up the buffer after final processing
requestBuffers.Delete(requestID)

if err := json.Unmarshal(finalBody, &res); err != nil {
klog.ErrorS(err, "error to unmarshal response", "requestID", requestID, "responseBody", string(b.ResponseBody.GetBody()))
complete = true
return generateErrorResponse(
envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: HeaderErrorResponseUnmarshal, RawValue: []byte("true"),
}}},
err.Error()), complete
} else if len(res.Model) == 0 {
msg := ErrorUnknownResponse.Error()
responseBodyContent := string(b.ResponseBody.GetBody())
if len(responseBodyContent) != 0 {
msg = responseBodyContent
}
klog.ErrorS(err, "unexpected response", "requestID", requestID, "responseBody", responseBodyContent)
complete = true
return generateErrorResponse(
envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: HeaderErrorResponseUnknown, RawValue: []byte("true"),
}}},
msg), complete
}
// Do not overwrite model, res can be empty.
usage = res.Usage

var requestEnd string
if usage.TotalTokens != 0 {
complete = true
// Update promptTokens and completeTokens
promptTokens = usage.PromptTokens
completionTokens = 0 // no completion tokens in embeddings request
// Count token per user.
if user.Name != "" {
tpm, err := s.ratelimiter.Incr(ctx, fmt.Sprintf("%v_TPM_CURRENT", user.Name), res.Usage.TotalTokens)
if err != nil {
return generateErrorResponse(
envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: HeaderErrorIncrTPM, RawValue: []byte("true"),
}}},
err.Error()), complete
}

headers = append(headers,
&configPb.HeaderValueOption{
Header: &configPb.HeaderValue{
Key: HeaderUpdateRPM,
RawValue: []byte(fmt.Sprintf("%d", rpm)),
},
},
&configPb.HeaderValueOption{
Header: &configPb.HeaderValue{
Key: HeaderUpdateTPM,
RawValue: []byte(fmt.Sprintf("%d", tpm)),
},
},
)
requestEnd = fmt.Sprintf(requestEnd+"rpm: %d, tpm: %d, ", rpm, tpm)
}

if routerCtx != nil && routerCtx.HasRouted() {
targetPodIP := routerCtx.TargetAddress()
headers = append(headers,
&configPb.HeaderValueOption{
Header: &configPb.HeaderValue{
Key: HeaderTargetPod,
RawValue: []byte(targetPodIP),
},
},
&configPb.HeaderValueOption{
Header: &configPb.HeaderValue{
Key: HeaderRequestID,
RawValue: []byte(requestID),
},
},
)
requestEnd = fmt.Sprintf(requestEnd+"targetPod: %s", targetPodIP)
}

klog.Infof("request end, requestID: %s - %s", requestID, requestEnd)
} else if b.ResponseBody.EndOfStream {
complete = true
}

return &extProcPb.ProcessingResponse{
Response: &extProcPb.ProcessingResponse_ResponseBody{
ResponseBody: &extProcPb.BodyResponse{
Response: &extProcPb.CommonResponse{
HeaderMutation: &extProcPb.HeaderMutation{
SetHeaders: headers,
},
},
},
},
}, complete
}
Loading
Loading