diff --git a/Cargo.lock b/Cargo.lock index c3b488bdb..eb05e5b62 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -147,9 +147,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.71" +version = "0.1.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a564d521dd56509c4c47480d00b80ee55f7e385ae48db5744c67ad50c92d2ebf" +checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" dependencies = [ "proc-macro2", "quote", @@ -1271,6 +1271,7 @@ name = "lorax-router" version = "0.1.0" dependencies = [ "async-stream", + "async-trait", "axum", "axum-tracing-opentelemetry", "clap", diff --git a/proto/generate.proto b/proto/generate.proto index b47f2cd29..ac48cf15b 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -232,16 +232,24 @@ message DecodeResponse { optional CachedBatch batch = 2; } -message EmbedRequest { - string inputs = 1; +message Embedding { + /// Request ID + uint64 request_id = 1; + + /// Embedding values + repeated float values = 2; } -message Embedding { - repeated float values = 1; +message EmbedRequest { + /// Batch + Batch batch = 1; } message EmbedResponse { - Embedding embeddings = 1; + /// Embeddings + repeated Embedding embeddings = 1; + + /// Error message on failure string errorMsg = 2; } diff --git a/router/Cargo.toml b/router/Cargo.toml index 4ae55744f..5a925f312 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -46,10 +46,11 @@ utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] } ngrok = { version = "0.12.3", features = ["axum"], optional = true } once_cell = "1.19.0" itertools = "0.12.1" +async-trait = "0.1.80" [build-dependencies] vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] } [features] default = ["ngrok"] -ngrok = ["dep:ngrok"] \ No newline at end of file +ngrok = ["dep:ngrok"] diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 6b7157885..7a5f2f1e6 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -64,14 +64,6 @@ impl Client { Ok(response) } - /// Embed - #[instrument(skip(self))] - pub async fn embed(&mut self, inputs: String) -> Result { - let request = tonic::Request::new(EmbedRequest { inputs }).inject_context(); - let response = self.stub.embed(request).await?.into_inner(); - Ok(response) - } - /// Get model health #[instrument(skip(self))] pub async fn health(&mut self) -> Result { @@ -196,6 +188,14 @@ impl Client { Ok((response.generations, response.batch)) } + /// Embed + #[instrument(skip(self))] + pub async fn embed(&mut self, batch: Batch) -> Result> { + let request = tonic::Request::new(EmbedRequest { batch: Some(batch) }).inject_context(); + let response = self.stub.embed(request).await?.into_inner(); + Ok(response.embeddings) + } + /// Downloads the weights for an adapter. pub async fn download_adapter( &mut self, diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index 3acdb5f68..0b7242104 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -9,7 +9,7 @@ pub use client::Client; pub use pb::generate::v1::HealthResponse; pub use pb::generate::v1::InfoResponse as ShardInfo; pub use pb::generate::v1::{ - AdapterParameters, AlternativeTokens, Batch, CachedBatch, DownloadAdapterResponse, + AdapterParameters, AlternativeTokens, Batch, CachedBatch, DownloadAdapterResponse, Embedding, FinishReason, GeneratedText, Generation, MajoritySignMethod, MergeStrategy, NextTokenChooserParameters, NextTokens, PrefillTokens, Request, StoppingCriteriaParameters, }; diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 1a6f885fb..c973f9251 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -1,4 +1,4 @@ -use crate::pb::generate::v1::EmbedResponse; +use crate::pb::generate::v1::{EmbedResponse, Embedding}; /// Multi shard Client use crate::{ AdapterParameters, Batch, CachedBatch, Client, DownloadAdapterResponse, Generation, @@ -154,15 +154,16 @@ impl ShardedClient { merge_generations(results?) } - /// Get the model info + /// Embed the given batch #[instrument(skip(self))] - pub async fn embed(&mut self, inputs: String) -> Result> { + pub async fn embed(&mut self, batch: Batch) -> Result> { let futures: Vec<_> = self .clients .iter_mut() - .map(|client| Box::pin(client.embed(inputs.clone()))) + .map(|client| Box::pin(client.embed(batch.clone()))) .collect(); - join_all(futures).await.into_iter().collect() + let results: Result>> = join_all(futures).await.into_iter().collect(); + Ok(results?.into_iter().flatten().collect()) } pub async fn download_adapter( diff --git a/router/src/batch.rs b/router/src/batch.rs new file mode 100644 index 000000000..f47e65b78 --- /dev/null +++ b/router/src/batch.rs @@ -0,0 +1,445 @@ +use core::fmt::Debug; +use std::{ + any::Any, + collections::{HashMap, HashSet}, + sync::{atomic::AtomicBool, Arc}, +}; + +use async_trait::async_trait; + +use lorax_client::{ + Batch, CachedBatch, NextTokenChooserParameters, Request, ShardedClient, + StoppingCriteriaParameters, +}; +use nohash_hasher::{BuildNoHashHasher, IntMap}; +use tokio::time::Instant; +use tracing::{info_span, span, Instrument, Span}; + +use crate::{ + adapter::Adapter, + infer::{decode, embed, prefill, InferError, InferStreamResponse}, +}; + +pub(crate) trait ValidRequest: Sync + Send + Debug + Any { + fn input_length(&self) -> u32; + fn max_new_tokens(&self) -> u32; + fn adapter(&self) -> Adapter; + fn to_batch(&self, num_entries: usize, queue_len: usize) -> Box; + fn as_any(&self) -> &dyn Any; +} + +impl ValidRequest for ValidGenerateRequest { + fn input_length(&self) -> u32 { + self.input_length + } + + fn max_new_tokens(&self) -> u32 { + self.stopping_parameters.max_new_tokens + } + + fn adapter(&self) -> Adapter { + self.adapter.clone() + } + + fn to_batch(&self, num_entries: usize, queue_len: usize) -> Box { + Box::new(GenerateBatchEntries::new(num_entries, queue_len)) + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +#[derive(Debug)] +pub(crate) struct ValidEmbedRequest { + pub inputs: String, + pub input_length: u32, + pub adapter: Adapter, +} + +impl ValidRequest for ValidEmbedRequest { + fn input_length(&self) -> u32 { + self.input_length + } + + fn max_new_tokens(&self) -> u32 { + 1 + } + + fn adapter(&self) -> Adapter { + self.adapter.clone() + } + + fn to_batch(&self, num_entries: usize, queue_len: usize) -> Box { + Box::new(EmbedBatchEntries::new(num_entries, queue_len)) + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +#[derive(Debug)] +pub(crate) struct ValidGenerateRequest { + pub inputs: String, + pub input_length: u32, + pub truncate: u32, + pub decoder_input_details: bool, + pub parameters: NextTokenChooserParameters, + pub stopping_parameters: StoppingCriteriaParameters, + pub adapter: Adapter, + pub apply_chat_template: bool, +} + +/// AdapterLoader entry +#[derive(Debug)] +pub(crate) struct Entry { + /// Request + pub request: Arc, + /// Response sender to communicate between the Infer struct and the batching_task + pub response_tx: flume::Sender>, + /// Span that will live as long as entry + pub span: Span, + /// Temporary span used as a guard when logging inference, wait times... + pub temp_span: Option, + /// Instant when this entry was queued + pub queue_time: Instant, + /// Instant when this entry was added to a batch + pub batch_time: Option, +} + +#[derive(Debug)] +pub(crate) struct BatchEntriesState { + pub(crate) batch_requests: Vec, + pub(crate) batch_entries: HashMap>, + pub(crate) index_to_adapter: HashMap, +} + +impl BatchEntriesState { + pub(crate) fn new(num_entries: usize, queue_len: usize) -> Self { + let batch_requests = Vec::with_capacity(num_entries); + let batch_entries = + IntMap::with_capacity_and_hasher(num_entries, BuildNoHashHasher::default()); + + let index_to_adapter = HashMap::with_capacity(queue_len); + + Self { + batch_requests, + batch_entries, + index_to_adapter, + } + } + + fn add(&mut self, id: u64, mut entry: Entry, adapter: Adapter, request: Request) { + self.batch_requests.push(request); + + // Set batch_time + entry.batch_time = Some(Instant::now()); + // Insert in batch_entries IntMap + self.batch_entries.insert(id, entry); + // Map from adapter index back to queue in case we need to add back entries below + // let queue = queue_map.get_mut(&adapter).unwrap(); + self.index_to_adapter.insert(adapter.index(), adapter); + } + + fn drain(&mut self) -> Vec<(Adapter, u64, Entry)> { + let mut entries = Vec::with_capacity(self.batch_requests.len()); + + // TODO(travis): clone is not ideal, find a way to do this cleanly in place + for r in self.batch_requests.clone().into_iter().rev() { + let id = r.id; + let entry = self.batch_entries.remove(&id).unwrap(); + let adapter_index = r.adapter_index; + let adapter = self.index_to_adapter.get_mut(&adapter_index).unwrap(); + entries.push((adapter.clone(), id, entry)); + } + + entries + } + + fn create_batch_data(&self, batch_id: u64, max_tokens: u32) -> Batch { + // Final batch size + let size = self.len() as u32; + + // TODO(travis): clone is not ideal, find a way to do this cleanly in place + Batch { + id: batch_id, + requests: self.batch_requests.clone(), + size, + max_tokens, + } + } + + fn adapters_in_use(&self) -> HashSet { + self.batch_entries + .iter() + .map(|(_, entry)| entry.request.adapter()) + .collect::>() + } + + fn is_empty(&self) -> bool { + self.batch_requests.is_empty() + } + + fn len(&self) -> usize { + self.batch_requests.len() + } +} + +#[async_trait] +pub(crate) trait BatchEntries: Sync + Send + Debug { + fn can_add(&self, entry: &Entry) -> bool; + fn add(&mut self, id: u64, entry: Entry, adapter: Adapter); + fn extend(&mut self, entries: Box); + fn drain(&mut self) -> Vec<(Adapter, u64, Entry)>; + fn create_batch_data(&self, batch_id: u64, max_tokens: u32) -> Batch; + fn adapters_in_use(&self) -> HashSet; + fn is_empty(&self) -> bool; + fn len(&self) -> usize; + fn state(&self) -> &BatchEntriesState; + fn mut_state(&mut self) -> &mut BatchEntriesState; + + async fn process_first( + &mut self, + client: &mut ShardedClient, + batch: Batch, + span: Span, + generation_health: &Arc, + ) -> Option; + + async fn process_next( + &mut self, + client: &mut ShardedClient, + batches: Vec, + span: Span, + generation_health: &Arc, + ) -> Option; +} + +#[derive(Debug)] +pub(crate) struct GenerateBatchEntries { + pub(crate) state: BatchEntriesState, +} + +impl GenerateBatchEntries { + pub(crate) fn new(num_entries: usize, queue_len: usize) -> Self { + Self { + state: BatchEntriesState::new(num_entries, queue_len), + } + } +} + +#[async_trait] +impl BatchEntries for GenerateBatchEntries { + fn can_add(&self, entry: &Entry) -> bool { + // return false if the entry.request is not of type ValidGenerateRequest + let valid_request = entry + .request + .as_ref() + .as_any() + .downcast_ref::(); + + let result = valid_request.is_some(); + result + } + + fn add(&mut self, id: u64, entry: Entry, adapter: Adapter) { + let valid_request = entry + .request + .as_ref() + .as_any() + .downcast_ref::(); + + let request = valid_request.unwrap(); + let request_proto = Request { + id, + prefill_logprobs: request.decoder_input_details, + inputs: request.inputs.clone(), + truncate: request.truncate, + parameters: Some(request.parameters.clone()), + stopping_parameters: Some(request.stopping_parameters.clone()), + adapter_index: adapter.index(), + apply_chat_template: request.apply_chat_template, + }; + + self.state.add(id, entry, adapter, request_proto); + } + + fn extend(&mut self, mut entries: Box) { + let new_batch_entries = std::mem::take(&mut entries.mut_state().batch_entries); + self.state.batch_entries.extend(new_batch_entries); + } + + fn drain(&mut self) -> Vec<(Adapter, u64, Entry)> { + self.state.drain() + } + + fn create_batch_data(&self, batch_id: u64, max_tokens: u32) -> Batch { + self.state.create_batch_data(batch_id, max_tokens) + } + + fn adapters_in_use(&self) -> HashSet { + self.state.adapters_in_use() + } + + fn is_empty(&self) -> bool { + self.state.is_empty() + } + + fn len(&self) -> usize { + self.state.len() + } + + fn state(&self) -> &BatchEntriesState { + &self.state + } + + fn mut_state(&mut self) -> &mut BatchEntriesState { + &mut self.state + } + + async fn process_first( + &mut self, + client: &mut ShardedClient, + batch: Batch, + span: Span, + generation_health: &Arc, + ) -> Option { + prefill( + client, + batch, + &mut self.state.batch_entries, + &generation_health, + ) + .instrument(span) + .await + } + + async fn process_next( + &mut self, + client: &mut ShardedClient, + batches: Vec, + span: Span, + generation_health: &Arc, + ) -> Option { + decode( + client, + batches, + &mut self.state.batch_entries, + &generation_health, + ) + .instrument(span) + .await + } +} + +#[derive(Debug)] +pub(crate) struct EmbedBatchEntries { + pub(crate) state: BatchEntriesState, +} + +impl EmbedBatchEntries { + pub(crate) fn new(num_entries: usize, queue_len: usize) -> Self { + Self { + state: BatchEntriesState::new(num_entries, queue_len), + } + } +} + +#[async_trait] +impl BatchEntries for EmbedBatchEntries { + fn can_add(&self, entry: &Entry) -> bool { + // return false if the entry.request is not of type ValidEmbedRequest + let valid_request = entry + .request + .as_ref() + .as_any() + .downcast_ref::(); + + let result = valid_request.is_some(); + result + } + + fn add(&mut self, id: u64, entry: Entry, adapter: Adapter) { + let valid_request = entry + .request + .as_ref() + .as_any() + .downcast_ref::(); + + let request = valid_request.unwrap(); + let request_proto = Request { + id, + prefill_logprobs: false, + inputs: request.inputs.clone(), + truncate: 0, + parameters: None, + stopping_parameters: None, + adapter_index: adapter.index(), + apply_chat_template: false, + }; + + self.state.add(id, entry, adapter, request_proto); + } + + fn extend(&mut self, mut entries: Box) { + let new_batch_entries = std::mem::take(&mut entries.mut_state().batch_entries); + self.state.batch_entries.extend(new_batch_entries); + } + + fn drain(&mut self) -> Vec<(Adapter, u64, Entry)> { + self.state.drain() + } + + fn create_batch_data(&self, batch_id: u64, max_tokens: u32) -> Batch { + self.state.create_batch_data(batch_id, max_tokens) + } + + fn adapters_in_use(&self) -> HashSet { + self.state.adapters_in_use() + } + + fn is_empty(&self) -> bool { + self.state.is_empty() + } + + fn len(&self) -> usize { + self.state.len() + } + + fn state(&self) -> &BatchEntriesState { + &self.state + } + + fn mut_state(&mut self) -> &mut BatchEntriesState { + &mut self.state + } + + async fn process_first( + &mut self, + client: &mut ShardedClient, + batch: Batch, + span: Span, + generation_health: &Arc, + ) -> Option { + embed( + client, + batch, + &mut self.state.batch_entries, + &generation_health, + ) + .instrument(span) + .await + } + + async fn process_next( + &mut self, + client: &mut ShardedClient, + batches: Vec, + span: Span, + generation_health: &Arc, + ) -> Option { + // TODO(travis): send error (programming eroor) if we get here + None + } +} diff --git a/router/src/infer.rs b/router/src/infer.rs index 99aaebe9c..7184b1f77 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -1,9 +1,10 @@ /// Batching and inference logic use crate::adapter::{extract_adapter_params, Adapter, BASE_MODEL_ADAPTER_ID}; +use crate::batch::ValidEmbedRequest; use crate::queue::AdapterEvent; use crate::scheduler::AdapterScheduler; use crate::validation::{Validation, ValidationError}; -use crate::{AdapterParameters, AlternativeToken, Entry, Token}; +use crate::{AdapterParameters, AlternativeToken, EmbedRequest, EmbedResponse, Entry, Token}; use crate::{GenerateRequest, PrefillToken}; use flume::r#async::RecvStream; use flume::SendTimeoutError; @@ -11,7 +12,8 @@ use futures::future::try_join_all; use futures::stream::StreamExt; use itertools::multizip; use lorax_client::{ - Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient, + Batch, CachedBatch, ClientError, Embedding, GeneratedText, Generation, PrefillTokens, + ShardedClient, }; use nohash_hasher::IntMap; use std::collections::{HashMap, HashSet}; @@ -23,7 +25,7 @@ use std::time::Duration; use thiserror::Error; use tokio::sync::{Mutex, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError}; use tokio::time::Instant; -use tracing::{info_span, instrument, Instrument, Span}; +use tracing::{info_span, instrument, Span}; /// Inference struct #[derive(Clone)] @@ -169,7 +171,7 @@ impl Infer { self.adapter_scheduler.process( adapter.clone(), Entry { - request: valid_request, + request: Arc::new(valid_request), response_tx, span: Span::current(), temp_span: None, @@ -256,6 +258,10 @@ impl Infer { result_start = Some(start); result_queued = Some(queued) } + InferStreamResponse::Embed { .. } => { + // This should not happen + tracing::error!("Received an Embed message in generate. This is a bug."); + } } } @@ -278,6 +284,130 @@ impl Infer { Err(err) } } + + #[instrument(skip(self))] + pub(crate) async fn embed(&self, request: EmbedRequest) -> Result { + // Limit concurrent requests by acquiring a permit from the semaphore + let permit = self + .clone() + .limit_concurrent_requests + .try_acquire_owned() + .map_err(|err| { + metrics::increment_counter!("lorax_request_failure", "err" => "overloaded"); + tracing::error!("{err}"); + err + })?; + + // TODO(travis): support adapters + // let (adapter_source, adapter_parameters) = extract_adapter_params( + // request.parameters.adapter_id.clone(), + // request.parameters.adapter_source.clone(), + // request.parameters.adapter_parameters.clone(), + // ); + + // let adapter_idx; + // { + // // TODO(travis): can optimize concurrency here using RWLock + // let mut adapter_to_index = self.adapter_to_index.lock().await; + // let adapter_key = adapter_parameters.clone(); + // if adapter_to_index.contains_key(&adapter_key) { + // adapter_idx = *adapter_to_index.get(&adapter_key).unwrap(); + // } else { + // adapter_idx = adapter_to_index.len() as u32; + // adapter_to_index.insert(adapter_key, adapter_idx); + // } + // } + + let adapter = Adapter::new( + AdapterParameters { + adapter_ids: vec![BASE_MODEL_ADAPTER_ID.to_string()], + ..Default::default() + }, + "hub".to_string(), + 0, + None, + ); + + // TODO(travis): robust validation + // Validate request + // let valid_request = self + // .validation + // .validate(request, adapter.clone()) + // .await + // .map_err(|err| { + // metrics::increment_counter!("lorax_request_failure", "err" => "validation"); + // tracing::error!("{err}"); + // err + // })?; + + let (inputs, input_length) = self + .validation + .validate_input(request.inputs, None, Some(1)) + .await?; + + let valid_request = ValidEmbedRequest { + inputs: inputs, + input_length: input_length as u32, + adapter: adapter.clone(), + }; + + // MPSC channel to communicate with the background batching task + let (response_tx, response_rx) = flume::unbounded(); + + // Process the request by sending it to the queue associated with `adapter` + self.adapter_scheduler.process( + adapter.clone(), + Entry { + request: Arc::new(valid_request), + response_tx, + span: Span::current(), + temp_span: None, + queue_time: Instant::now(), + batch_time: None, + }, + ); + + // Return values + let mut return_embeddings = None; + + let mut stream = response_rx.into_stream(); + while let Some(response) = stream.next().await { + match response? { + // Add prefill tokens + InferStreamResponse::Prefill { .. } => { + tracing::error!("Received a Prefill message in embed. This is a bug."); + } + // Push last token + InferStreamResponse::Token(..) => { + tracing::error!("Received a Token message in embed. This is a bug."); + } + // Final message + // Set return values + InferStreamResponse::End { .. } => { + tracing::error!("Received an End message in embed. This is a bug."); + } + InferStreamResponse::Embed { + embedding, + start, + queued, + } => { + return_embeddings = Some(embedding.values); + } + } + } + + if let Some(return_embeddings) = return_embeddings { + Ok(EmbedResponse { + embeddings: return_embeddings, + }) + } else { + let err = InferError::IncompleteGeneration; + metrics::increment_counter!("lorax_request_failure", "err" => "incomplete"); + tracing::error!("{err}"); + Err(err) + } + } + /// Add best_of new requests to the queue and return a InferResponse of the sequence with /// the highest log probability per token #[instrument(skip(self))] @@ -341,7 +471,7 @@ async fn batching_task( // Get the next batch from the queue // This batch might be smaller than the maximum batch size if there are not enough requests // waiting in the queue - while let Some((mut entries, batch, span)) = adapter_scheduler + while let Some((mut batch_entries, batch, span)) = adapter_scheduler .next_batch( HashSet::new(), None, @@ -350,8 +480,8 @@ async fn batching_task( ) .await { - let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health) - .instrument(span) + let mut cached_batch = batch_entries + .process_first(&mut client, batch, span, &generation_health) .await; let mut waiting_tokens = 1; @@ -379,14 +509,10 @@ async fn batching_task( }; let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); - - let adapters_in_use = entries - .iter() - .map(|(_, entry)| entry.request.adapter.clone()) - .collect::>(); + let adapters_in_use = batch_entries.adapters_in_use(); // Try to get a new batch - if let Some((mut new_entries, new_batch, span)) = adapter_scheduler + if let Some((new_entries, new_batch, span)) = adapter_scheduler .next_batch( adapters_in_use, min_size, @@ -402,47 +528,55 @@ async fn batching_task( metrics::increment_counter!("lorax_batch_concat", "reason" => "wait_exceeded"); } - entries.iter_mut().for_each(|(_, entry)| { - // Create a new span to add the info that this entry is waiting - // because a new batch is being computed - let entry_waiting_span = info_span!(parent: &entry.span, "waiting"); - // Add relationships - span.follows_from(&entry_waiting_span); - entry_waiting_span.follows_from(&span); - // Update entry - entry.temp_span = Some(entry_waiting_span); - }); + batch_entries + .mut_state() + .batch_entries + .iter_mut() + .for_each(|(_, entry)| { + // Create a new span to add the info that this entry is waiting + // because a new batch is being computed + let entry_waiting_span = info_span!(parent: &entry.span, "waiting"); + // Add relationships + span.follows_from(&entry_waiting_span); + entry_waiting_span.follows_from(&span); + // Update entry + entry.temp_span = Some(entry_waiting_span); + }); // Generate one token for this new batch to have the attention past in cache - let new_cached_batch = - prefill(&mut client, new_batch, &mut new_entries, &generation_health) - .instrument(span) - .await; + let new_cached_batch = batch_entries + .process_first(&mut client, new_batch, span, &generation_health) + .await; // Reset waiting counter waiting_tokens = 1; // Extend current batch with the new batch if let Some(new_cached_batch) = new_cached_batch { - entries.extend(new_entries); + batch_entries.extend(new_entries); batches.push(new_cached_batch); } } // Create span for this batch to add context to inference calls - let next_batch_size = entries.len(); + let next_batch_size = batch_entries.len(); let next_batch_span = info_span!(parent: None, "batch", batch_size = next_batch_size); - entries.iter_mut().for_each(|(_, entry)| { - // Create a new span to link the batch back to this entry - let entry_batch_span = info_span!(parent: &entry.span, "infer"); - // Add relationships - next_batch_span.follows_from(&entry_batch_span); - entry_batch_span.follows_from(&next_batch_span); - // Update entry - entry.temp_span = Some(entry_batch_span); - }); - - cached_batch = decode(&mut client, batches, &mut entries, &generation_health) - .instrument(next_batch_span) + + batch_entries + .mut_state() + .batch_entries + .iter_mut() + .for_each(|(_, entry)| { + // Create a new span to link the batch back to this entry + let entry_batch_span = info_span!(parent: &entry.span, "infer"); + // Add relationships + next_batch_span.follows_from(&entry_batch_span); + entry_batch_span.follows_from(&next_batch_span); + // Update entry + entry.temp_span = Some(entry_batch_span); + }); + + cached_batch = batch_entries + .process_next(&mut client, batches, next_batch_span, &generation_health) .await; waiting_tokens += 1; } @@ -453,7 +587,7 @@ async fn batching_task( } #[instrument(skip_all)] -async fn prefill( +pub(crate) async fn prefill( client: &mut ShardedClient, batch: Batch, entries: &mut IntMap, @@ -490,7 +624,7 @@ async fn prefill( } #[instrument(skip_all)] -async fn decode( +pub(crate) async fn decode( client: &mut ShardedClient, batches: Vec, entries: &mut IntMap, @@ -527,6 +661,66 @@ async fn decode( } } +#[instrument(skip_all)] +pub(crate) async fn embed( + client: &mut ShardedClient, + batch: Batch, + entries: &mut IntMap, + generation_health: &Arc, +) -> Option { + let start_time = Instant::now(); + let batch_id = batch.id; + metrics::increment_counter!("lorax_batch_inference_count", "method" => "prefill"); + + match client.embed(batch).await { + Ok(results) => { + // Update health + generation_health.store(true, Ordering::SeqCst); + // Send generated tokens and filter stopped entries + results.into_iter().for_each(|embedding| { + let id = embedding.request_id; + // Get entry + // We can `expect` here as the request id should always be in the entries + let entry = entries + .get(&id) + .expect("ID not found in entries. This is a bug."); + + // Send generation responses back to the infer task + // If the receive an error from the Flume channel, it means that the client dropped the + // request and we need to stop generating hence why we unwrap_or(true) + let stopped = send_embeddings(embedding, entry) + .map_err(|err| { + if let SendTimeoutError::Timeout(_) = *err { + tracing::error!("Entry response channel timed out.") + } + + metrics::increment_counter!("lorax_request_failure", "err" => "dropped"); + err + }) + .unwrap_or(true); + if stopped { + entries + .remove(&id) + .expect("ID not found in entries. This is a bug."); + } + }); + + metrics::histogram!("lorax_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "embed"); + metrics::increment_counter!("lorax_batch_inference_success", "method" => "embed"); + None + } + // If we have an error, we discard the whole batch + Err(err) => { + // Update health + generation_health.store(false, Ordering::SeqCst); + let _ = client.clear_cache(Some(batch_id)).await; + send_errors(err, entries); + metrics::increment_counter!("lorax_batch_inference_failure", "method" => "embed"); + None + } + } +} + /// Filter a `batch` and remove all requests not present in `entries` #[instrument(skip_all)] async fn filter_batch( @@ -685,6 +879,29 @@ fn send_responses( Ok(stopped) } +/// Send responses through the `entry` response channel +fn send_embeddings( + embedding: Embedding, + entry: &Entry, +) -> Result>>> { + // Return directly if the channel is disconnected + if entry.response_tx.is_disconnected() { + return Ok(true); + } + + entry.response_tx.send_timeout( + Ok(InferStreamResponse::Embed { + embedding: embedding.clone(), + queued: entry.queue_time, + start: entry.batch_time.unwrap(), + }), + Duration::from_millis(10), + )?; + + // TODO(travis): redundant as we always return true, just make it return nothing + Ok(true) +} + /// Send errors to Infer for all `entries` #[instrument(skip_all)] fn send_errors(error: ClientError, entries: &mut IntMap) { @@ -712,6 +929,12 @@ pub(crate) enum InferStreamResponse { }, // Intermediate messages Token(Token), + // Embeddings + Embed { + embedding: Embedding, + start: Instant, + queued: Instant, + }, // Last message End { token: Token, diff --git a/router/src/lib.rs b/router/src/lib.rs index f5ced9b63..541387800 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1,5 +1,6 @@ /// LoRAX Webserver mod adapter; +mod batch; mod health; mod infer; mod loader; @@ -10,9 +11,9 @@ mod validation; use lorax_client::{AdapterParameters as AdapterParametersMessage, AlternativeTokens}; use lorax_client::{MajoritySignMethod, MergeStrategy}; +use batch::Entry; use infer::Infer; use loader::AdapterLoader; -use queue::Entry; use serde::{Deserialize, Serialize}; use serde_json::json; use utoipa::ToSchema; diff --git a/router/src/queue.rs b/router/src/queue.rs index b8e344375..8a51f240d 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -11,27 +11,10 @@ use tracing::{info_span, Span}; use crate::{ adapter::Adapter, + batch::Entry, infer::{InferError, InferStreamResponse}, - validation::ValidGenerateRequest, }; -/// AdapterLoader entry -#[derive(Debug)] -pub(crate) struct Entry { - /// Request - pub request: ValidGenerateRequest, - /// Response sender to communicate between the Infer struct and the batching_task - pub response_tx: flume::Sender>, - /// Span that will live as long as entry - pub span: Span, - /// Temporary span used as a guard when logging inference, wait times... - pub temp_span: Option, - /// Instant when this entry was queued - pub queue_time: Instant, - /// Instant when this entry was added to a batch - pub batch_time: Option, -} - #[derive(Debug, PartialEq)] pub(crate) enum AdapterStatus { Downloading, diff --git a/router/src/scheduler.rs b/router/src/scheduler.rs index 18b7d8663..5f790b880 100644 --- a/router/src/scheduler.rs +++ b/router/src/scheduler.rs @@ -1,7 +1,8 @@ use crate::{ adapter::Adapter, + batch::{self, BatchEntries, Entry}, queue::{AdapterEvent, AdapterQueuesState}, - AdapterLoader, Entry, + AdapterLoader, }; use lorax_client::{Batch, Request, ShardedClient}; use nohash_hasher::{BuildNoHashHasher, IntMap}; @@ -102,7 +103,7 @@ impl AdapterScheduler { } } -type NextBatch = (IntMap, Batch, Span); +type NextBatch = (Box, Batch, Span); /// Background task that manages the queues of the various adapters /// TODO(geoffrey): add tracing (span object) to the various commands @@ -253,12 +254,6 @@ impl AdapterSchedulerState { let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); next_batch_span.follows_from(&Span::current()); - let mut batch_requests = Vec::with_capacity(num_entries); - let mut batch_entries = - IntMap::with_capacity_and_hasher(num_entries, BuildNoHashHasher::default()); - - let mut index_to_adapter = HashMap::with_capacity(queues_state.active_len()); - let mut max_input_length = 0; let mut prefill_tokens: u32 = 0; let mut decode_tokens: u32 = 0; @@ -273,6 +268,7 @@ impl AdapterSchedulerState { ); // Pop entries starting from the front of the queue + let mut batch_entries: Option> = None; while let Some((id, mut entry, adapter)) = queues_state.next_entry() { // Filter entries where the response receiver was dropped (== entries where the request // was dropped by the client) @@ -282,25 +278,30 @@ impl AdapterSchedulerState { } if self.requires_padding { + let mut batch_requests_len = 0; + if let Some(batch_entries) = batch_entries.as_ref() { + batch_requests_len = batch_entries.len(); + } + // We pad to max input length in the Python shards // We need to take these padding tokens into the equation - max_input_length = max_input_length.max(entry.request.input_length); - prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length + max_input_length = max_input_length.max(entry.request.input_length()); + prefill_tokens = (batch_requests_len + 1) as u32 * max_input_length } else { // pad to block size - prefill_tokens += ((entry.request.input_length + self.block_size - 1) + prefill_tokens += ((entry.request.input_length() + self.block_size - 1) / self.block_size) * self.block_size; } if self.requires_padding { - decode_tokens += entry.request.stopping_parameters.max_new_tokens; + decode_tokens += entry.request.max_new_tokens(); } else { let max_new_tokens = match self.window_size { - None => entry.request.stopping_parameters.max_new_tokens, + None => entry.request.max_new_tokens(), Some(window_size) => min( - window_size.saturating_sub(entry.request.input_length), - entry.request.stopping_parameters.max_new_tokens, + window_size.saturating_sub(entry.request.input_length()), + entry.request.max_new_tokens(), ), }; @@ -318,6 +319,20 @@ impl AdapterSchedulerState { break; } + if batch_entries.is_none() { + batch_entries = Some( + entry + .request + .to_batch(num_entries, queues_state.active_len()), + ); + } + + if !batch_entries.as_ref().unwrap().can_add(&entry) { + // Incompatible entry for this batch. Reinsert and break + queues_state.push_front(&adapter, id, entry); + break; + } + // Create a new span to link the batch back to this entry let entry_batch_span = info_span!(parent: &entry.span, "infer"); // Add relationships @@ -326,61 +341,41 @@ impl AdapterSchedulerState { // Update entry entry.temp_span = Some(entry_batch_span); - batch_requests.push(Request { - id, - prefill_logprobs: entry.request.decoder_input_details, - inputs: entry.request.inputs.clone(), - truncate: entry.request.truncate, - parameters: Some(entry.request.parameters.clone()), - stopping_parameters: Some(entry.request.stopping_parameters.clone()), - adapter_index: adapter.index(), - apply_chat_template: entry.request.apply_chat_template, - }); - // Set batch_time - entry.batch_time = Some(Instant::now()); - // Insert in batch_entries IntMap - batch_entries.insert(id, entry); - // Map from adapter index back to queue in case we need to add back entries below - // let queue = queue_map.get_mut(&adapter).unwrap(); - index_to_adapter.insert(adapter.index(), adapter); + batch_entries.as_mut().unwrap().add(id, entry, adapter) } + if batch_entries.is_none() { + return None; + } + + let mut batch_entries = batch_entries.unwrap(); + // Empty batch - if batch_requests.is_empty() { + if batch_entries.is_empty() { return None; } // Check if our batch is big enough if let Some(min_size) = min_size { // Batch is too small - if batch_requests.len() < min_size { + if batch_entries.len() < min_size { // Add back entries to the queue in the correct order - for r in batch_requests.into_iter().rev() { - let id = r.id; - let entry = batch_entries.remove(&id).unwrap(); - let adapter_index = r.adapter_index; - let adapter = index_to_adapter.get_mut(&adapter_index).unwrap(); - queues_state.push_front(adapter, id, entry); + for (adapter, id, entry) in batch_entries.drain() { + queues_state.push_front(&adapter, id, entry); } return None; } } - // Final batch size - let size = batch_requests.len() as u32; - next_batch_span.record("batch_size", size); + next_batch_span.record("batch_size", batch_entries.len() as u32); + let max_tokens = prefill_tokens + decode_tokens; + let batch = batch_entries.create_batch_data(self.next_batch_id, max_tokens); - let batch = Batch { - id: self.next_batch_id, - requests: batch_requests, - size, - max_tokens: (prefill_tokens + decode_tokens), - }; // Increment batch id self.next_batch_id += 1; - metrics::histogram!("lorax_batch_next_size", batch.size as f64); + metrics::histogram!("lorax_batch_next_size", batch_entries.len() as f64); Some((batch_entries, batch, next_batch_span)) } diff --git a/router/src/server.rs b/router/src/server.rs index 98f13b48b..4a4eb9a12 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -841,6 +841,15 @@ async fn generate_stream_with_callback( yield Ok(callback(stream_token)); break; + }, + InferStreamResponse::Embed { + .. + } => { + let err = InferError::from(ValidationError::EmbeddingModel); + metrics::increment_counter!("lorax_request_failure", "err" => "bad_request"); + tracing::error!("{err}"); + yield Ok(Event::from(err)); + break; } } } @@ -1333,31 +1342,20 @@ impl From for Event { )] #[instrument(skip_all)] async fn embed( + infer: Extension, mut client: Extension, Json(req): Json, ) -> Result, (StatusCode, Json)> { - let input = req.inputs.clone(); - let embeddings = client.embed(input).await.unwrap(); - let embeddings = embeddings.get(0); - - // TODO: better error enums - if (!embeddings.unwrap().error_msg.is_empty()) { - return Err(( - StatusCode::BAD_REQUEST, - Json(ErrorResponse { - error: embeddings.unwrap().error_msg.clone(), - error_type: "model doesn't support embeddings".to_string(), - }), - )); - } + let span = tracing::Span::current(); + let start_time = Instant::now(); + metrics::increment_counter!("lorax_request_count"); + + tracing::debug!("Input: {}", req.inputs); + + // Inference + let response = infer.embed(req).await?; - let values = embeddings - .map(|emb| emb.embeddings.as_ref().map(|emb| emb.values.clone())) - .flatten() - .unwrap_or_default(); - Ok(Json(EmbedResponse { - embeddings: values.clone(), - })) + Ok(Json(response)) } /// Tokenize inputs diff --git a/router/src/validation.rs b/router/src/validation.rs index 6e2c5c70d..c1596ca19 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -1,4 +1,5 @@ use crate::adapter::Adapter; +use crate::batch::ValidGenerateRequest; /// Payload validation logic use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; use crate::{GenerateParameters, GenerateRequest}; @@ -86,7 +87,7 @@ impl Validation { } #[instrument(skip(self, inputs))] - async fn validate_input( + pub(crate) async fn validate_input( &self, inputs: String, truncate: Option, @@ -396,18 +397,6 @@ type TokenizerRequest = ( Span, ); -#[derive(Debug)] -pub(crate) struct ValidGenerateRequest { - pub inputs: String, - pub input_length: u32, - pub truncate: u32, - pub decoder_input_details: bool, - pub parameters: NextTokenChooserParameters, - pub stopping_parameters: StoppingCriteriaParameters, - pub adapter: Adapter, - pub apply_chat_template: bool, -} - #[derive(Error, Debug)] pub enum ValidationError { #[error("`best_of` must be > 0 and <= {0}. Given: {1}")] diff --git a/server/lorax_server/models/flash_bert.py b/server/lorax_server/models/flash_bert.py index 86904117d..beaa7f904 100644 --- a/server/lorax_server/models/flash_bert.py +++ b/server/lorax_server/models/flash_bert.py @@ -172,7 +172,8 @@ def __init__( else: raise NotImplementedError("FlashSantacoderSharded is only available on GPU") - self.device = device + # self.device = device + self.device = "cpu" self.dtype = dtype tokenizer = AutoTokenizer.from_pretrained(model_id) @@ -214,7 +215,10 @@ def supports_text_generation(self) -> bool: return False def warmup(self, batch: FlashEmbeddingBatch, max_new_tokens: int) -> int | None: - return 42 # no-op for now + # Note: This is meant to 1) preallocate the memory by doing a forward pass + # and then just returning the max seqlen since for embeddings we are never generating + _ = self.embed(batch) + return batch.max_s def generate_token(self, batch: FlashEmbeddingBatch) -> None: if not self.supports_text_generation: @@ -226,26 +230,14 @@ def forward(self, batch: FlashEmbeddingBatch): @tracer.start_as_current_span("embed") def embed(self, batch: FlashEmbeddingBatch) -> Embedding: - embedding = self.model.forward( + embedding: torch.Tensor = self.model.forward( input_ids=batch.input_ids, token_type_ids=batch.token_type_ids, position_ids=batch.position_ids, cu_seqlens=batch.cu_seqlens, max_s=batch.max_s, ) - cpu_results = embedding.view(-1).tolist() - - return Embedding(values=cpu_results[: self.hidden_size]) - - def tokenize_to_batch(self, inputs) -> FlashEmbeddingBatch: - tokens = self.tokenizer(inputs, return_token_type_ids=True) - num_tokens = len(tokens["input_ids"]) - position_ids = range(num_tokens) - return FlashEmbeddingBatch( - input_ids=torch.tensor(tokens["input_ids"], dtype=torch.int32, device=self.device), - token_type_ids=torch.tensor(tokens["token_type_ids"], dtype=torch.int32, device=self.device), - position_ids=torch.tensor(position_ids, dtype=torch.int32, device=self.device), - cu_seqlens=torch.tensor([0, num_tokens], dtype=torch.int32, device=self.device), - max_s=num_tokens, - size=1, - ) + embedding = embedding.reshape(embedding.shape[0], -1)[:, : self.hidden_size] + + cpu_results = embedding.cpu().tolist() + return cpu_results diff --git a/server/lorax_server/models/types.py b/server/lorax_server/models/types.py index c77f784e6..18accb3ff 100644 --- a/server/lorax_server/models/types.py +++ b/server/lorax_server/models/types.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from typing import List, Optional +import numpy as np import torch from transformers import PreTrainedTokenizerBase @@ -56,7 +57,6 @@ def to_pb(self) -> generate_pb2.GeneratedText: seed=self.seed, ) - @dataclass class PrefillTokens: token_ids: List[int] @@ -128,6 +128,7 @@ def to_pb(self) -> generate_pb2.Generation: @dataclass class FlashEmbeddingBatch(ABC): + request_ids: List[int] input_ids: torch.Tensor token_type_ids: torch.Tensor position_ids: torch.Tensor @@ -136,8 +137,78 @@ class FlashEmbeddingBatch(ABC): max_s: int size: int - def __len__(self): + def __len__(self) -> int: return self.size - def from_pb(self, *args, **kwargs): - return None + @classmethod + def from_pb( + self, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + tokenizers: TokenizerManager, + dtype: torch.dtype, + device: torch.device, + ) -> "FlashEmbeddingBatch": + batch_inputs = [] + max_truncation = 0 + for r in pb.requests: + inputs = tokenizers.get_inputs(r, tokenizer) + batch_inputs.append(inputs) + max_truncation = max(max_truncation, r.truncate) + + batch_inputs = tokenizer( + batch_inputs, + return_token_type_ids=True, + truncation=True, + max_length=max_truncation, + ) + + batch_tokenized_inputs = batch_inputs["input_ids"] + batch_token_type_ids = batch_inputs["token_type_ids"] + + all_input_ids = [] + position_ids = [] + all_token_type_ids = [] + cu_seqlens = [0] + + max_s = 0 + cumulative_length = 0 + + for i, (r, tokenized_input, token_type_ids) in enumerate(zip(pb.requests, batch_tokenized_inputs, batch_token_type_ids)): + tokenized_input = tokenized_input[-r.truncate :] + token_type_ids = token_type_ids[-r.truncate :] + all_input_ids.append(tokenized_input) + all_token_type_ids.append(token_type_ids) + + input_length = len(tokenized_input) + max_s = max(max_s, input_length) + cu_seqlens.append(cumulative_length + input_length) + + # Position ids + request_position_ids = torch.arange(0, input_length, dtype=torch.int32) + position_ids.append(request_position_ids) + + cumulative_length += input_length + + if len(pb.requests) > 1: + input_ids = np.concatenate(all_input_ids, dtype=np.int64) + final_token_type_ids = np.concatenate(all_token_type_ids, dtype=np.int64) + position_ids = torch.cat(position_ids) + else: + input_ids = all_input_ids[0] + final_token_type_ids = all_token_type_ids[0] + position_ids = position_ids[0] + + input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + final_token_type_ids = torch.tensor(final_token_type_ids, dtype=torch.int64, device=device) + position_ids = position_ids.to(device) + + return FlashEmbeddingBatch( + request_ids=[r.id for r in pb.requests], + input_ids=input_ids, + token_type_ids=final_token_type_ids, + position_ids=position_ids, + cu_seqlens=torch.tensor(cu_seqlens, dtype=torch.int32, device=device), + max_s=max_s, + size=len(batch_inputs), + ) diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index 5ea9bd6cf..63cc39c77 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -97,14 +97,20 @@ async def Prefill(self, request: generate_pb2.PrefillRequest, context): async def Embed(self, request: generate_pb2.EmbedRequest, context): if not self.model.supports_embeddings: - logger.error("Model does not support embeddings") - return generate_pb2.EmbedResponse( - embeddings=generate_pb2.Embedding(), errorMsg="Model does not support embeddings" - ) - batch = request.inputs - tokenised_batch = self.model.tokenize_to_batch(batch) - embeddings = self.model.embed(tokenised_batch) - return generate_pb2.EmbedResponse(embeddings=embeddings) + raise ValueError("Model does not support embeddings") + + batch = self.model.batch_type.from_pb( + request.batch, + self.model.tokenizer, + self.model.tokenizers, + self.model.dtype, + self.model.device, + ) + embeddings = self.model.embed(batch) + embeddings_proto = [] + for i, embedding in enumerate(embeddings): + embeddings_proto.append(generate_pb2.Embedding(request_id=batch.request_ids[i], values=embedding)) + return generate_pb2.EmbedResponse(embeddings=embeddings_proto) async def Decode(self, request: generate_pb2.DecodeRequest, context): if len(request.batches) == 0: