diff --git a/.github/workflows/integration-tests/action.yaml b/.github/workflows/integration-tests/action.yaml index c7492d46d..854efe238 100644 --- a/.github/workflows/integration-tests/action.yaml +++ b/.github/workflows/integration-tests/action.yaml @@ -66,6 +66,12 @@ runs: cd integration-tests pytest test_embeddings.py -vv --capture=tee-sys --log-cli-level=INFO + - name: Run Classification tests + shell: bash + run: | + cd integration-tests + pytest test_classifications.py -vv --capture=tee-sys --log-cli-level=INFO + - name: Run LLM tests shell: bash run: | diff --git a/Cargo.lock b/Cargo.lock index 17edb92dd..0a24ded7b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -277,6 +277,7 @@ checksum = "504e3947307ac8326a5437504c517c4b56716c9d98fac0028c2acc7ca47d70ae" dependencies = [ "async-trait", "axum-core 0.4.5", + "axum-macros", "bytes", "futures-util", "http 1.1.0", @@ -341,6 +342,17 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum-macros" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57d123550fa8d071b7255cb0cc04dc302baa6c8c4a79f55701552684d8399bce" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "axum-tracing-opentelemetry" version = "0.16.0" diff --git a/integration-tests/test_classifications.py b/integration-tests/test_classifications.py new file mode 100644 index 000000000..85b23285d --- /dev/null +++ b/integration-tests/test_classifications.py @@ -0,0 +1,23 @@ +import requests +from utils.docker_runner import run_lorax_container + + +def test_distilbert_ner(): + config = { + "name": "distilbert-ner", + "model_id": "dslim/distilbert-NER", + "docker_args": { + "max_input_length": 512, + "max_batch_prefill_tokens": 512, + "max_batch_total_tokens": 512, + "max_total_tokens": 512, + }, + } + with run_lorax_container(config): + response = requests.post( + "http://localhost:8080/classify", + json={"inputs": "Johnny supports the Golden State Warriors. He lives in London."}, + ) + response.raise_for_status() + print("RESPONSE FROM CLASSIFICATION: ", response.json()) + assert len(response.json()["predictions"]) > 0 diff --git a/integration-tests/test_embeddings.py b/integration-tests/test_embeddings.py index 329a8810b..71723da52 100644 --- a/integration-tests/test_embeddings.py +++ b/integration-tests/test_embeddings.py @@ -13,3 +13,21 @@ def test_stella_1_5b(): response.raise_for_status() print("RESPONSE FROM EMBEDDING: ", response.json()) assert len(response.json()["embeddings"]) > 0 + + +def test_uae_large_v1_1_5b(): + config = { + "name": "UAE-Large-V1-1.5b", + "model_id": "WhereIsAI/UAE-Large-V1", + "docker_args": { + "max_input_length": 512, + "max_batch_prefill_tokens": 512, + "max_batch_total_tokens": 512, + "max_total_tokens": 512, + }, + } + with run_lorax_container(config): + response = requests.post("http://localhost:8080/embed", json={"inputs": "Hello, world!"}) + response.raise_for_status() + print("RESPONSE FROM EMBEDDING: ", response.json()) + assert len(response.json()["embeddings"]) > 0 diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 53929a3cb..33116d167 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1675,13 +1675,6 @@ fn main() -> Result<(), LauncherError> { } }; - // Validate args - if max_input_tokens >= max_total_tokens { - return Err(LauncherError::ArgumentValidation( - "`max_input_length` must be < `max_total_tokens`".to_string(), - )); - } - if args.validation_workers == 0 { return Err(LauncherError::ArgumentValidation( "`validation_workers` must be > 0".to_string(), diff --git a/router/Cargo.toml b/router/Cargo.toml index af817326e..679da0563 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -16,7 +16,7 @@ path = "src/main.rs" [dependencies] async-stream = "0.3.3" -axum = { version = "0.7", features = ["json"] } +axum = { version = "0.7", features = ["json", "macros"] } axum-tracing-opentelemetry = "0.16" clap = { version = "4.1.4", features = ["derive", "env"] } futures = "0.3.26" diff --git a/router/src/infer.rs b/router/src/infer.rs index 101360ba1..d89090803 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -4,9 +4,10 @@ use crate::queue::AdapterEvent; use crate::scheduler::AdapterScheduler; use crate::validation::{Validation, ValidationError}; use crate::{ - AdapterParameters, AlternativeToken, BatchClassifyRequest, ChatTemplateVersions, - ClassifyRequest, EmbedRequest, EmbedResponse, Entity, Entry, HubTokenizerConfig, Message, - MessageChunk, MessageContent, TextMessage, Token, TokenizerConfigToken, Tool, + AdapterParameters, AlternativeToken, BatchClassifyRequest, BatchEmbedRequest, + ChatTemplateVersions, ClassifyRequest, EmbedRequest, EmbedResponse, Entity, Entry, + HubTokenizerConfig, Message, MessageChunk, MessageContent, TextMessage, Token, + TokenizerConfigToken, Tool, }; use crate::{GenerateRequest, PrefillToken}; use futures::future::try_join_all; @@ -596,6 +597,7 @@ impl Infer { embedding, start: _, queued: _, + id: _, } => { return_embeddings = Some(embedding.values); } @@ -857,6 +859,137 @@ impl Infer { } } + #[instrument(skip(self))] + pub(crate) async fn embed_batch( + &self, + request: BatchEmbedRequest, + ) -> Result, InferError> { + // Limit concurrent requests by acquiring a permit from the semaphore + let _permit = self + .clone() + .limit_concurrent_requests + .try_acquire_owned() + .map_err(|err| { + metrics::increment_counter!("lorax_request_failure", "err" => "overloaded"); + tracing::error!("{err}"); + err + })?; + + let (adapter_source, adapter_parameters) = extract_adapter_params( + request.parameters.adapter_id.clone(), + request.parameters.adapter_source.clone(), + request.parameters.adapter_parameters.clone(), + ); + + let adapter_idx; + { + // TODO(travis): can optimize concurrency here using RWLock + let mut adapter_to_index = self.adapter_to_index.lock().await; + let adapter_key = adapter_parameters.clone(); + if adapter_to_index.contains_key(&adapter_key) { + adapter_idx = *adapter_to_index.get(&adapter_key).unwrap(); + } else { + adapter_idx = adapter_to_index.len() as u32; + adapter_to_index.insert(adapter_key, adapter_idx); + } + } + + let api_token = request.parameters.api_token.clone(); + let adapter = Adapter::new( + adapter_parameters, + adapter_source.unwrap(), + adapter_idx, + api_token, + ); + + // MPSC channel to communicate with the background batching task + let (response_tx, response_rx) = mpsc::unbounded_channel(); + + // Call validate_input on every input in the request and await the results + let futures: Vec<_> = request + .inputs + .iter() + .map(|input| { + self.validation + .validate_input(input.clone(), true, None, Some(1)) + }) + .collect(); + + let all_tokenized_inputs = try_join_all(futures).await?; + + for ((id, r_inputs), (tokenized_inputs, input_length)) in + request.inputs.iter().enumerate().zip(all_tokenized_inputs) + { + let inputs = r_inputs.to_string().clone(); + let valid_request = ValidEmbedRequest { + inputs, + tokenized_inputs, + input_length: input_length as u32, + adapter: adapter.clone(), + }; + + // Process the request by sending it to the queue associated with `adapter` + self.adapter_scheduler.process( + adapter.clone(), + Entry { + request: Arc::new(valid_request), + response_tx: response_tx.clone(), + span: Span::current(), + temp_span: None, + queue_time: Instant::now(), + batch_time: None, + block_allocation: None, + id: Some(id as u64), + }, + ); + } + + drop(response_tx); // Close the sending end + + // Return values + + let mut all_embeddings = HashMap::new(); + let mut stream = UnboundedReceiverStream::new(response_rx); + while let Some(response) = stream.next().await { + match response? { + InferStreamResponse::Embed { + embedding, + start: _, + queued: _, + id, + } => { + all_embeddings.insert( + id.unwrap(), + EmbedResponse { + embeddings: embedding.values, + }, + ); + } + _ => { + tracing::error!( + "Received unexpected message type in classify_batch. This is a bug." + ); + } + } + } + if all_embeddings.is_empty() { + let err = InferError::EmbeddingFailure; + metrics::increment_counter!("lorax_request_failure", "err" => "embedding_failure"); + tracing::error!("{err}"); + Err(err) + } else { + let mut sorted_responses: Vec<_> = all_embeddings.into_iter().collect(); + sorted_responses.sort_by_key(|&(id, _)| id); + + let sorted_responses: Vec = sorted_responses + .into_iter() + .map(|(_, response)| response) + .collect(); + + Ok(sorted_responses) + } + } + /// Add best_of new requests to the queue and return a InferResponse of the sequence with /// the highest log probability per token #[instrument(skip(self))] @@ -1478,6 +1611,7 @@ fn send_embeddings( embedding: embedding.clone(), queued: entry.queue_time, start: entry.batch_time.unwrap(), + id: entry.id, }))?; // TODO(travis): redundant as we always return true, just make it return nothing @@ -1534,22 +1668,18 @@ pub(crate) enum InferStreamResponse { // Intermediate messages Token(Token), // Embeddings + // TODO: add tracing for embedding Embed { embedding: Embedding, - // For now allow this field even though it is unused. - // TODO:(magdy) enable tracing for these requests #[allow(dead_code)] start: Instant, #[allow(dead_code)] queued: Instant, + id: Option, // to support batching }, Classify { predictions: ClassifyPredictionList, - // For now allow this field even though it is unused. - // TODO:(magdy) enable tracing for these requests - #[allow(dead_code)] start: Instant, - #[allow(dead_code)] queued: Instant, id: Option, // to support batching }, diff --git a/router/src/lib.rs b/router/src/lib.rs index 05cb07e96..fd5b1049f 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1165,6 +1165,49 @@ struct EmbedResponse { embeddings: Vec, } +#[derive(Debug, Clone, Deserialize)] +#[serde(untagged)] +enum StringOrVec { + String(String), + Vec(Vec), +} + +impl std::fmt::Display for StringOrVec { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + StringOrVec::String(s) => write!(f, "{}", s), + StringOrVec::Vec(v) => write!(f, "{}", v.join(", ")), + } + } +} + +#[derive(Clone, Debug, Deserialize, ToSchema)] +struct CompatEmbedRequest { + input: StringOrVec, + #[allow(dead_code)] + model: String, + #[allow(dead_code)] + encoding_format: Option, + #[allow(dead_code)] + dimensions: Option, + #[allow(dead_code)] + user: Option, + #[serde(default = "default_embed_parameters")] + parameters: EmbedParameters, +} + +#[derive(Serialize, ToSchema)] +struct CompatEmbedResponse { + embeddings: Vec, +} + +#[derive(Serialize, ToSchema)] +struct CompatEmbedding { + index: i32, + embedding: Vec, + object: String, +} + #[derive(Clone, Debug, Deserialize, ToSchema)] struct ClassifyRequest { inputs: String, @@ -1175,6 +1218,13 @@ struct BatchClassifyRequest { inputs: Vec, } +#[derive(Clone, Debug, Deserialize, ToSchema)] +struct BatchEmbedRequest { + inputs: Vec, + #[serde(default = "default_embed_parameters")] + parameters: EmbedParameters, +} + #[derive(Debug, Serialize, Deserialize)] struct Entity { entity_group: String, diff --git a/router/src/main.rs b/router/src/main.rs index 249031258..42d0c65a1 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -135,13 +135,6 @@ async fn main() -> Result<(), RouterError> { init_logging(otlp_endpoint, json_output); - // Validate args - if max_input_length >= max_total_tokens { - return Err(RouterError::ArgumentValidation( - "`max_input_length` must be < `max_total_tokens`".to_string(), - )); - } - if validation_workers == 0 { return Err(RouterError::ArgumentValidation( "`validation_workers` must be > 0".to_string(), diff --git a/router/src/server.rs b/router/src/server.rs index d33412aaf..b874b6b9a 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -7,16 +7,17 @@ use crate::tool_grammar::ToolGrammar; use crate::validation::ValidationError; use crate::{json, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig}; use crate::{ - AdapterParameters, AlternativeToken, BatchClassifyRequest, BestOfSequence, + AdapterParameters, AlternativeToken, BatchClassifyRequest, BatchEmbedRequest, BestOfSequence, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionStreamResponse, ChatCompletionStreamResponseChoice, ChatMessage, ClassifyRequest, - CompatGenerateRequest, CompletionFinishReason, CompletionRequest, CompletionResponse, - CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, Details, - EmbedParameters, EmbedRequest, EmbedResponse, Entity, ErrorResponse, FinishReason, - GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, JsonSchema, - LogProbs, Message, OpenAiResponseFormat, PrefillToken, ResponseFormat, ResponseFormatType, - ReturnFunctionDefinition, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeRequest, - TokenizeResponse, Tool, ToolCall, ToolChoice, UsageInfo, Validation, + CompatEmbedRequest, CompatEmbedResponse, CompatEmbedding, CompatGenerateRequest, + CompletionFinishReason, CompletionRequest, CompletionResponse, CompletionResponseChoice, + CompletionResponseStreamChoice, CompletionStreamResponse, Details, EmbedParameters, + EmbedRequest, EmbedResponse, Entity, ErrorResponse, FinishReason, GenerateParameters, + GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, JsonSchema, LogProbs, Message, + OpenAiResponseFormat, PrefillToken, ResponseFormat, ResponseFormatType, + ReturnFunctionDefinition, SimpleToken, StreamDetails, StreamResponse, StringOrVec, Token, + TokenizeRequest, TokenizeResponse, Tool, ToolCall, ToolChoice, UsageInfo, Validation, }; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; @@ -1483,6 +1484,7 @@ pub async fn run( .route("/classify_batch", post(classify_batch)) .route("/generate_stream", post(generate_stream)) .route("/v1/completions", post(completions_v1)) + .route("/v1/embeddings", post(compat_embed)) .route("/v1/chat/completions", post(chat_completions_v1)) // AWS Sagemaker route .route("/invocations", post(compat_generate)); @@ -1625,7 +1627,7 @@ impl From for Event { post, tag = "Embedding", path = "/embed", - request_body = TokenizeRequest, + request_body = EmbedRequest, responses( (status = 200, description = "Embeddings ids", body = EmbedResponse), (status = 500, description = "Incomplete embedding", body = ErrorResponse), @@ -1643,6 +1645,69 @@ async fn embed( Ok(Json(response)) } +/// Embed inputs +#[utoipa::path( + post, + tag = "OpenAI Compatible", + path = "/v1/embeddings", + request_body = CompatEmbedRequest, + responses( + (status = 200, description = "Embeddings ids", body = CompatEmbedResponse), + (status = 500, description = "Incomplete embedding", body = ErrorResponse), + ) +)] +#[instrument(skip_all)] +#[axum::debug_handler] +async fn compat_embed( + infer: Extension, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + metrics::increment_counter!("lorax_request_count"); + tracing::debug!("Input: {}", req.input); + if let StringOrVec::Vec(inputs) = req.input { + let batch_embed_req = BatchEmbedRequest { + inputs, + parameters: req.parameters, + }; + let response = infer.embed_batch(batch_embed_req).await?; + let compat_embeddings = response + .into_iter() + .enumerate() + .map(|(i, e)| -> CompatEmbedding { + CompatEmbedding { + index: i as i32, + embedding: e.embeddings, + object: "embedding".to_string(), + } + }) + .collect(); + Ok(Json(CompatEmbedResponse { + embeddings: compat_embeddings, + })) + } else if let StringOrVec::String(input) = req.input { + let embed_req = EmbedRequest { + inputs: input.to_string(), + parameters: req.parameters, + }; + let response = infer.embed(embed_req).await?; + Ok(Json(CompatEmbedResponse { + embeddings: vec![CompatEmbedding { + index: 0, + embedding: response.embeddings, + object: "embedding".to_string(), + }], + })) + } else { + Err(( + StatusCode::BAD_REQUEST, + Json(ErrorResponse { + error: "Invalid input".to_string(), + error_type: "invalid_input".to_string(), + }), + )) + } +} + #[utoipa::path( post, tag = "Classify", @@ -1716,7 +1781,7 @@ async fn classify( post, tag = "ClassifyBatch", path = "/classify_batch", - request_body = TokenizeRequest, + request_body = BatchClassifyRequest, responses( (status = 200, description = "Classifications", body = BatchClassifyResponse), (status = 500, description = "Incomplete classification", body = ErrorResponse),