diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 812613c00..ac4cd2ce4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,3 +16,5 @@ repos: rev: 7.0.0 hooks: - id: flake8 + name: flake8 + args: ['--max-line-length=120'] diff --git a/clients/python/lorax/client.py b/clients/python/lorax/client.py index 340a9b5b4..a57955524 100644 --- a/clients/python/lorax/client.py +++ b/clients/python/lorax/client.py @@ -10,7 +10,9 @@ Response, Request, Parameters, - MergedAdapters, ResponseFormat, + MergedAdapters, + ResponseFormat, + EmbedResponse ) from lorax.errors import parse_error @@ -58,6 +60,7 @@ def __init__( HTTP requests session object to reuse """ self.base_url = base_url + self.embed_endpoint = f"{base_url}/embed" self.headers = headers self.cookies = cookies self.timeout = timeout @@ -345,6 +348,34 @@ def generate_stream( raise parse_error(resp.status_code, json_payload) yield response + + def embed(self, inputs: str) -> EmbedResponse: + """ + Given inputs, embed the text using the model + + Args: + inputs (`str`): + Input text + + Returns: + Embeddings: computed embeddings + """ + request = Request(inputs=inputs) + + resp = requests.post( + self.embed_endpoint, + json=request.dict(by_alias=True), + headers=self.headers, + cookies=self.cookies, + timeout=self.timeout, + ) + + payload = resp.json() + if resp.status_code != 200: + raise parse_error(resp.status_code, resp.json()) + + return EmbedResponse(**payload) + class AsyncClient: """Asynchronous Client to make calls to a LoRAX instance @@ -387,6 +418,7 @@ def __init__( Timeout in seconds """ self.base_url = base_url + self.embed_endpoint = f"{base_url}/embed" self.headers = headers self.cookies = cookies self.timeout = ClientTimeout(timeout * 60) @@ -661,3 +693,26 @@ async def generate_stream( # If we failed to parse the payload, then it is an error payload raise parse_error(resp.status, json_payload) yield response + + + async def embed(self, inputs: str) -> EmbedResponse: + """ + Given inputs, embed the text using the model + + Args: + inputs (`str`): + Input text + + Returns: + Embeddings: computed embeddings + """ + request = Request(inputs=inputs) + async with ClientSession( + headers=self.headers, cookies=self.cookies, timeout=self.timeout + ) as session: + async with session.post(self.embed_endpoint, json=request.dict(by_alias=True)) as resp: + payload = await resp.json() + + if resp.status != 200: + raise parse_error(resp.status, payload) + return EmbedResponse(**payload) \ No newline at end of file diff --git a/clients/python/lorax/types.py b/clients/python/lorax/types.py index bd6f5ff49..e662a5403 100644 --- a/clients/python/lorax/types.py +++ b/clients/python/lorax/types.py @@ -336,6 +336,10 @@ class StreamResponse(BaseModel): class DeployedModel(BaseModel): model_id: str sha: str - - # Suppress pydantic warning over `model_id` field. + # Suppress pydantic warning over `model_id` field model_config = ConfigDict(protected_namespaces=()) + + +class EmbedResponse(BaseModel): + # Embeddings + embeddings: Optional[List[float]] diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 4073f72c4..0626fd669 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -152,6 +152,12 @@ struct Args { #[clap(long, env)] sharded: Option, + /// Whether this model is mean for embeddings or text generation. + /// By default models are for text generation. + /// Setting it to `true` will enable the embedding endpoints and disable the generation ones. + #[clap(long, env)] + embedding_model: Option, + /// The number of shards to use if you don't want to use all GPUs on a given machine. /// You can use `CUDA_VISIBLE_DEVICES=0,1 lorax-launcher... --num_shard 2` /// and `CUDA_VISIBLE_DEVICES=2,3 lorax-launcher... --num_shard 2` to @@ -1119,6 +1125,10 @@ fn spawn_webserver( router_args.push(origin.to_string()); } + if args.embedding_model.unwrap_or(false) { + router_args.push("--embedding-model".to_string()); + } + // Ngrok if args.ngrok { router_args.push("--ngrok".to_string()); diff --git a/proto/generate.proto b/proto/generate.proto index 2470b065b..b47f2cd29 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -15,6 +15,8 @@ service LoraxService { rpc Warmup (WarmupRequest) returns (WarmupResponse); /// Prefill batch and decode first token rpc Prefill (PrefillRequest) returns (PrefillResponse); + /// Embed + rpc Embed (EmbedRequest) returns (EmbedResponse); /// Decode token for a list of prefilled batches rpc Decode (DecodeRequest) returns (DecodeResponse); /// Health check @@ -230,6 +232,19 @@ message DecodeResponse { optional CachedBatch batch = 2; } +message EmbedRequest { + string inputs = 1; +} + +message Embedding { + repeated float values = 1; +} + +message EmbedResponse { + Embedding embeddings = 1; + string errorMsg = 2; +} + message WarmupRequest { /// Batch to warmup on Batch batch = 1; @@ -309,7 +324,7 @@ message DownloadAdapterResponse { /// Fraction of the adapter memory limit consumed by the adapter. /// If no limit is set, will return 0. - /// When the total across all loaded adapters exceeds + /// When the total across all loaded adapters exceeds /// the adapter_memory_fraction limit, no more adapters /// will be loaded to GPU and LoRAX will begin swapping. float memory_fraction = 2; diff --git a/router/client/src/client.rs b/router/client/src/client.rs index ee68c6d96..6b7157885 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -64,6 +64,14 @@ impl Client { Ok(response) } + /// Embed + #[instrument(skip(self))] + pub async fn embed(&mut self, inputs: String) -> Result { + let request = tonic::Request::new(EmbedRequest { inputs }).inject_context(); + let response = self.stub.embed(request).await?.into_inner(); + Ok(response) + } + /// Get model health #[instrument(skip(self))] pub async fn health(&mut self) -> Result { diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 7427114f4..1a6f885fb 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -1,3 +1,4 @@ +use crate::pb::generate::v1::EmbedResponse; /// Multi shard Client use crate::{ AdapterParameters, Batch, CachedBatch, Client, DownloadAdapterResponse, Generation, @@ -153,6 +154,17 @@ impl ShardedClient { merge_generations(results?) } + /// Get the model info + #[instrument(skip(self))] + pub async fn embed(&mut self, inputs: String) -> Result> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.embed(inputs.clone()))) + .collect(); + join_all(futures).await.into_iter().collect() + } + pub async fn download_adapter( &mut self, adapter_parameters: AdapterParameters, diff --git a/router/src/lib.rs b/router/src/lib.rs index 643948acf..f5ced9b63 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -68,6 +68,8 @@ pub struct Info { pub docker_label: Option<&'static str>, #[schema(nullable = true, example = "http://localhost:8899")] pub request_logger_url: Option, + #[schema(example = false)] + pub embedding_model: bool, } #[derive(Clone, Debug, Deserialize, ToSchema, Default)] @@ -633,6 +635,16 @@ pub(crate) enum CompletionFinishReason { ToolCalls, } +#[derive(Clone, Debug, Deserialize, ToSchema)] +struct EmbedRequest { + inputs: String, +} + +#[derive(Serialize, ToSchema)] +struct EmbedResponse { + embeddings: Vec, +} + impl From for CompatGenerateRequest { fn from(req: CompletionRequest) -> Self { CompatGenerateRequest { diff --git a/router/src/main.rs b/router/src/main.rs index 7f3c4b3e7..c133b4be0 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -64,6 +64,8 @@ struct Args { #[clap(long, env)] json_output: bool, #[clap(long, env)] + embedding_model: bool, + #[clap(long, env)] otlp_endpoint: Option, #[clap(long, env)] cors_allow_origin: Option>, @@ -109,6 +111,7 @@ async fn main() -> Result<(), RouterError> { revision, validation_workers, json_output, + embedding_model, otlp_endpoint, cors_allow_origin, cors_allow_method, @@ -372,6 +375,7 @@ async fn main() -> Result<(), RouterError> { ngrok_authtoken, ngrok_edge, adapter_source, + embedding_model, ) .await?; Ok(()) diff --git a/router/src/server.rs b/router/src/server.rs index 43e460f59..98f13b48b 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -9,10 +9,11 @@ use crate::{ ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionStreamResponse, ChatCompletionStreamResponseChoice, ChatMessage, CompatGenerateRequest, CompletionFinishReason, CompletionRequest, CompletionResponse, CompletionResponseChoice, - CompletionResponseStreamChoice, CompletionStreamResponse, Details, ErrorResponse, FinishReason, - GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, LogProbs, - PrefillToken, ResponseFormat, ResponseFormatType, SimpleToken, StreamDetails, StreamResponse, - Token, TokenizeRequest, TokenizeResponse, UsageInfo, Validation, + CompletionResponseStreamChoice, CompletionStreamResponse, Details, EmbedRequest, EmbedResponse, + ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, + HubModelInfo, Infer, Info, LogProbs, PrefillToken, ResponseFormat, ResponseFormatType, + SimpleToken, StreamDetails, StreamResponse, Token, TokenizeRequest, TokenizeResponse, + UsageInfo, Validation, }; use axum::extract::Extension; use axum::http::{request, HeaderMap, Method, StatusCode}; @@ -21,6 +22,7 @@ use axum::response::{IntoResponse, Response}; use axum::routing::{get, post}; use axum::{http, Json, Router}; use axum_tracing_opentelemetry::opentelemetry_tracing_layer; +use clap::error; use futures::stream::StreamExt; use futures::Stream; use lorax_client::{ShardInfo, ShardedClient}; @@ -80,6 +82,16 @@ async fn compat_generate( ) -> Result)> { let mut req = req.0; + if info.embedding_model { + metrics::increment_counter!("lorax_request_failure", "err" => "bad_request"); + tracing::error!("Embedding model doesn't support generation"); + let err = ErrorResponse { + error: "Embedding model doesn't support generation".to_string(), + error_type: "bad_request".to_string(), + }; + return Err((StatusCode::BAD_REQUEST, Json(err))); + } + // default return_full_text given the pipeline_tag if req.parameters.return_full_text.is_none() { req.parameters.return_full_text = Some(default_return_full_text.0) @@ -149,6 +161,16 @@ async fn completions_v1( } let mut gen_req = CompatGenerateRequest::from(req); + if info.embedding_model { + metrics::increment_counter!("lorax_request_failure", "err" => "bad_request"); + tracing::error!("Embedding model doesn't support generation"); + let err = ErrorResponse { + error: "Embedding model doesn't support generation".to_string(), + error_type: "bad_request".to_string(), + }; + return Err((StatusCode::BAD_REQUEST, Json(err))); + } + // default return_full_text given the pipeline_tag if gen_req.parameters.return_full_text.is_none() { gen_req.parameters.return_full_text = Some(default_return_full_text.0) @@ -230,6 +252,16 @@ async fn chat_completions_v1( req.model = "".to_string(); } + if info.embedding_model { + metrics::increment_counter!("lorax_request_failure", "err" => "bad_request"); + tracing::error!("Embedding model doesn't support generation"); + let err = ErrorResponse { + error: "Embedding model doesn't support generation".to_string(), + error_type: "bad_request".to_string(), + }; + return Err((StatusCode::BAD_REQUEST, Json(err))); + } + let mut gen_req = CompatGenerateRequest::from(req); // default return_full_text given the pipeline_tag @@ -355,6 +387,16 @@ async fn generate( tracing::debug!("Input: {}", req.0.inputs); + if info.embedding_model { + metrics::increment_counter!("lorax_request_failure", "err" => "bad_request"); + tracing::error!("Embedding model doesn't support generation"); + let err = ErrorResponse { + error: "Embedding model doesn't support generation".to_string(), + error_type: "bad_request".to_string(), + }; + return Err((StatusCode::BAD_REQUEST, Json(err))); + } + let compute_characters = req.0.inputs.chars().count(); let mut add_prompt = None; if req.0.parameters.return_full_text.unwrap_or(false) { @@ -672,152 +714,160 @@ async fn generate_stream_with_callback( headers.insert("x-model-id", info.model_id.parse().unwrap()); let stream = async_stream::stream! { - // Inference - let mut end_reached = false; - let mut error = false; - - let mut prefill_tokens_length = 0; - - let mut add_prompt = None; - if req.0.parameters.return_full_text.unwrap_or(false) { - add_prompt = Some(req.0.inputs.clone()); - } - let details = req.0.parameters.details; - - let best_of = req.0.parameters.best_of.unwrap_or(1); - if best_of != 1 { - let err = InferError::from(ValidationError::BestOfStream); - metrics::increment_counter!("lorax_request_failure", "err" => "validation"); - tracing::error!("{err}"); - yield Ok(Event::from(err)); - } else if req.0.parameters.decoder_input_details { - let err = InferError::from(ValidationError::PrefillDetailsStream); - metrics::increment_counter!("lorax_request_failure", "err" => "validation"); + if info.embedding_model { + let err = InferError::from(ValidationError::EmbeddingModel); + metrics::increment_counter!("lorax_request_failure", "err" => "bad_request"); tracing::error!("{err}"); yield Ok(Event::from(err)); } else { - match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await { - // Keep permit as long as generate_stream lives - Ok((_permit, mut response_stream)) => { - // Server-Sent Event stream - while let Some(response) = response_stream.next().await { - match response { - Ok(response) => { - match response { - // Prefill is ignored - InferStreamResponse::Prefill { - tokens_length, - .. - } => { - prefill_tokens_length = tokens_length; - } - // Yield event for every new token - InferStreamResponse::Token(token) => { - tracing::debug!(parent: &span, "Token: {:?}", token); - // StreamResponse - let stream_token = StreamResponse { - token, - generated_text: None, - details: None, - }; + // Inference + let mut end_reached = false; + let mut error = false; - yield Ok(callback(stream_token)) - } - // Yield event for last token and compute timings - InferStreamResponse::End { - token, - generated_text, - start, - queued, - } => { - // Token details - let details = match details { - true => Some(StreamDetails { - finish_reason: FinishReason::from(generated_text.finish_reason), - prompt_tokens: prefill_tokens_length, - generated_tokens: generated_text.generated_tokens, - seed: generated_text.seed, - }), - false => None, - }; - - // Timings - let total_time = start_time.elapsed(); - let validation_time = queued - start_time; - let queue_time = start - queued; - let inference_time = Instant::now() - start; - let time_per_token = inference_time / generated_text.generated_tokens; - - // Tracing metadata - span.record("total_time", format!("{total_time:?}")); - span.record("validation_time", format!("{validation_time:?}")); - span.record("queue_time", format!("{queue_time:?}")); - span.record("inference_time", format!("{inference_time:?}")); - span.record("time_per_token", format!("{time_per_token:?}")); - span.record("seed", format!("{:?}", generated_text.seed)); - - // Metrics - metrics::increment_counter!("lorax_request_success"); - metrics::histogram!("lorax_request_duration", total_time.as_secs_f64()); - metrics::histogram!("lorax_request_validation_duration", validation_time.as_secs_f64()); - metrics::histogram!("lorax_request_queue_duration", queue_time.as_secs_f64()); - metrics::histogram!("lorax_request_inference_duration", inference_time.as_secs_f64()); - metrics::histogram!("lorax_request_mean_time_per_token_duration", time_per_token.as_secs_f64()); - metrics::histogram!("lorax_request_generated_tokens", generated_text.generated_tokens as f64); - - - - // StreamResponse - end_reached = true; - - let mut output_text = generated_text.text; - if let Some(prompt) = add_prompt { - output_text = prompt + &output_text; - } + let mut prefill_tokens_length = 0; - tracing::debug!(parent: &span, "Output: {}", output_text); - tracing::info!(parent: &span, "Success"); + let mut add_prompt = None; + if req.0.parameters.return_full_text.unwrap_or(false) { + add_prompt = Some(req.0.inputs.clone()); + } + let details = req.0.parameters.details; - let total_tokens = generated_text.generated_tokens + prefill_tokens_length; - if info.request_logger_url.is_some() { - let _ = request_logger_sender.send((total_tokens as i64, api_token.unwrap_or("".to_string()), info.model_id.clone())).await; + let best_of = req.0.parameters.best_of.unwrap_or(1); + if best_of != 1 { + let err = InferError::from(ValidationError::BestOfStream); + metrics::increment_counter!("lorax_request_failure", "err" => "validation"); + tracing::error!("{err}"); + yield Ok(Event::from(err)); + } else if req.0.parameters.decoder_input_details { + let err = InferError::from(ValidationError::PrefillDetailsStream); + metrics::increment_counter!("lorax_request_failure", "err" => "validation"); + tracing::error!("{err}"); + yield Ok(Event::from(err)); + } else { + match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await { + // Keep permit as long as generate_stream lives + Ok((_permit, mut response_stream)) => { + // Server-Sent Event stream + while let Some(response) = response_stream.next().await { + match response { + Ok(response) => { + match response { + // Prefill is ignored + InferStreamResponse::Prefill { + tokens_length, + .. + } => { + prefill_tokens_length = tokens_length; } - - let stream_token = StreamResponse { + // Yield event for every new token + InferStreamResponse::Token(token) => { + tracing::debug!(parent: &span, "Token: {:?}", token); + + // StreamResponse + let stream_token = StreamResponse { + token, + generated_text: None, + details: None, + }; + + yield Ok(callback(stream_token)) + } + // Yield event for last token and compute timings + InferStreamResponse::End { token, - generated_text: Some(output_text), - details - }; - - yield Ok(callback(stream_token)); - break; + generated_text, + start, + queued, + } => { + // Token details + let details = match details { + true => Some(StreamDetails { + finish_reason: FinishReason::from(generated_text.finish_reason), + prompt_tokens: prefill_tokens_length, + generated_tokens: generated_text.generated_tokens, + seed: generated_text.seed, + }), + false => None, + }; + + // Timings + let total_time = start_time.elapsed(); + let validation_time = queued - start_time; + let queue_time = start - queued; + let inference_time = Instant::now() - start; + let time_per_token = inference_time / generated_text.generated_tokens; + + // Tracing metadata + span.record("total_time", format!("{total_time:?}")); + span.record("validation_time", format!("{validation_time:?}")); + span.record("queue_time", format!("{queue_time:?}")); + span.record("inference_time", format!("{inference_time:?}")); + span.record("time_per_token", format!("{time_per_token:?}")); + span.record("seed", format!("{:?}", generated_text.seed)); + + // Metrics + metrics::increment_counter!("lorax_request_success"); + metrics::histogram!("lorax_request_duration", total_time.as_secs_f64()); + metrics::histogram!("lorax_request_validation_duration", validation_time.as_secs_f64()); + metrics::histogram!("lorax_request_queue_duration", queue_time.as_secs_f64()); + metrics::histogram!("lorax_request_inference_duration", inference_time.as_secs_f64()); + metrics::histogram!("lorax_request_mean_time_per_token_duration", time_per_token.as_secs_f64()); + metrics::histogram!("lorax_request_generated_tokens", generated_text.generated_tokens as f64); + + + + // StreamResponse + end_reached = true; + + let mut output_text = generated_text.text; + if let Some(prompt) = add_prompt { + output_text = prompt + &output_text; + } + + tracing::debug!(parent: &span, "Output: {}", output_text); + tracing::info!(parent: &span, "Success"); + + let total_tokens = generated_text.generated_tokens + prefill_tokens_length; + if info.request_logger_url.is_some() { + let _ = request_logger_sender.send((total_tokens as i64, api_token.unwrap_or("".to_string()), info.model_id.clone())).await; + } + + let stream_token = StreamResponse { + token, + generated_text: Some(output_text), + details + }; + + yield Ok(callback(stream_token)); + break; + } } } - } - // yield error - Err(err) => { - error = true; - yield Ok(Event::from(err)); - break; + // yield error + Err(err) => { + error = true; + yield Ok(Event::from(err)); + break; + } } } + }, + // yield error + Err(err) => { + error = true; + yield Ok(Event::from(err)); } - }, - // yield error - Err(err) => { - error = true; + } + // Check if generation reached the end + // Skip if we already sent an error + if !end_reached && !error { + let err = InferError::IncompleteGeneration; + metrics::increment_counter!("lorax_request_failure", "err" => "incomplete"); + tracing::error!("{err}"); yield Ok(Event::from(err)); } } - // Check if generation reached the end - // Skip if we already sent an error - if !end_reached && !error { - let err = InferError::IncompleteGeneration; - metrics::increment_counter!("lorax_request_failure", "err" => "incomplete"); - tracing::error!("{err}"); - yield Ok(Event::from(err)); - } } }; @@ -895,6 +945,7 @@ pub async fn run( ngrok_authtoken: Option, ngrok_edge: Option, adapter_source: String, + embedding_model: bool, ) -> Result<(), axum::BoxError> { // OpenAPI documentation #[derive(OpenApi)] @@ -977,7 +1028,7 @@ pub async fn run( let generation_health = Arc::new(AtomicBool::new(false)); let health_ext = Health::new(client.clone(), generation_health.clone()); let infer = Infer::new( - client, + client.clone(), validation, waiting_served_ratio, max_batch_prefill_tokens, @@ -1084,6 +1135,7 @@ pub async fn run( sha: option_env!("VERGEN_GIT_SHA"), docker_label: option_env!("DOCKER_LABEL"), request_logger_url: std::env::var("REQUEST_LOGGER_URL").ok(), + embedding_model, }; DEFAULT_ADAPTER_SOURCE @@ -1108,6 +1160,7 @@ pub async fn run( .route("/", post(compat_generate)) .route("/info", get(get_model_info)) .route("/generate", post(generate)) + .route("/embed", post(embed)) .route("/generate_stream", post(generate_stream)) .route("/v1/completions", post(completions_v1)) .route("/v1/chat/completions", post(chat_completions_v1)) @@ -1123,6 +1176,7 @@ pub async fn run( .route("/metrics", get(metrics)) .route("/tokenize", post(tokenize)) .layer(Extension(info)) + .layer(Extension(client.clone())) .layer(Extension(request_logger_sender.clone())) .layer(Extension(health_ext.clone())) .layer(Extension(compat_return_full_text)) @@ -1266,6 +1320,46 @@ impl From for Event { } } +/// Embed inputs +#[utoipa::path( + post, + tag = "Embedding", + path = "/embed", + request_body = TokenizeRequest, + responses( + (status = 200, description = "Embeddings ids", body = EmbedResponse), + (status = 500, description = "Incomplete embedding", body = ErrorResponse), + ) +)] +#[instrument(skip_all)] +async fn embed( + mut client: Extension, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let input = req.inputs.clone(); + let embeddings = client.embed(input).await.unwrap(); + let embeddings = embeddings.get(0); + + // TODO: better error enums + if (!embeddings.unwrap().error_msg.is_empty()) { + return Err(( + StatusCode::BAD_REQUEST, + Json(ErrorResponse { + error: embeddings.unwrap().error_msg.clone(), + error_type: "model doesn't support embeddings".to_string(), + }), + )); + } + + let values = embeddings + .map(|emb| emb.embeddings.as_ref().map(|emb| emb.values.clone())) + .flatten() + .unwrap_or_default(); + Ok(Json(EmbedResponse { + embeddings: values.clone(), + })) +} + /// Tokenize inputs #[utoipa::path( post, diff --git a/router/src/validation.rs b/router/src/validation.rs index 33514b84e..6e2c5c70d 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -456,6 +456,8 @@ pub enum ValidationError { AdapterIdMissing, #[error("number of adapter IDs must match number of adapter weights")] AdapterWeightMismatch, + #[error("Embedding models don't support text generation")] + EmbeddingModel, } #[cfg(test)] diff --git a/server/lorax_server/models/__init__.py b/server/lorax_server/models/__init__.py index 8e143dec9..9ecd1c856 100644 --- a/server/lorax_server/models/__init__.py +++ b/server/lorax_server/models/__init__.py @@ -92,6 +92,10 @@ def get_model( dtypetrust_remote_code=trust_remote_code, ) + if model_type == "bert": + from lorax_server.models.flash_bert import FlashBert + return FlashBert(model_id, revision=revision, dtype=dtype) + if model_id.startswith("bigcode/") or model_type == "gpt_bigcode": from lorax_server.models.flash_santacoder import FlashSantacoderSharded diff --git a/server/lorax_server/models/flash_bert.py b/server/lorax_server/models/flash_bert.py new file mode 100644 index 000000000..86904117d --- /dev/null +++ b/server/lorax_server/models/flash_bert.py @@ -0,0 +1,251 @@ +from typing import Optional, Type + +import torch +from opentelemetry import trace +from torch import nn +from transformers import AutoTokenizer +from transformers.activations import ACT2FN +from transformers.models.bert import BertConfig + +from lorax_server.models import Model +from lorax_server.models.types import FlashEmbeddingBatch +from lorax_server.pb.generate_pb2 import Embedding +from lorax_server.utils import ( + Weights, + initialize_torch_distributed, + weight_files, +) +from lorax_server.utils.flash_attn import attention +from lorax_server.utils.layers import FastLayerNorm + +tracer = trace.get_tracer(__name__) + + +# NOTE: This implementation of flashbert was based on the +# huggingface/text-embeddings-inference implementation of flashbert here: +# https://github.com/huggingface/text-embeddings-inference/blob/cb802a25d43fe6078c715b49652a3bc8a7d5aac8/backends/python/server/text_embeddings_server/models/flash_bert.py + + +class BertEmbeddings: + def __init__(self, prefix, weights, device, dtype, config: BertConfig): + self.word_embeddings_weight = weights.get_tensor(f"{prefix}.word_embeddings.weight").to(dtype).to(device) + self.token_type_embeddings_weight = ( + weights.get_tensor(f"{prefix}.token_type_embeddings.weight").to(dtype).to(device) + ) + + if config.position_embedding_type == "absolute": + self.position_embeddings_weight = ( + weights.get_tensor(f"{prefix}.position_embeddings.weight").to(dtype).to(device) + ) + else: + raise NotImplementedError("FlashBert only supports absolute position embeddings") + + self.layer_norm = FastLayerNorm.load(prefix=f"{prefix}.LayerNorm", weights=weights, eps=config.layer_norm_eps) + + def forward(self, input_ids, token_type_ids, position_ids): + inputs_embeds = nn.functional.embedding(input_ids, self.word_embeddings_weight) + token_type_embeds = nn.functional.embedding(token_type_ids, self.token_type_embeddings_weight) + position_embeds = nn.functional.embedding(position_ids, self.position_embeddings_weight) + + inputs_embeds += position_embeds + + embeddings, _ = self.layer_norm.forward(inputs_embeds, token_type_embeds) + return embeddings + + +class BertAttention: + def __init__(self, prefix, weights, device, dtype, config: BertConfig): + query_weight = weights.get_tensor(f"{prefix}.self.query.weight") + query_bias = weights.get_tensor(f"{prefix}.self.query.bias") + key_weight = weights.get_tensor(f"{prefix}.self.key.weight") + key_bias = weights.get_tensor(f"{prefix}.self.key.bias") + value_weight = weights.get_tensor(f"{prefix}.self.value.weight") + value_bias = weights.get_tensor(f"{prefix}.self.value.bias") + + self.qkv_weight = torch.cat([query_weight, key_weight, value_weight]).T.to(dtype).to(device) + self.qkv_bias = torch.cat([query_bias, key_bias, value_bias]).to(dtype).to(device) + + self.dense_weight = weights.get_tensor(f"{prefix}.output.dense.weight").T.to(dtype).to(device) + self.dense_bias = weights.get_tensor(f"{prefix}.output.dense.bias").to(dtype).to(device) + + self.layer_norm = FastLayerNorm.load( + prefix=f"{prefix}.output.LayerNorm", weights=weights, eps=config.layer_norm_eps + ) + + self.head_size = config.hidden_size // config.num_attention_heads + self.softmax_scale = self.head_size**-0.5 + self.num_heads = config.num_attention_heads + + def forward(self, hidden_states, cu_seqlens, max_s): + residual = hidden_states + + qkv = torch.addmm(self.qkv_bias, hidden_states, self.qkv_weight) + q, k, v = qkv.view(-1, self.num_heads * 3, self.head_size).split(self.num_heads, dim=1) + + attn_output = torch.empty_like(q) + attention(q, k, v, attn_output, cu_seqlens, max_s, self.softmax_scale) + + hidden_states = torch.addmm( + self.dense_bias, + attn_output.view(-1, self.num_heads * self.head_size), + self.dense_weight, + ) + hidden_states, _ = self.layer_norm.forward(hidden_states, residual) + + return hidden_states + + +class BertLayer: + def __init__(self, prefix, weights, device, dtype, config: BertConfig): + self.attention = BertAttention(f"{prefix}.attention", weights, device, dtype, config) + + self.intermediate_weight = weights.get_tensor(f"{prefix}.intermediate.dense.weight").T.to(dtype).to(device) + self.intermediate_bias = weights.get_tensor(f"{prefix}.intermediate.dense.bias").to(dtype).to(device) + + act = config.hidden_act + self.intermediate_act_fn = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate="tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none", + ) + ) + + self.output_weight = weights.get_tensor(f"{prefix}.output.dense.weight").T.to(dtype).to(device) + self.output_bias = weights.get_tensor(f"{prefix}.output.dense.bias").to(dtype).to(device) + self.layer_norm = FastLayerNorm.load( + prefix=f"{prefix}.output.LayerNorm", weights=weights, eps=config.layer_norm_eps + ) + + def forward(self, hidden_states, cu_seqlens, max_s): + hidden_states = self.attention.forward(hidden_states, cu_seqlens, max_s) + residual = hidden_states + + hidden_states = torch.addmm(self.intermediate_bias, hidden_states, self.intermediate_weight) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = torch.addmm( + self.output_bias, + hidden_states, + self.output_weight, + ) + hidden_states, _ = self.layer_norm.forward(hidden_states, residual) + return hidden_states + + +class BertEncoder: + def __init__(self, prefix, weights, device, dtype, config: BertConfig): + self.layers = [ + BertLayer(f"{prefix}.layer.{i}", weights, device, dtype, config) for i in range(config.num_hidden_layers) + ] + + def forward(self, hidden_states, cu_seqlens, max_s): + for layer in self.layers: + hidden_states = layer.forward(hidden_states, cu_seqlens, max_s) + return hidden_states + + +class FlashBertModel(torch.nn.Module): + def __init__(self, weights, device, dtype, config: BertConfig): + super().__init__() + self.embeddings = BertEmbeddings("embeddings", weights, device, dtype, config) + self.encoder = BertEncoder("encoder", weights, device, dtype, config) + + def forward(self, input_ids, token_type_ids, position_ids, cu_seqlens, max_s): + embeddings = self.embeddings.forward(input_ids, token_type_ids, position_ids) + encoder_outputs = self.encoder.forward(embeddings, cu_seqlens, max_s) + + return encoder_outputs[cu_seqlens[:-1]] + + +class FlashBert(Model): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + ): + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + raise NotImplementedError("FlashSantacoderSharded is only available on GPU") + + self.device = device + self.dtype = dtype + + tokenizer = AutoTokenizer.from_pretrained(model_id) + self.tokenizer = tokenizer + + config = BertConfig.from_pretrained(model_id) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights( + filenames, + device, + dtype, + process_group=self.process_group, + ) + model = FlashBertModel(weights, device, dtype, config) + + self.hidden_size = config.hidden_size + + super(FlashBert, self).__init__( + model_id=model_id, + model=model, + tokenizer=tokenizer, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + requires_padding=False, + ) + + @property + def batch_type(self) -> Type[FlashEmbeddingBatch]: + return FlashEmbeddingBatch + + @property + def supports_embeddings(self) -> bool: + return True + + @property + def supports_text_generation(self) -> bool: + return False + + def warmup(self, batch: FlashEmbeddingBatch, max_new_tokens: int) -> int | None: + return 42 # no-op for now + + def generate_token(self, batch: FlashEmbeddingBatch) -> None: + if not self.supports_text_generation: + raise NotImplementedError("This model does not support text generation") + return None + + def forward(self, batch: FlashEmbeddingBatch): + return self.embed(batch) + + @tracer.start_as_current_span("embed") + def embed(self, batch: FlashEmbeddingBatch) -> Embedding: + embedding = self.model.forward( + input_ids=batch.input_ids, + token_type_ids=batch.token_type_ids, + position_ids=batch.position_ids, + cu_seqlens=batch.cu_seqlens, + max_s=batch.max_s, + ) + cpu_results = embedding.view(-1).tolist() + + return Embedding(values=cpu_results[: self.hidden_size]) + + def tokenize_to_batch(self, inputs) -> FlashEmbeddingBatch: + tokens = self.tokenizer(inputs, return_token_type_ids=True) + num_tokens = len(tokens["input_ids"]) + position_ids = range(num_tokens) + return FlashEmbeddingBatch( + input_ids=torch.tensor(tokens["input_ids"], dtype=torch.int32, device=self.device), + token_type_ids=torch.tensor(tokens["token_type_ids"], dtype=torch.int32, device=self.device), + position_ids=torch.tensor(position_ids, dtype=torch.int32, device=self.device), + cu_seqlens=torch.tensor([0, num_tokens], dtype=torch.int32, device=self.device), + max_s=num_tokens, + size=1, + ) diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index 3b4bda89a..0f956d036 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -95,6 +95,14 @@ def info(self) -> InfoResponse: @property def sliding_window_blocks(self) -> Optional[int]: return None + + @property + def supports_embeddings(self) -> bool: + return False + + @property + def supports_text_generation(self) -> bool: + return True @property @abstractmethod diff --git a/server/lorax_server/models/types.py b/server/lorax_server/models/types.py index 497c65f79..cf6addb5e 100644 --- a/server/lorax_server/models/types.py +++ b/server/lorax_server/models/types.py @@ -124,3 +124,19 @@ def to_pb(self) -> generate_pb2.Generation: next_tokens=self.next_tokens.to_pb(), generated_text=self.generated_text.to_pb() if self.generated_text is not None else None, ) + +@dataclass +class FlashEmbeddingBatch(ABC): + input_ids: torch.Tensor + token_type_ids: torch.Tensor + position_ids: torch.Tensor + + cu_seqlens: torch.Tensor + max_s: int + size: int + + def __len__(self): + return self.size + + def from_pb(self, *args, **kwargs): + return None \ No newline at end of file diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index 869b7e02f..e23ec51cb 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -95,6 +95,15 @@ async def Prefill(self, request: generate_pb2.PrefillRequest, context): batch=next_batch.to_pb() if next_batch else None, ) + async def Embed(self, request: generate_pb2.EmbedRequest, context): + if not self.model.supports_embeddings: + logger.error("Model does not support embeddings") + return generate_pb2.EmbedResponse(embeddings=generate_pb2.Embedding(), errorMsg="Model does not support embeddings") + batch = request.inputs + tokenised_batch = self.model.tokenize_to_batch(batch) + embeddings = self.model.embed(tokenised_batch) + return generate_pb2.EmbedResponse(embeddings=embeddings) + async def Decode(self, request: generate_pb2.DecodeRequest, context): if len(request.batches) == 0: raise ValueError("Must provide at least one batch")