From f1269f0466fd5e1484170563522890bb3304cbe7 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Tue, 28 May 2024 22:42:54 -0700 Subject: [PATCH 01/31] Added batch without generics --- router/src/batch.rs | 117 +++++++++++++++++++++++++++++++++++++++ router/src/lib.rs | 1 + router/src/queue.rs | 17 ------ router/src/validation.rs | 12 ---- 4 files changed, 118 insertions(+), 29 deletions(-) create mode 100644 router/src/batch.rs diff --git a/router/src/batch.rs b/router/src/batch.rs new file mode 100644 index 000000000..4c82eff5f --- /dev/null +++ b/router/src/batch.rs @@ -0,0 +1,117 @@ +use std::sync::Arc; + +use lorax_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; +use tokio::time::Instant; +use tracing::Span; + +use crate::{ + adapter::Adapter, + infer::{InferError, InferStreamResponse}, +}; + +pub(crate) trait ValidRequest { + fn input_length(&self) -> u32; + fn max_new_tokens(&self) -> u32; + fn to_batch(&self) -> Arc; +} + +impl ValidRequest for ValidGenerateRequest { + fn input_length(&self) -> u32 { + self.input_length + } + + fn max_new_tokens(&self) -> u32 { + self.stopping_parameters.max_new_tokens + } + + fn to_batch(&self) -> Arc { + Arc::new(GenerateBatchEntries::new()) + } +} + +#[derive(Debug)] +pub(crate) struct ValidEmbedRequest { + pub inputs: String, + pub input_length: u32, +} + +impl ValidRequest for ValidEmbedRequest { + fn input_length(&self) -> u32 { + self.input_length + } + + fn max_new_tokens(&self) -> u32 { + 1 + } + + fn to_batch(&self) -> Arc { + Arc::new(EmbedBatchEntries::new()) + } +} + +#[derive(Debug)] +pub(crate) struct ValidGenerateRequest { + pub inputs: String, + pub input_length: u32, + pub truncate: u32, + pub decoder_input_details: bool, + pub parameters: NextTokenChooserParameters, + pub stopping_parameters: StoppingCriteriaParameters, + pub adapter: Adapter, + pub apply_chat_template: bool, +} + +/// AdapterLoader entry +#[derive(Debug)] +pub(crate) struct Entry { + /// Request + pub request: Arc, + /// Response sender to communicate between the Infer struct and the batching_task + pub response_tx: flume::Sender>, + /// Span that will live as long as entry + pub span: Span, + /// Temporary span used as a guard when logging inference, wait times... + pub temp_span: Option, + /// Instant when this entry was queued + pub queue_time: Instant, + /// Instant when this entry was added to a batch + pub batch_time: Option, +} + +pub(crate) trait BatchEntries { + fn add(&mut self, entry: Entry); +} + +#[derive(Debug)] +pub(crate) struct GenerateBatchEntries { + pub(crate) entries: Vec, +} + +impl GenerateBatchEntries { + pub(crate) fn new() -> Self { + Self { entries: vec![] } + } +} + +impl BatchEntries for GenerateBatchEntries { + fn add(&mut self, entry: Entry) { + self.entries.push(entry); + } +} + +#[derive(Debug)] +pub(crate) struct EmbedBatchEntries { + pub(crate) entries: Vec, +} + +impl EmbedBatchEntries { + pub(crate) fn new() -> Self { + Self { entries: vec![] } + } +} + +impl BatchEntries for EmbedBatchEntries { + fn add(&mut self, entry: Entry) { + self.entries.push(entry); + } +} diff --git a/router/src/lib.rs b/router/src/lib.rs index f5ced9b63..248b0b0f5 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1,5 +1,6 @@ /// LoRAX Webserver mod adapter; +mod batch; mod health; mod infer; mod loader; diff --git a/router/src/queue.rs b/router/src/queue.rs index b8e344375..00cc310e3 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -15,23 +15,6 @@ use crate::{ validation::ValidGenerateRequest, }; -/// AdapterLoader entry -#[derive(Debug)] -pub(crate) struct Entry { - /// Request - pub request: ValidGenerateRequest, - /// Response sender to communicate between the Infer struct and the batching_task - pub response_tx: flume::Sender>, - /// Span that will live as long as entry - pub span: Span, - /// Temporary span used as a guard when logging inference, wait times... - pub temp_span: Option, - /// Instant when this entry was queued - pub queue_time: Instant, - /// Instant when this entry was added to a batch - pub batch_time: Option, -} - #[derive(Debug, PartialEq)] pub(crate) enum AdapterStatus { Downloading, diff --git a/router/src/validation.rs b/router/src/validation.rs index 6e2c5c70d..21313a639 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -396,18 +396,6 @@ type TokenizerRequest = ( Span, ); -#[derive(Debug)] -pub(crate) struct ValidGenerateRequest { - pub inputs: String, - pub input_length: u32, - pub truncate: u32, - pub decoder_input_details: bool, - pub parameters: NextTokenChooserParameters, - pub stopping_parameters: StoppingCriteriaParameters, - pub adapter: Adapter, - pub apply_chat_template: bool, -} - #[derive(Error, Debug)] pub enum ValidationError { #[error("`best_of` must be > 0 and <= {0}. Given: {1}")] From d7d28559428618bf6710ab6b464b5232f66b3ae4 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Tue, 28 May 2024 22:46:44 -0700 Subject: [PATCH 02/31] Fixed --- router/src/batch.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/router/src/batch.rs b/router/src/batch.rs index 4c82eff5f..65f2e1f62 100644 --- a/router/src/batch.rs +++ b/router/src/batch.rs @@ -1,3 +1,4 @@ +use core::fmt::Debug; use std::sync::Arc; use lorax_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; @@ -15,6 +16,15 @@ pub(crate) trait ValidRequest { fn to_batch(&self) -> Arc; } +impl Debug for dyn ValidRequest { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("ValidRequest") + .field("input_length", &self.input_length()) + .field("max_new_tokens", &self.max_new_tokens()) + .finish() + } +} + impl ValidRequest for ValidGenerateRequest { fn input_length(&self) -> u32 { self.input_length From 9a20d3dd1426d4f4b93495bee88b12f039a9ce8f Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Tue, 28 May 2024 22:50:19 -0700 Subject: [PATCH 03/31] Fixed imports --- router/src/lib.rs | 2 +- router/src/queue.rs | 2 +- router/src/scheduler.rs | 3 ++- router/src/validation.rs | 1 + 4 files changed, 5 insertions(+), 3 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index 248b0b0f5..541387800 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -11,9 +11,9 @@ mod validation; use lorax_client::{AdapterParameters as AdapterParametersMessage, AlternativeTokens}; use lorax_client::{MajoritySignMethod, MergeStrategy}; +use batch::Entry; use infer::Infer; use loader::AdapterLoader; -use queue::Entry; use serde::{Deserialize, Serialize}; use serde_json::json; use utoipa::ToSchema; diff --git a/router/src/queue.rs b/router/src/queue.rs index 00cc310e3..8a51f240d 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -11,8 +11,8 @@ use tracing::{info_span, Span}; use crate::{ adapter::Adapter, + batch::Entry, infer::{InferError, InferStreamResponse}, - validation::ValidGenerateRequest, }; #[derive(Debug, PartialEq)] diff --git a/router/src/scheduler.rs b/router/src/scheduler.rs index 18b7d8663..a8566c530 100644 --- a/router/src/scheduler.rs +++ b/router/src/scheduler.rs @@ -1,7 +1,8 @@ use crate::{ adapter::Adapter, + batch::Entry, queue::{AdapterEvent, AdapterQueuesState}, - AdapterLoader, Entry, + AdapterLoader, }; use lorax_client::{Batch, Request, ShardedClient}; use nohash_hasher::{BuildNoHashHasher, IntMap}; diff --git a/router/src/validation.rs b/router/src/validation.rs index 21313a639..5a8f24850 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -1,4 +1,5 @@ use crate::adapter::Adapter; +use crate::batch::ValidGenerateRequest; /// Payload validation logic use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; use crate::{GenerateParameters, GenerateRequest}; From 6f6889df79f7e6e280416dcafe346748511a06c0 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Tue, 28 May 2024 22:53:45 -0700 Subject: [PATCH 04/31] Cleanup --- router/src/batch.rs | 11 +---------- router/src/scheduler.rs | 12 ++++++------ 2 files changed, 7 insertions(+), 16 deletions(-) diff --git a/router/src/batch.rs b/router/src/batch.rs index 65f2e1f62..9a674f920 100644 --- a/router/src/batch.rs +++ b/router/src/batch.rs @@ -10,21 +10,12 @@ use crate::{ infer::{InferError, InferStreamResponse}, }; -pub(crate) trait ValidRequest { +pub(crate) trait ValidRequest: Sync + Send + Debug { fn input_length(&self) -> u32; fn max_new_tokens(&self) -> u32; fn to_batch(&self) -> Arc; } -impl Debug for dyn ValidRequest { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.debug_struct("ValidRequest") - .field("input_length", &self.input_length()) - .field("max_new_tokens", &self.max_new_tokens()) - .finish() - } -} - impl ValidRequest for ValidGenerateRequest { fn input_length(&self) -> u32 { self.input_length diff --git a/router/src/scheduler.rs b/router/src/scheduler.rs index a8566c530..b7f493ffd 100644 --- a/router/src/scheduler.rs +++ b/router/src/scheduler.rs @@ -285,23 +285,23 @@ impl AdapterSchedulerState { if self.requires_padding { // We pad to max input length in the Python shards // We need to take these padding tokens into the equation - max_input_length = max_input_length.max(entry.request.input_length); + max_input_length = max_input_length.max(entry.request.input_length()); prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length } else { // pad to block size - prefill_tokens += ((entry.request.input_length + self.block_size - 1) + prefill_tokens += ((entry.request.input_length() + self.block_size - 1) / self.block_size) * self.block_size; } if self.requires_padding { - decode_tokens += entry.request.stopping_parameters.max_new_tokens; + decode_tokens += entry.request.max_new_tokens(); } else { let max_new_tokens = match self.window_size { - None => entry.request.stopping_parameters.max_new_tokens, + None => entry.request.max_new_tokens(), Some(window_size) => min( - window_size.saturating_sub(entry.request.input_length), - entry.request.stopping_parameters.max_new_tokens, + window_size.saturating_sub(entry.request.input_length()), + entry.request.max_new_tokens(), ), }; From a7ae51f229576467df15e66ecace0ccaf94d1d73 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 29 May 2024 10:39:31 -0700 Subject: [PATCH 05/31] Fix --- router/src/batch.rs | 110 ++++++++++++++++++++++++++++++++++------ router/src/infer.rs | 4 +- router/src/scheduler.rs | 51 ++++++------------- 3 files changed, 111 insertions(+), 54 deletions(-) diff --git a/router/src/batch.rs b/router/src/batch.rs index 9a674f920..75ad4f99c 100644 --- a/router/src/batch.rs +++ b/router/src/batch.rs @@ -1,19 +1,22 @@ use core::fmt::Debug; -use std::sync::Arc; +use std::{any::Any, collections::HashMap, sync::Arc}; -use lorax_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; +use lorax_client::{NextTokenChooserParameters, Request, StoppingCriteriaParameters}; +use nohash_hasher::{BuildNoHashHasher, IntMap}; use tokio::time::Instant; -use tracing::Span; +use tracing::{info_span, Span}; use crate::{ adapter::Adapter, infer::{InferError, InferStreamResponse}, }; -pub(crate) trait ValidRequest: Sync + Send + Debug { +pub(crate) trait ValidRequest: Sync + Send + Debug + Any { fn input_length(&self) -> u32; fn max_new_tokens(&self) -> u32; - fn to_batch(&self) -> Arc; + fn adapter(&self) -> Adapter; + fn to_batch(&self, num_entries: usize, queue_len: usize) -> Arc; + fn as_any(&self) -> &dyn Any; } impl ValidRequest for ValidGenerateRequest { @@ -25,8 +28,16 @@ impl ValidRequest for ValidGenerateRequest { self.stopping_parameters.max_new_tokens } - fn to_batch(&self) -> Arc { - Arc::new(GenerateBatchEntries::new()) + fn adapter(&self) -> Adapter { + self.adapter + } + + fn to_batch(&self, num_entries: usize, queue_len: usize) -> Arc { + Arc::new(GenerateBatchEntries::new(num_entries, queue_len)) + } + + fn as_any(&self) -> &dyn Any { + self } } @@ -34,6 +45,7 @@ impl ValidRequest for ValidGenerateRequest { pub(crate) struct ValidEmbedRequest { pub inputs: String, pub input_length: u32, + pub adapter: Adapter, } impl ValidRequest for ValidEmbedRequest { @@ -45,9 +57,17 @@ impl ValidRequest for ValidEmbedRequest { 1 } - fn to_batch(&self) -> Arc { + fn adapter(&self) -> Adapter { + self.adapter + } + + fn to_batch(&self, num_entries: usize, queue_len: usize) -> Arc { Arc::new(EmbedBatchEntries::new()) } + + fn as_any(&self) -> &dyn Any { + self + } } #[derive(Debug)] @@ -80,23 +100,81 @@ pub(crate) struct Entry { } pub(crate) trait BatchEntries { - fn add(&mut self, entry: Entry); + fn add(&mut self, id: u64, entry: Entry, adapter: Adapter) -> bool; } #[derive(Debug)] pub(crate) struct GenerateBatchEntries { - pub(crate) entries: Vec, + pub(crate) batch_requests: Vec, + pub(crate) batch_entries: HashMap>, + pub(crate) index_to_adapter: HashMap, + next_batch_span: Span, } impl GenerateBatchEntries { - pub(crate) fn new() -> Self { - Self { entries: vec![] } + pub(crate) fn new(num_entries: usize, queue_len: usize) -> Self { + // Create span for this batch to add context to inference calls + let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); + next_batch_span.follows_from(&Span::current()); + + let mut batch_requests = Vec::with_capacity(num_entries); + let mut batch_entries = + IntMap::with_capacity_and_hasher(num_entries, BuildNoHashHasher::default()); + + let mut index_to_adapter = HashMap::with_capacity(queue_len); + + Self { + batch_requests, + batch_entries, + index_to_adapter, + next_batch_span, + } } } impl BatchEntries for GenerateBatchEntries { - fn add(&mut self, entry: Entry) { - self.entries.push(entry); + fn add(&mut self, id: u64, mut entry: Entry, adapter: Adapter) -> bool { + // return false if the entry.request is not of type ValidGenerateRequest + let valid_request = entry + .request + .as_ref() + .as_any() + .downcast_ref::(); + + if valid_request.is_none() { + return false; + } + + let request = valid_request.unwrap(); + + // self.entries.push(entry); + + // Create a new span to link the batch back to this entry + let entry_batch_span = info_span!(parent: &entry.span, "infer"); + // Add relationships + self.next_batch_span.follows_from(&entry_batch_span); + entry_batch_span.follows_from(&self.next_batch_span); + // Update entry + entry.temp_span = Some(entry_batch_span); + + self.batch_requests.push(Request { + id, + prefill_logprobs: request.decoder_input_details, + inputs: request.inputs.clone(), + truncate: request.truncate, + parameters: Some(request.parameters.clone()), + stopping_parameters: Some(request.stopping_parameters.clone()), + adapter_index: adapter.index(), + apply_chat_template: request.apply_chat_template, + }); + // Set batch_time + entry.batch_time = Some(Instant::now()); + // Insert in batch_entries IntMap + self.batch_entries.insert(id, entry); + // Map from adapter index back to queue in case we need to add back entries below + // let queue = queue_map.get_mut(&adapter).unwrap(); + self.index_to_adapter.insert(adapter.index(), adapter); + return true; } } @@ -112,7 +190,7 @@ impl EmbedBatchEntries { } impl BatchEntries for EmbedBatchEntries { - fn add(&mut self, entry: Entry) { - self.entries.push(entry); + fn add(&mut self, id: u64, mut entry: Entry, adapter: Adapter) -> bool { + return false; } } diff --git a/router/src/infer.rs b/router/src/infer.rs index 99aaebe9c..8818a5de2 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -169,7 +169,7 @@ impl Infer { self.adapter_scheduler.process( adapter.clone(), Entry { - request: valid_request, + request: Arc::new(valid_request), response_tx, span: Span::current(), temp_span: None, @@ -382,7 +382,7 @@ async fn batching_task( let adapters_in_use = entries .iter() - .map(|(_, entry)| entry.request.adapter.clone()) + .map(|(_, entry)| entry.request.adapter()) .collect::>(); // Try to get a new batch diff --git a/router/src/scheduler.rs b/router/src/scheduler.rs index b7f493ffd..003f599aa 100644 --- a/router/src/scheduler.rs +++ b/router/src/scheduler.rs @@ -1,6 +1,6 @@ use crate::{ adapter::Adapter, - batch::Entry, + batch::{BatchEntries, Entry}, queue::{AdapterEvent, AdapterQueuesState}, AdapterLoader, }; @@ -250,16 +250,6 @@ impl AdapterSchedulerState { } } - // Create span for this batch to add context to inference calls - let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); - next_batch_span.follows_from(&Span::current()); - - let mut batch_requests = Vec::with_capacity(num_entries); - let mut batch_entries = - IntMap::with_capacity_and_hasher(num_entries, BuildNoHashHasher::default()); - - let mut index_to_adapter = HashMap::with_capacity(queues_state.active_len()); - let mut max_input_length = 0; let mut prefill_tokens: u32 = 0; let mut decode_tokens: u32 = 0; @@ -274,6 +264,7 @@ impl AdapterSchedulerState { ); // Pop entries starting from the front of the queue + let batch: Option> = None; while let Some((id, mut entry, adapter)) = queues_state.next_entry() { // Filter entries where the response receiver was dropped (== entries where the request // was dropped by the client) @@ -319,31 +310,19 @@ impl AdapterSchedulerState { break; } - // Create a new span to link the batch back to this entry - let entry_batch_span = info_span!(parent: &entry.span, "infer"); - // Add relationships - next_batch_span.follows_from(&entry_batch_span); - entry_batch_span.follows_from(&next_batch_span); - // Update entry - entry.temp_span = Some(entry_batch_span); - - batch_requests.push(Request { - id, - prefill_logprobs: entry.request.decoder_input_details, - inputs: entry.request.inputs.clone(), - truncate: entry.request.truncate, - parameters: Some(entry.request.parameters.clone()), - stopping_parameters: Some(entry.request.stopping_parameters.clone()), - adapter_index: adapter.index(), - apply_chat_template: entry.request.apply_chat_template, - }); - // Set batch_time - entry.batch_time = Some(Instant::now()); - // Insert in batch_entries IntMap - batch_entries.insert(id, entry); - // Map from adapter index back to queue in case we need to add back entries below - // let queue = queue_map.get_mut(&adapter).unwrap(); - index_to_adapter.insert(adapter.index(), adapter); + if batch.is_none() { + batch = Some( + entry + .request + .to_batch(num_entries, queues_state.active_len()), + ); + } + + if !batch.as_ref().unwrap().add(id, entry, adapter) { + // Incompatible entry for this batch. Reinsert and break + queues_state.push_front(&adapter, id, entry); + break; + } } // Empty batch From 378fb3cee8df2340f378d540dcc31f078a8219fe Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 29 May 2024 15:07:55 -0700 Subject: [PATCH 06/31] Refactor into state --- router/src/batch.rs | 170 +++++++++++++++++++++++++++++++++------- router/src/scheduler.rs | 36 +++++---- 2 files changed, 161 insertions(+), 45 deletions(-) diff --git a/router/src/batch.rs b/router/src/batch.rs index 75ad4f99c..1748828c1 100644 --- a/router/src/batch.rs +++ b/router/src/batch.rs @@ -1,7 +1,7 @@ use core::fmt::Debug; use std::{any::Any, collections::HashMap, sync::Arc}; -use lorax_client::{NextTokenChooserParameters, Request, StoppingCriteriaParameters}; +use lorax_client::{Batch, NextTokenChooserParameters, Request, StoppingCriteriaParameters}; use nohash_hasher::{BuildNoHashHasher, IntMap}; use tokio::time::Instant; use tracing::{info_span, Span}; @@ -101,17 +101,21 @@ pub(crate) struct Entry { pub(crate) trait BatchEntries { fn add(&mut self, id: u64, entry: Entry, adapter: Adapter) -> bool; + fn drain(&mut self) -> Vec<(Adapter, u64, Entry)>; + fn create_batch_data(&self, batch_id: u64, max_tokens: u32) -> Batch; + fn is_empty(&self) -> bool; + fn len(&self) -> usize; } #[derive(Debug)] -pub(crate) struct GenerateBatchEntries { +pub(crate) struct BatchEntriesState { pub(crate) batch_requests: Vec, pub(crate) batch_entries: HashMap>, pub(crate) index_to_adapter: HashMap, next_batch_span: Span, } -impl GenerateBatchEntries { +impl BatchEntriesState { pub(crate) fn new(num_entries: usize, queue_len: usize) -> Self { // Create span for this batch to add context to inference calls let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); @@ -130,10 +134,78 @@ impl GenerateBatchEntries { next_batch_span, } } + + fn add(&mut self, id: u64, mut entry: Entry, adapter: Adapter, request: Request) { + // Create a new span to link the batch back to this entry + let entry_batch_span = info_span!(parent: &entry.span, "infer"); + // Add relationships + self.next_batch_span.follows_from(&entry_batch_span); + entry_batch_span.follows_from(&self.next_batch_span); + // Update entry + entry.temp_span = Some(entry_batch_span); + + self.batch_requests.push(request); + + // Set batch_time + entry.batch_time = Some(Instant::now()); + // Insert in batch_entries IntMap + self.batch_entries.insert(id, entry); + // Map from adapter index back to queue in case we need to add back entries below + // let queue = queue_map.get_mut(&adapter).unwrap(); + self.index_to_adapter.insert(adapter.index(), adapter); + } + + fn drain(&mut self) -> Vec<(Adapter, u64, Entry)> { + let mut entries = Vec::with_capacity(self.batch_requests.len()); + + for r in self.batch_requests.into_iter().rev() { + let id = r.id; + let entry = self.batch_entries.remove(&id).unwrap(); + let adapter_index = r.adapter_index; + let adapter = self.index_to_adapter.get_mut(&adapter_index).unwrap(); + entries.push((adapter.clone(), id, entry)); + } + + entries + } + + fn create_batch_data(&self, batch_id: u64, max_tokens: u32) -> Batch { + // Final batch size + let size = self.len() as u32; + self.next_batch_span.record("batch_size", size); + + Batch { + id: batch_id, + requests: self.batch_requests, + size, + max_tokens, + } + } + + fn is_empty(&self) -> bool { + self.batch_requests.is_empty() + } + + fn len(&self) -> usize { + self.batch_requests.len() + } +} + +#[derive(Debug)] +pub(crate) struct GenerateBatchEntries { + pub(crate) state: BatchEntriesState, +} + +impl GenerateBatchEntries { + pub(crate) fn new(num_entries: usize, queue_len: usize) -> Self { + Self { + state: BatchEntriesState::new(num_entries, queue_len), + } + } } impl BatchEntries for GenerateBatchEntries { - fn add(&mut self, id: u64, mut entry: Entry, adapter: Adapter) -> bool { + fn add(&mut self, id: u64, entry: Entry, adapter: Adapter) -> bool { // return false if the entry.request is not of type ValidGenerateRequest let valid_request = entry .request @@ -146,18 +218,7 @@ impl BatchEntries for GenerateBatchEntries { } let request = valid_request.unwrap(); - - // self.entries.push(entry); - - // Create a new span to link the batch back to this entry - let entry_batch_span = info_span!(parent: &entry.span, "infer"); - // Add relationships - self.next_batch_span.follows_from(&entry_batch_span); - entry_batch_span.follows_from(&self.next_batch_span); - // Update entry - entry.temp_span = Some(entry_batch_span); - - self.batch_requests.push(Request { + let request_proto = Request { id, prefill_logprobs: request.decoder_input_details, inputs: request.inputs.clone(), @@ -166,31 +227,84 @@ impl BatchEntries for GenerateBatchEntries { stopping_parameters: Some(request.stopping_parameters.clone()), adapter_index: adapter.index(), apply_chat_template: request.apply_chat_template, - }); - // Set batch_time - entry.batch_time = Some(Instant::now()); - // Insert in batch_entries IntMap - self.batch_entries.insert(id, entry); - // Map from adapter index back to queue in case we need to add back entries below - // let queue = queue_map.get_mut(&adapter).unwrap(); - self.index_to_adapter.insert(adapter.index(), adapter); + }; + + self.state.add(id, entry, adapter, request_proto); return true; } + + fn drain(&mut self) -> Vec<(Adapter, u64, Entry)> { + self.state.drain() + } + + fn create_batch_data(&self, batch_id: u64, max_tokens: u32) -> Batch { + self.state.create_batch_data(batch_id, max_tokens) + } + + fn is_empty(&self) -> bool { + self.state.is_empty() + } + + fn len(&self) -> usize { + self.state.len() + } } #[derive(Debug)] pub(crate) struct EmbedBatchEntries { - pub(crate) entries: Vec, + pub(crate) state: BatchEntriesState, } impl EmbedBatchEntries { pub(crate) fn new() -> Self { - Self { entries: vec![] } + Self { + state: BatchEntriesState::new(0, 0), + } } } impl BatchEntries for EmbedBatchEntries { - fn add(&mut self, id: u64, mut entry: Entry, adapter: Adapter) -> bool { - return false; + fn add(&mut self, id: u64, entry: Entry, adapter: Adapter) -> bool { + // return false if the entry.request is not of type ValidGenerateRequest + let valid_request = entry + .request + .as_ref() + .as_any() + .downcast_ref::(); + + if valid_request.is_none() { + return false; + } + + let request = valid_request.unwrap(); + let request_proto = Request { + id, + prefill_logprobs: false, + inputs: request.inputs.clone(), + truncate: 0, + parameters: None, + stopping_parameters: None, + adapter_index: adapter.index(), + apply_chat_template: false, + }; + + self.state.add(id, entry, adapter, request_proto); + return true; + } + + fn drain(&mut self) -> Vec<(Adapter, u64, Entry)> { + self.state.drain() + } + + fn create_batch_data(&self, batch_id: u64, max_tokens: u32) -> Batch { + self.state.create_batch_data(batch_id, max_tokens) + } + + fn is_empty(&self) -> bool { + self.state.is_empty() + } + + fn len(&self) -> usize { + self.state.len() } } diff --git a/router/src/scheduler.rs b/router/src/scheduler.rs index 003f599aa..9f7adade4 100644 --- a/router/src/scheduler.rs +++ b/router/src/scheduler.rs @@ -1,6 +1,6 @@ use crate::{ adapter::Adapter, - batch::{BatchEntries, Entry}, + batch::{self, BatchEntries, Entry}, queue::{AdapterEvent, AdapterQueuesState}, AdapterLoader, }; @@ -103,7 +103,7 @@ impl AdapterScheduler { } } -type NextBatch = (IntMap, Batch, Span); +type NextBatch = (Arc, Batch); /// Background task that manages the queues of the various adapters /// TODO(geoffrey): add tracing (span object) to the various commands @@ -264,7 +264,7 @@ impl AdapterSchedulerState { ); // Pop entries starting from the front of the queue - let batch: Option> = None; + let batch_entries: Option> = None; while let Some((id, mut entry, adapter)) = queues_state.next_entry() { // Filter entries where the response receiver was dropped (== entries where the request // was dropped by the client) @@ -310,37 +310,39 @@ impl AdapterSchedulerState { break; } - if batch.is_none() { - batch = Some( + if batch_entries.is_none() { + batch_entries = Some( entry .request .to_batch(num_entries, queues_state.active_len()), ); } - if !batch.as_ref().unwrap().add(id, entry, adapter) { + if !batch_entries.as_ref().unwrap().add(id, entry, adapter) { // Incompatible entry for this batch. Reinsert and break queues_state.push_front(&adapter, id, entry); break; } } + if batch_entries.is_none() { + return None; + } + + let batch_entries = batch_entries.unwrap(); + // Empty batch - if batch_requests.is_empty() { + if batch_entries.is_empty() { return None; } // Check if our batch is big enough if let Some(min_size) = min_size { // Batch is too small - if batch_requests.len() < min_size { + if batch_entries.len() < min_size { // Add back entries to the queue in the correct order - for r in batch_requests.into_iter().rev() { - let id = r.id; - let entry = batch_entries.remove(&id).unwrap(); - let adapter_index = r.adapter_index; - let adapter = index_to_adapter.get_mut(&adapter_index).unwrap(); - queues_state.push_front(adapter, id, entry); + for (adapter, id, entry) in batch_entries.drain() { + queues_state.push_front(&adapter, id, entry); } return None; @@ -348,7 +350,7 @@ impl AdapterSchedulerState { } // Final batch size - let size = batch_requests.len() as u32; + let size = batch.len() as u32; next_batch_span.record("batch_size", size); let batch = Batch { @@ -360,9 +362,9 @@ impl AdapterSchedulerState { // Increment batch id self.next_batch_id += 1; - metrics::histogram!("lorax_batch_next_size", batch.size as f64); + metrics::histogram!("lorax_batch_next_size", batch.len() as f64); - Some((batch_entries, batch, next_batch_span)) + Some((batch_entries, batch)) } } From fdb3198e88512ccb8ed4451e940a3cbe098dbdee Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 29 May 2024 15:14:24 -0700 Subject: [PATCH 07/31] Fixes --- router/src/batch.rs | 2 +- router/src/scheduler.rs | 24 +++++++++++------------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/router/src/batch.rs b/router/src/batch.rs index 1748828c1..f856c9d66 100644 --- a/router/src/batch.rs +++ b/router/src/batch.rs @@ -99,7 +99,7 @@ pub(crate) struct Entry { pub batch_time: Option, } -pub(crate) trait BatchEntries { +pub(crate) trait BatchEntries: Sync + Send + Debug { fn add(&mut self, id: u64, entry: Entry, adapter: Adapter) -> bool; fn drain(&mut self) -> Vec<(Adapter, u64, Entry)>; fn create_batch_data(&self, batch_id: u64, max_tokens: u32) -> Batch; diff --git a/router/src/scheduler.rs b/router/src/scheduler.rs index 9f7adade4..86c2ebe04 100644 --- a/router/src/scheduler.rs +++ b/router/src/scheduler.rs @@ -264,7 +264,7 @@ impl AdapterSchedulerState { ); // Pop entries starting from the front of the queue - let batch_entries: Option> = None; + let mut batch_entries: Option> = None; while let Some((id, mut entry, adapter)) = queues_state.next_entry() { // Filter entries where the response receiver was dropped (== entries where the request // was dropped by the client) @@ -274,10 +274,15 @@ impl AdapterSchedulerState { } if self.requires_padding { + let mut batch_requests_len = 0; + if batch_entries.is_some() { + batch_requests_len = batch_entries.as_ref().unwrap().len(); + } + // We pad to max input length in the Python shards // We need to take these padding tokens into the equation max_input_length = max_input_length.max(entry.request.input_length()); - prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length + prefill_tokens = (batch_requests_len + 1) as u32 * max_input_length } else { // pad to block size prefill_tokens += ((entry.request.input_length() + self.block_size - 1) @@ -341,7 +346,7 @@ impl AdapterSchedulerState { // Batch is too small if batch_entries.len() < min_size { // Add back entries to the queue in the correct order - for (adapter, id, entry) in batch_entries.drain() { + for (adapter, id, entry) in batch_entries.as_ref().drain() { queues_state.push_front(&adapter, id, entry); } @@ -349,20 +354,13 @@ impl AdapterSchedulerState { } } - // Final batch size - let size = batch.len() as u32; - next_batch_span.record("batch_size", size); + let max_tokens = prefill_tokens + decode_tokens; + let batch = batch_entries.create_batch_data(self.next_batch_id, max_tokens); - let batch = Batch { - id: self.next_batch_id, - requests: batch_requests, - size, - max_tokens: (prefill_tokens + decode_tokens), - }; // Increment batch id self.next_batch_id += 1; - metrics::histogram!("lorax_batch_next_size", batch.len() as f64); + metrics::histogram!("lorax_batch_next_size", batch_entries.len() as f64); Some((batch_entries, batch)) } From 024b90aac2a0cb0a7c58fb765959eabd90d5918a Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 29 May 2024 16:56:29 -0700 Subject: [PATCH 08/31] Refactor --- router/src/batch.rs | 22 +++++++++++++++++++++- router/src/infer.rs | 10 +++------- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/router/src/batch.rs b/router/src/batch.rs index f856c9d66..6f8991f89 100644 --- a/router/src/batch.rs +++ b/router/src/batch.rs @@ -1,5 +1,9 @@ use core::fmt::Debug; -use std::{any::Any, collections::HashMap, sync::Arc}; +use std::{ + any::Any, + collections::{HashMap, HashSet}, + sync::Arc, +}; use lorax_client::{Batch, NextTokenChooserParameters, Request, StoppingCriteriaParameters}; use nohash_hasher::{BuildNoHashHasher, IntMap}; @@ -103,6 +107,7 @@ pub(crate) trait BatchEntries: Sync + Send + Debug { fn add(&mut self, id: u64, entry: Entry, adapter: Adapter) -> bool; fn drain(&mut self) -> Vec<(Adapter, u64, Entry)>; fn create_batch_data(&self, batch_id: u64, max_tokens: u32) -> Batch; + fn adapters_in_use(&self) -> HashSet; fn is_empty(&self) -> bool; fn len(&self) -> usize; } @@ -182,6 +187,13 @@ impl BatchEntriesState { } } + fn adapters_in_use(&self) -> HashSet { + self.batch_entries + .iter() + .map(|(_, entry)| entry.request.adapter()) + .collect::>() + } + fn is_empty(&self) -> bool { self.batch_requests.is_empty() } @@ -241,6 +253,10 @@ impl BatchEntries for GenerateBatchEntries { self.state.create_batch_data(batch_id, max_tokens) } + fn adapters_in_use(&self) -> HashSet { + self.state.adapters_in_use() + } + fn is_empty(&self) -> bool { self.state.is_empty() } @@ -300,6 +316,10 @@ impl BatchEntries for EmbedBatchEntries { self.state.create_batch_data(batch_id, max_tokens) } + fn adapters_in_use(&self) -> HashSet { + self.state.adapters_in_use() + } + fn is_empty(&self) -> bool { self.state.is_empty() } diff --git a/router/src/infer.rs b/router/src/infer.rs index 8818a5de2..bb4b1b9e0 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -341,7 +341,7 @@ async fn batching_task( // Get the next batch from the queue // This batch might be smaller than the maximum batch size if there are not enough requests // waiting in the queue - while let Some((mut entries, batch, span)) = adapter_scheduler + while let Some((mut batch_entries, batch)) = adapter_scheduler .next_batch( HashSet::new(), None, @@ -379,14 +379,10 @@ async fn batching_task( }; let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); - - let adapters_in_use = entries - .iter() - .map(|(_, entry)| entry.request.adapter()) - .collect::>(); + let adapters_in_use = batch_entries.adapters_in_use(); // Try to get a new batch - if let Some((mut new_entries, new_batch, span)) = adapter_scheduler + if let Some((mut batch_entries, new_batch)) = adapter_scheduler .next_batch( adapters_in_use, min_size, From 9820cb9c745362ac750b65e82e920baf8bf8e112 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 29 May 2024 22:25:30 -0700 Subject: [PATCH 09/31] Callbacks --- Cargo.lock | 5 ++- router/Cargo.toml | 3 +- router/src/batch.rs | 106 ++++++++++++++++++++++++++++++++++++++------ router/src/infer.rs | 41 ++++++++++++++++- 4 files changed, 137 insertions(+), 18 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c3b488bdb..eb05e5b62 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -147,9 +147,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.71" +version = "0.1.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a564d521dd56509c4c47480d00b80ee55f7e385ae48db5744c67ad50c92d2ebf" +checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" dependencies = [ "proc-macro2", "quote", @@ -1271,6 +1271,7 @@ name = "lorax-router" version = "0.1.0" dependencies = [ "async-stream", + "async-trait", "axum", "axum-tracing-opentelemetry", "clap", diff --git a/router/Cargo.toml b/router/Cargo.toml index 4ae55744f..5a925f312 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -46,10 +46,11 @@ utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] } ngrok = { version = "0.12.3", features = ["axum"], optional = true } once_cell = "1.19.0" itertools = "0.12.1" +async-trait = "0.1.80" [build-dependencies] vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] } [features] default = ["ngrok"] -ngrok = ["dep:ngrok"] \ No newline at end of file +ngrok = ["dep:ngrok"] diff --git a/router/src/batch.rs b/router/src/batch.rs index 6f8991f89..3fa492fff 100644 --- a/router/src/batch.rs +++ b/router/src/batch.rs @@ -2,17 +2,22 @@ use core::fmt::Debug; use std::{ any::Any, collections::{HashMap, HashSet}, - sync::Arc, + sync::{atomic::AtomicBool, Arc}, }; -use lorax_client::{Batch, NextTokenChooserParameters, Request, StoppingCriteriaParameters}; +use async_trait::async_trait; + +use lorax_client::{ + Batch, CachedBatch, NextTokenChooserParameters, Request, ShardedClient, + StoppingCriteriaParameters, +}; use nohash_hasher::{BuildNoHashHasher, IntMap}; use tokio::time::Instant; -use tracing::{info_span, Span}; +use tracing::{info_span, Instrument, Span}; use crate::{ adapter::Adapter, - infer::{InferError, InferStreamResponse}, + infer::{decode, embed, prefill, InferError, InferStreamResponse}, }; pub(crate) trait ValidRequest: Sync + Send + Debug + Any { @@ -103,15 +108,6 @@ pub(crate) struct Entry { pub batch_time: Option, } -pub(crate) trait BatchEntries: Sync + Send + Debug { - fn add(&mut self, id: u64, entry: Entry, adapter: Adapter) -> bool; - fn drain(&mut self) -> Vec<(Adapter, u64, Entry)>; - fn create_batch_data(&self, batch_id: u64, max_tokens: u32) -> Batch; - fn adapters_in_use(&self) -> HashSet; - fn is_empty(&self) -> bool; - fn len(&self) -> usize; -} - #[derive(Debug)] pub(crate) struct BatchEntriesState { pub(crate) batch_requests: Vec, @@ -203,6 +199,30 @@ impl BatchEntriesState { } } +#[async_trait] +pub(crate) trait BatchEntries: Sync + Send + Debug { + fn add(&mut self, id: u64, entry: Entry, adapter: Adapter) -> bool; + fn drain(&mut self) -> Vec<(Adapter, u64, Entry)>; + fn create_batch_data(&self, batch_id: u64, max_tokens: u32) -> Batch; + fn adapters_in_use(&self) -> HashSet; + fn is_empty(&self) -> bool; + fn len(&self) -> usize; + + async fn process_first( + &mut self, + client: &mut ShardedClient, + batch: Batch, + generation_health: &Arc, + ) -> Option; + + async fn process_next( + &mut self, + client: &mut ShardedClient, + batches: Vec, + generation_health: &Arc, + ) -> Option; +} + #[derive(Debug)] pub(crate) struct GenerateBatchEntries { pub(crate) state: BatchEntriesState, @@ -216,6 +236,7 @@ impl GenerateBatchEntries { } } +#[async_trait] impl BatchEntries for GenerateBatchEntries { fn add(&mut self, id: u64, entry: Entry, adapter: Adapter) -> bool { // return false if the entry.request is not of type ValidGenerateRequest @@ -264,6 +285,38 @@ impl BatchEntries for GenerateBatchEntries { fn len(&self) -> usize { self.state.len() } + + async fn process_first( + &mut self, + client: &mut ShardedClient, + batch: Batch, + generation_health: &Arc, + ) -> Option { + prefill( + &mut client, + batch, + &mut self.state.batch_entries, + &generation_health, + ) + .instrument(self.state.next_batch_span) + .await + } + + async fn process_next( + &mut self, + client: &mut ShardedClient, + batches: Vec, + generation_health: &Arc, + ) -> Option { + decode( + &mut client, + batches, + &mut self.state.batch_entries, + &generation_health, + ) + .instrument(self.state.next_batch_span) + .await + } } #[derive(Debug)] @@ -279,6 +332,7 @@ impl EmbedBatchEntries { } } +#[async_trait] impl BatchEntries for EmbedBatchEntries { fn add(&mut self, id: u64, entry: Entry, adapter: Adapter) -> bool { // return false if the entry.request is not of type ValidGenerateRequest @@ -327,4 +381,30 @@ impl BatchEntries for EmbedBatchEntries { fn len(&self) -> usize { self.state.len() } + + async fn process_first( + &mut self, + client: &mut ShardedClient, + batch: Batch, + generation_health: &Arc, + ) -> Option { + embed( + &mut client, + batch, + &mut self.state.batch_entries, + &generation_health, + ) + .instrument(self.state.next_batch_span) + .await + } + + async fn process_next( + &mut self, + client: &mut ShardedClient, + batches: Vec, + generation_health: &Arc, + ) -> Option { + // TODO(travis): send error (programming eroor) if we get here + None + } } diff --git a/router/src/infer.rs b/router/src/infer.rs index bb4b1b9e0..77c160d51 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -449,7 +449,7 @@ async fn batching_task( } #[instrument(skip_all)] -async fn prefill( +pub(crate) async fn prefill( client: &mut ShardedClient, batch: Batch, entries: &mut IntMap, @@ -486,7 +486,7 @@ async fn prefill( } #[instrument(skip_all)] -async fn decode( +pub(crate) async fn decode( client: &mut ShardedClient, batches: Vec, entries: &mut IntMap, @@ -523,6 +523,43 @@ async fn decode( } } +#[instrument(skip_all)] +pub(crate) async fn embed( + client: &mut ShardedClient, + batch: Batch, + entries: &mut IntMap, + generation_health: &Arc, +) -> Option { + let start_time = Instant::now(); + let batch_id = batch.id; + metrics::increment_counter!("lorax_batch_inference_count", "method" => "prefill"); + + match client.embed(batch).await { + Ok((generations, next_batch)) => { + // Update health + generation_health.store(true, Ordering::SeqCst); + // Send generated tokens and filter stopped entries + filter_send_generations(generations, entries); + + // Filter next batch and remove requests that were stopped + let next_batch = filter_batch(client, next_batch, entries).await; + + metrics::histogram!("lorax_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill"); + metrics::increment_counter!("lorax_batch_inference_success", "method" => "prefill"); + next_batch + } + // If we have an error, we discard the whole batch + Err(err) => { + // Update health + generation_health.store(false, Ordering::SeqCst); + let _ = client.clear_cache(Some(batch_id)).await; + send_errors(err, entries); + metrics::increment_counter!("lorax_batch_inference_failure", "method" => "prefill"); + None + } + } +} + /// Filter a `batch` and remove all requests not present in `entries` #[instrument(skip_all)] async fn filter_batch( From a6f7e3b6a4e24c127f6a94411ff874517506e3ef Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 30 May 2024 08:21:22 -0700 Subject: [PATCH 10/31] Fixed extend --- router/src/batch.rs | 22 ++++++++++++++++++++++ router/src/infer.rs | 20 ++++++++++---------- 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/router/src/batch.rs b/router/src/batch.rs index 3fa492fff..5fdb160ac 100644 --- a/router/src/batch.rs +++ b/router/src/batch.rs @@ -202,11 +202,13 @@ impl BatchEntriesState { #[async_trait] pub(crate) trait BatchEntries: Sync + Send + Debug { fn add(&mut self, id: u64, entry: Entry, adapter: Adapter) -> bool; + fn extend(&mut self, entries: Arc); fn drain(&mut self) -> Vec<(Adapter, u64, Entry)>; fn create_batch_data(&self, batch_id: u64, max_tokens: u32) -> Batch; fn adapters_in_use(&self) -> HashSet; fn is_empty(&self) -> bool; fn len(&self) -> usize; + fn state(&self) -> &BatchEntriesState; async fn process_first( &mut self, @@ -266,6 +268,12 @@ impl BatchEntries for GenerateBatchEntries { return true; } + fn extend(&mut self, entries: Arc) { + self.state() + .batch_entries + .extend(entries.state().batch_entries); + } + fn drain(&mut self) -> Vec<(Adapter, u64, Entry)> { self.state.drain() } @@ -286,6 +294,10 @@ impl BatchEntries for GenerateBatchEntries { self.state.len() } + fn state(&self) -> &BatchEntriesState { + &self.state + } + async fn process_first( &mut self, client: &mut ShardedClient, @@ -362,6 +374,12 @@ impl BatchEntries for EmbedBatchEntries { return true; } + fn extend(&mut self, entries: Arc) { + self.state() + .batch_entries + .extend(entries.state().batch_entries); + } + fn drain(&mut self) -> Vec<(Adapter, u64, Entry)> { self.state.drain() } @@ -382,6 +400,10 @@ impl BatchEntries for EmbedBatchEntries { self.state.len() } + fn state(&self) -> &BatchEntriesState { + &self.state + } + async fn process_first( &mut self, client: &mut ShardedClient, diff --git a/router/src/infer.rs b/router/src/infer.rs index 77c160d51..cf2c694e5 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -350,8 +350,8 @@ async fn batching_task( ) .await { - let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health) - .instrument(span) + let mut cached_batch = batch_entries + .process_first(&mut client, batch, &generation_health) .await; let mut waiting_tokens = 1; @@ -382,7 +382,7 @@ async fn batching_task( let adapters_in_use = batch_entries.adapters_in_use(); // Try to get a new batch - if let Some((mut batch_entries, new_batch)) = adapter_scheduler + if let Some((mut new_entries, new_batch)) = adapter_scheduler .next_batch( adapters_in_use, min_size, @@ -410,21 +410,20 @@ async fn batching_task( }); // Generate one token for this new batch to have the attention past in cache - let new_cached_batch = - prefill(&mut client, new_batch, &mut new_entries, &generation_health) - .instrument(span) - .await; + let new_cached_batch = batch_entries + .process_first(&mut client, new_batch, &generation_health) + .await; // Reset waiting counter waiting_tokens = 1; // Extend current batch with the new batch if let Some(new_cached_batch) = new_cached_batch { - entries.extend(new_entries); + batch_entries.extend(new_entries); batches.push(new_cached_batch); } } // Create span for this batch to add context to inference calls - let next_batch_size = entries.len(); + let next_batch_size = batch_entries.len(); let next_batch_span = info_span!(parent: None, "batch", batch_size = next_batch_size); entries.iter_mut().for_each(|(_, entry)| { @@ -437,7 +436,8 @@ async fn batching_task( entry.temp_span = Some(entry_batch_span); }); - cached_batch = decode(&mut client, batches, &mut entries, &generation_health) + cached_batch = batch_entries + .process_next(&mut client, batches, &generation_health) .instrument(next_batch_span) .await; waiting_tokens += 1; From 308bf3ff40eebffcc77dbc5879fc662183a91845 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 30 May 2024 12:23:58 -0700 Subject: [PATCH 11/31] Arc to Box --- router/src/batch.rs | 52 ++++++++++++++++++++++++++++++++++------- router/src/infer.rs | 31 +++++++++--------------- router/src/scheduler.rs | 4 ++-- 3 files changed, 57 insertions(+), 30 deletions(-) diff --git a/router/src/batch.rs b/router/src/batch.rs index 5fdb160ac..21fb116ab 100644 --- a/router/src/batch.rs +++ b/router/src/batch.rs @@ -24,7 +24,7 @@ pub(crate) trait ValidRequest: Sync + Send + Debug + Any { fn input_length(&self) -> u32; fn max_new_tokens(&self) -> u32; fn adapter(&self) -> Adapter; - fn to_batch(&self, num_entries: usize, queue_len: usize) -> Arc; + fn to_batch(&self, num_entries: usize, queue_len: usize) -> Box; fn as_any(&self) -> &dyn Any; } @@ -41,8 +41,8 @@ impl ValidRequest for ValidGenerateRequest { self.adapter } - fn to_batch(&self, num_entries: usize, queue_len: usize) -> Arc { - Arc::new(GenerateBatchEntries::new(num_entries, queue_len)) + fn to_batch(&self, num_entries: usize, queue_len: usize) -> Box { + Box::new(GenerateBatchEntries::new(num_entries, queue_len)) } fn as_any(&self) -> &dyn Any { @@ -70,8 +70,8 @@ impl ValidRequest for ValidEmbedRequest { self.adapter } - fn to_batch(&self, num_entries: usize, queue_len: usize) -> Arc { - Arc::new(EmbedBatchEntries::new()) + fn to_batch(&self, num_entries: usize, queue_len: usize) -> Box { + Box::new(EmbedBatchEntries::new()) } fn as_any(&self) -> &dyn Any { @@ -202,13 +202,15 @@ impl BatchEntriesState { #[async_trait] pub(crate) trait BatchEntries: Sync + Send + Debug { fn add(&mut self, id: u64, entry: Entry, adapter: Adapter) -> bool; - fn extend(&mut self, entries: Arc); + fn extend(&mut self, entries: Box); fn drain(&mut self) -> Vec<(Adapter, u64, Entry)>; fn create_batch_data(&self, batch_id: u64, max_tokens: u32) -> Batch; fn adapters_in_use(&self) -> HashSet; fn is_empty(&self) -> bool; fn len(&self) -> usize; fn state(&self) -> &BatchEntriesState; + fn set_span(&mut self, span: Span); + fn update_entries_span(&mut self, create_span_fn: Box Span>); async fn process_first( &mut self, @@ -268,7 +270,7 @@ impl BatchEntries for GenerateBatchEntries { return true; } - fn extend(&mut self, entries: Arc) { + fn extend(&mut self, entries: Box) { self.state() .batch_entries .extend(entries.state().batch_entries); @@ -298,6 +300,23 @@ impl BatchEntries for GenerateBatchEntries { &self.state } + fn set_span(&mut self, span: Span) { + self.state.next_batch_span = span; + } + + fn update_entries_span(&mut self, create_span_fn: Box Span>) { + let state = self.state(); + for (_, entry) in self.state.batch_entries.iter_mut() { + // Create a new span to link the batch back to this entry + let entry_batch_span = create_span_fn(&entry.span); + // Add relationships + state.next_batch_span.follows_from(&entry_batch_span); + entry_batch_span.follows_from(&state.next_batch_span); + // Update entry + entry.temp_span = Some(entry_batch_span); + } + } + async fn process_first( &mut self, client: &mut ShardedClient, @@ -374,7 +393,7 @@ impl BatchEntries for EmbedBatchEntries { return true; } - fn extend(&mut self, entries: Arc) { + fn extend(&mut self, entries: Box) { self.state() .batch_entries .extend(entries.state().batch_entries); @@ -404,6 +423,23 @@ impl BatchEntries for EmbedBatchEntries { &self.state } + fn set_span(&mut self, span: Span) { + self.state.next_batch_span = span; + } + + fn update_entries_span(&mut self, create_span_fn: Box Span>) { + let state = self.state(); + for (_, entry) in self.state.batch_entries.iter_mut() { + // Create a new span to link the batch back to this entry + let entry_batch_span = create_span_fn(&entry.span); + // Add relationships + state.next_batch_span.follows_from(&entry_batch_span); + entry_batch_span.follows_from(&state.next_batch_span); + // Update entry + entry.temp_span = Some(entry_batch_span); + } + } + async fn process_first( &mut self, client: &mut ShardedClient, diff --git a/router/src/infer.rs b/router/src/infer.rs index cf2c694e5..2abe974ec 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -398,16 +398,7 @@ async fn batching_task( metrics::increment_counter!("lorax_batch_concat", "reason" => "wait_exceeded"); } - entries.iter_mut().for_each(|(_, entry)| { - // Create a new span to add the info that this entry is waiting - // because a new batch is being computed - let entry_waiting_span = info_span!(parent: &entry.span, "waiting"); - // Add relationships - span.follows_from(&entry_waiting_span); - entry_waiting_span.follows_from(&span); - // Update entry - entry.temp_span = Some(entry_waiting_span); - }); + batch_entries.update_entries_span(Box::new(create_waiting_span)); // Generate one token for this new batch to have the attention past in cache let new_cached_batch = batch_entries @@ -426,19 +417,11 @@ async fn batching_task( let next_batch_size = batch_entries.len(); let next_batch_span = info_span!(parent: None, "batch", batch_size = next_batch_size); - entries.iter_mut().for_each(|(_, entry)| { - // Create a new span to link the batch back to this entry - let entry_batch_span = info_span!(parent: &entry.span, "infer"); - // Add relationships - next_batch_span.follows_from(&entry_batch_span); - entry_batch_span.follows_from(&next_batch_span); - // Update entry - entry.temp_span = Some(entry_batch_span); - }); + batch_entries.set_span(next_batch_span); + batch_entries.update_entries_span(Box::new(create_infer_span)); cached_batch = batch_entries .process_next(&mut client, batches, &generation_health) - .instrument(next_batch_span) .await; waiting_tokens += 1; } @@ -448,6 +431,14 @@ async fn batching_task( } } +fn create_waiting_span(parent: &Span) -> Span { + info_span!(parent: parent, "waiting") +} + +fn create_infer_span(parent: &Span) -> Span { + info_span!(parent:parent, "infer") +} + #[instrument(skip_all)] pub(crate) async fn prefill( client: &mut ShardedClient, diff --git a/router/src/scheduler.rs b/router/src/scheduler.rs index 86c2ebe04..f59720052 100644 --- a/router/src/scheduler.rs +++ b/router/src/scheduler.rs @@ -103,7 +103,7 @@ impl AdapterScheduler { } } -type NextBatch = (Arc, Batch); +type NextBatch = (Box, Batch); /// Background task that manages the queues of the various adapters /// TODO(geoffrey): add tracing (span object) to the various commands @@ -264,7 +264,7 @@ impl AdapterSchedulerState { ); // Pop entries starting from the front of the queue - let mut batch_entries: Option> = None; + let mut batch_entries: Option> = None; while let Some((id, mut entry, adapter)) = queues_state.next_entry() { // Filter entries where the response receiver was dropped (== entries where the request // was dropped by the client) From 1cc88e99a369db4f736fec7d5aeb13adefdbdd81 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 30 May 2024 12:42:48 -0700 Subject: [PATCH 12/31] Fixed grpc --- proto/generate.proto | 14 +++++++++----- router/client/src/client.rs | 16 ++++++++-------- router/client/src/sharded_client.rs | 11 ++++++----- router/src/infer.rs | 13 +++++-------- 4 files changed, 28 insertions(+), 26 deletions(-) diff --git a/proto/generate.proto b/proto/generate.proto index b47f2cd29..d95dc8981 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -232,16 +232,20 @@ message DecodeResponse { optional CachedBatch batch = 2; } -message EmbedRequest { - string inputs = 1; -} - message Embedding { repeated float values = 1; } +message EmbedRequest { + /// Batch + Batch batch = 1; +} + message EmbedResponse { - Embedding embeddings = 1; + /// Embeddings + repeated Embedding embeddings = 1; + + /// Error message on failure string errorMsg = 2; } diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 6b7157885..7a5f2f1e6 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -64,14 +64,6 @@ impl Client { Ok(response) } - /// Embed - #[instrument(skip(self))] - pub async fn embed(&mut self, inputs: String) -> Result { - let request = tonic::Request::new(EmbedRequest { inputs }).inject_context(); - let response = self.stub.embed(request).await?.into_inner(); - Ok(response) - } - /// Get model health #[instrument(skip(self))] pub async fn health(&mut self) -> Result { @@ -196,6 +188,14 @@ impl Client { Ok((response.generations, response.batch)) } + /// Embed + #[instrument(skip(self))] + pub async fn embed(&mut self, batch: Batch) -> Result> { + let request = tonic::Request::new(EmbedRequest { batch: Some(batch) }).inject_context(); + let response = self.stub.embed(request).await?.into_inner(); + Ok(response.embeddings) + } + /// Downloads the weights for an adapter. pub async fn download_adapter( &mut self, diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 1a6f885fb..c973f9251 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -1,4 +1,4 @@ -use crate::pb::generate::v1::EmbedResponse; +use crate::pb::generate::v1::{EmbedResponse, Embedding}; /// Multi shard Client use crate::{ AdapterParameters, Batch, CachedBatch, Client, DownloadAdapterResponse, Generation, @@ -154,15 +154,16 @@ impl ShardedClient { merge_generations(results?) } - /// Get the model info + /// Embed the given batch #[instrument(skip(self))] - pub async fn embed(&mut self, inputs: String) -> Result> { + pub async fn embed(&mut self, batch: Batch) -> Result> { let futures: Vec<_> = self .clients .iter_mut() - .map(|client| Box::pin(client.embed(inputs.clone()))) + .map(|client| Box::pin(client.embed(batch.clone()))) .collect(); - join_all(futures).await.into_iter().collect() + let results: Result>> = join_all(futures).await.into_iter().collect(); + Ok(results?.into_iter().flatten().collect()) } pub async fn download_adapter( diff --git a/router/src/infer.rs b/router/src/infer.rs index 2abe974ec..32e0b3a30 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -526,18 +526,15 @@ pub(crate) async fn embed( metrics::increment_counter!("lorax_batch_inference_count", "method" => "prefill"); match client.embed(batch).await { - Ok((generations, next_batch)) => { + Ok((results)) => { // Update health generation_health.store(true, Ordering::SeqCst); // Send generated tokens and filter stopped entries filter_send_generations(generations, entries); - // Filter next batch and remove requests that were stopped - let next_batch = filter_batch(client, next_batch, entries).await; - - metrics::histogram!("lorax_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill"); - metrics::increment_counter!("lorax_batch_inference_success", "method" => "prefill"); - next_batch + metrics::histogram!("lorax_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "embed"); + metrics::increment_counter!("lorax_batch_inference_success", "method" => "embed"); + None } // If we have an error, we discard the whole batch Err(err) => { @@ -545,7 +542,7 @@ pub(crate) async fn embed( generation_health.store(false, Ordering::SeqCst); let _ = client.clear_cache(Some(batch_id)).await; send_errors(err, entries); - metrics::increment_counter!("lorax_batch_inference_failure", "method" => "prefill"); + metrics::increment_counter!("lorax_batch_inference_failure", "method" => "embed"); None } } From a84c450243bd1f00e68b4022ac67d25c7a1acd07 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 30 May 2024 12:58:22 -0700 Subject: [PATCH 13/31] Fixed message forwarding --- proto/generate.proto | 6 +++- router/client/src/lib.rs | 2 +- router/src/infer.rs | 68 +++++++++++++++++++++++++++++++++++++--- 3 files changed, 70 insertions(+), 6 deletions(-) diff --git a/proto/generate.proto b/proto/generate.proto index d95dc8981..ac48cf15b 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -233,7 +233,11 @@ message DecodeResponse { } message Embedding { - repeated float values = 1; + /// Request ID + uint64 request_id = 1; + + /// Embedding values + repeated float values = 2; } message EmbedRequest { diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index 3acdb5f68..0b7242104 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -9,7 +9,7 @@ pub use client::Client; pub use pb::generate::v1::HealthResponse; pub use pb::generate::v1::InfoResponse as ShardInfo; pub use pb::generate::v1::{ - AdapterParameters, AlternativeTokens, Batch, CachedBatch, DownloadAdapterResponse, + AdapterParameters, AlternativeTokens, Batch, CachedBatch, DownloadAdapterResponse, Embedding, FinishReason, GeneratedText, Generation, MajoritySignMethod, MergeStrategy, NextTokenChooserParameters, NextTokens, PrefillTokens, Request, StoppingCriteriaParameters, }; diff --git a/router/src/infer.rs b/router/src/infer.rs index 32e0b3a30..2ed804cf2 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -11,7 +11,8 @@ use futures::future::try_join_all; use futures::stream::StreamExt; use itertools::multizip; use lorax_client::{ - Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient, + Batch, CachedBatch, ClientError, Embedding, GeneratedText, Generation, PrefillTokens, + ShardedClient, }; use nohash_hasher::IntMap; use std::collections::{HashMap, HashSet}; @@ -23,7 +24,7 @@ use std::time::Duration; use thiserror::Error; use tokio::sync::{Mutex, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError}; use tokio::time::Instant; -use tracing::{info_span, instrument, Instrument, Span}; +use tracing::{info_span, instrument, Span}; /// Inference struct #[derive(Clone)] @@ -256,6 +257,10 @@ impl Infer { result_start = Some(start); result_queued = Some(queued) } + InferStreamResponse::Embed { .. } => { + // This should not happen + tracing::error!("Received an Embed message in generate. This is a bug."); + } } } @@ -526,11 +531,37 @@ pub(crate) async fn embed( metrics::increment_counter!("lorax_batch_inference_count", "method" => "prefill"); match client.embed(batch).await { - Ok((results)) => { + Ok(results) => { // Update health generation_health.store(true, Ordering::SeqCst); // Send generated tokens and filter stopped entries - filter_send_generations(generations, entries); + results.into_iter().for_each(|embedding| { + let id = embedding.request_id; + // Get entry + // We can `expect` here as the request id should always be in the entries + let entry = entries + .get(&id) + .expect("ID not found in entries. This is a bug."); + + // Send generation responses back to the infer task + // If the receive an error from the Flume channel, it means that the client dropped the + // request and we need to stop generating hence why we unwrap_or(true) + let stopped = send_embeddings(embedding, entry) + .map_err(|err| { + if let SendTimeoutError::Timeout(_) = *err { + tracing::error!("Entry response channel timed out.") + } + + metrics::increment_counter!("lorax_request_failure", "err" => "dropped"); + err + }) + .unwrap_or(true); + if stopped { + entries + .remove(&id) + .expect("ID not found in entries. This is a bug."); + } + }); metrics::histogram!("lorax_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "embed"); metrics::increment_counter!("lorax_batch_inference_success", "method" => "embed"); @@ -706,6 +737,29 @@ fn send_responses( Ok(stopped) } +/// Send responses through the `entry` response channel +fn send_embeddings( + embedding: Embedding, + entry: &Entry, +) -> Result>>> { + // Return directly if the channel is disconnected + if entry.response_tx.is_disconnected() { + return Ok(true); + } + + entry.response_tx.send_timeout( + Ok(InferStreamResponse::Embed { + embedding: embedding.clone(), + queued: entry.queue_time, + start: entry.batch_time.unwrap(), + }), + Duration::from_millis(10), + )?; + + // TODO(travis): redundant as we always return true, just make it return nothing + Ok(true) +} + /// Send errors to Infer for all `entries` #[instrument(skip_all)] fn send_errors(error: ClientError, entries: &mut IntMap) { @@ -733,6 +787,12 @@ pub(crate) enum InferStreamResponse { }, // Intermediate messages Token(Token), + // Embeddings + Embed { + embedding: Embedding, + start: Instant, + queued: Instant, + }, // Last message End { token: Token, From 3e9c33596727cf08799bbcdce098c70ee71215c3 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 30 May 2024 13:01:14 -0700 Subject: [PATCH 14/31] Error handling --- router/src/server.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/router/src/server.rs b/router/src/server.rs index 98f13b48b..542b9b592 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -841,6 +841,15 @@ async fn generate_stream_with_callback( yield Ok(callback(stream_token)); break; + }, + InferStreamResponse::Embed { + .. + } => { + let err = InferError::from(ValidationError::EmbeddingModel); + metrics::increment_counter!("lorax_request_failure", "err" => "bad_request"); + tracing::error!("{err}"); + yield Ok(Event::from(err)); + break; } } } From cae56a6d14a55dffa1cbefdcbd09bda4b4f3eadd Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 30 May 2024 21:48:31 -0700 Subject: [PATCH 15/31] Embed infer --- router/src/infer.rs | 98 +++++++++++++++++++++++++++++++++++++++- router/src/validation.rs | 2 +- 2 files changed, 98 insertions(+), 2 deletions(-) diff --git a/router/src/infer.rs b/router/src/infer.rs index 2ed804cf2..2431c9ddf 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -1,9 +1,10 @@ /// Batching and inference logic use crate::adapter::{extract_adapter_params, Adapter, BASE_MODEL_ADAPTER_ID}; +use crate::batch::ValidEmbedRequest; use crate::queue::AdapterEvent; use crate::scheduler::AdapterScheduler; use crate::validation::{Validation, ValidationError}; -use crate::{AdapterParameters, AlternativeToken, Entry, Token}; +use crate::{AdapterParameters, AlternativeToken, EmbedRequest, Entry, Token}; use crate::{GenerateRequest, PrefillToken}; use flume::r#async::RecvStream; use flume::SendTimeoutError; @@ -183,6 +184,101 @@ impl Infer { Ok((permit, response_rx.into_stream())) } + #[instrument(skip(self))] + pub(crate) async fn embed( + &self, + request: EmbedRequest, + ) -> Result< + ( + OwnedSemaphorePermit, + RecvStream>, + ), + InferError, + > { + // Limit concurrent requests by acquiring a permit from the semaphore + let permit = self + .clone() + .limit_concurrent_requests + .try_acquire_owned() + .map_err(|err| { + metrics::increment_counter!("lorax_request_failure", "err" => "overloaded"); + tracing::error!("{err}"); + err + })?; + + // TODO(travis): support adapters + // let (adapter_source, adapter_parameters) = extract_adapter_params( + // request.parameters.adapter_id.clone(), + // request.parameters.adapter_source.clone(), + // request.parameters.adapter_parameters.clone(), + // ); + + // let adapter_idx; + // { + // // TODO(travis): can optimize concurrency here using RWLock + // let mut adapter_to_index = self.adapter_to_index.lock().await; + // let adapter_key = adapter_parameters.clone(); + // if adapter_to_index.contains_key(&adapter_key) { + // adapter_idx = *adapter_to_index.get(&adapter_key).unwrap(); + // } else { + // adapter_idx = adapter_to_index.len() as u32; + // adapter_to_index.insert(adapter_key, adapter_idx); + // } + // } + + let adapter = Adapter::new( + AdapterParameters { + adapter_ids: vec!["".to_string()], + ..Default::default() + }, + "hf".to_string(), + 0, + None, + ); + + // TODO(travis): robust validation + // Validate request + // let valid_request = self + // .validation + // .validate(request, adapter.clone()) + // .await + // .map_err(|err| { + // metrics::increment_counter!("lorax_request_failure", "err" => "validation"); + // tracing::error!("{err}"); + // err + // })?; + + let (inputs, input_length) = self + .validation + .validate_input(request.inputs, None, Some(1)) + .await?; + + let valid_request = ValidEmbedRequest { + inputs: inputs, + input_length: input_length as u32, + adapter: adapter.clone(), + }; + + // MPSC channel to communicate with the background batching task + let (response_tx, response_rx) = flume::unbounded(); + + // Process the request by sending it to the queue associated with `adapter` + self.adapter_scheduler.process( + adapter.clone(), + Entry { + request: Arc::new(valid_request), + response_tx, + span: Span::current(), + temp_span: None, + queue_time: Instant::now(), + batch_time: None, + }, + ); + + // Return stream + Ok((permit, response_rx.into_stream())) + } + /// Tokenizer the input #[instrument(skip_all)] pub(crate) async fn tokenize( diff --git a/router/src/validation.rs b/router/src/validation.rs index 5a8f24850..c1596ca19 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -87,7 +87,7 @@ impl Validation { } #[instrument(skip(self, inputs))] - async fn validate_input( + pub(crate) async fn validate_input( &self, inputs: String, truncate: Option, From cb3ddfaf2d7fc4c48f685552b240c02bec631c9a Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 30 May 2024 22:18:21 -0700 Subject: [PATCH 16/31] Fix embed --- router/src/infer.rs | 209 ++++++++++++++++++++++++------------------- router/src/server.rs | 143 +++++++++++++++++++++++++++++ 2 files changed, 262 insertions(+), 90 deletions(-) diff --git a/router/src/infer.rs b/router/src/infer.rs index 2431c9ddf..c802c329c 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -4,7 +4,7 @@ use crate::batch::ValidEmbedRequest; use crate::queue::AdapterEvent; use crate::scheduler::AdapterScheduler; use crate::validation::{Validation, ValidationError}; -use crate::{AdapterParameters, AlternativeToken, EmbedRequest, Entry, Token}; +use crate::{AdapterParameters, AlternativeToken, EmbedRequest, EmbedResponse, Entry, Token}; use crate::{GenerateRequest, PrefillToken}; use flume::r#async::RecvStream; use flume::SendTimeoutError; @@ -184,17 +184,109 @@ impl Infer { Ok((permit, response_rx.into_stream())) } + /// Tokenizer the input + #[instrument(skip_all)] + pub(crate) async fn tokenize( + &self, + request: GenerateRequest, + ) -> Result, InferError> { + // Tokenize request + let inputs = request.inputs; + let truncate = request.parameters.truncate; + let encoding = self + .validation + .tokenize(inputs, truncate) + .await + .map_err(|err| { + tracing::error!("Error occurred during tokenization. {err}"); + err + })?; + + // Return Encoding + Ok(encoding.map(|(encoding, _)| encoding)) + } + /// Add a new request to the queue and return a InferResponse #[instrument(skip(self))] - pub(crate) async fn embed( + pub(crate) async fn generate( &self, - request: EmbedRequest, - ) -> Result< - ( - OwnedSemaphorePermit, - RecvStream>, - ), - InferError, - > { + request: GenerateRequest, + ) -> Result { + // Create stream and keep semaphore permit as long as generate lives + let (_permit, mut stream) = self.generate_stream(request).await?; + + // Return values + let mut result_prefill = Vec::new(); + let mut result_tokens = Vec::new(); + let mut result_prefill_length = 0; + let mut result_generated_text = None; + let mut result_start = None; + let mut result_queued = None; + + // Iterate on stream + while let Some(response) = stream.next().await { + match response? { + // Add prefill tokens + InferStreamResponse::Prefill { + tokens, + tokens_length, + } => { + // Create Token objects + // We do that here instead of in the Python code as Rust for loops are faster + if let Some(tokens_val) = tokens { + result_prefill = tokens_val + .ids + .into_iter() + .zip(tokens_val.logprobs.into_iter()) + .zip(tokens_val.texts.into_iter()) + .map(|((id, logprob), text)| PrefillToken { id, text, logprob }) + .collect(); + } + result_prefill_length = tokens_length; + } + // Push last token + InferStreamResponse::Token(token) => result_tokens.push(token), + // Final message + // Set return values + InferStreamResponse::End { + token, + generated_text, + start, + queued, + } => { + result_tokens.push(token); + result_generated_text = Some(generated_text); + result_start = Some(start); + result_queued = Some(queued) + } + InferStreamResponse::Embed { .. } => { + // This should not happen + tracing::error!("Received an Embed message in generate. This is a bug."); + } + } + } + + // Check that we received a `InferStreamResponse::End` message + if let (Some(generated_text), Some(queued), Some(start)) = + (result_generated_text, result_queued, result_start) + { + Ok(InferResponse { + prefill: result_prefill, + tokens: result_tokens, + prompt_tokens: result_prefill_length, + generated_text, + queued, + start, + }) + } else { + let err = InferError::IncompleteGeneration; + metrics::increment_counter!("lorax_request_failure", "err" => "incomplete"); + tracing::error!("{err}"); + Err(err) + } + } + + #[instrument(skip(self))] + pub(crate) async fn embed(&self, request: EmbedRequest) -> Result { // Limit concurrent requests by acquiring a permit from the semaphore let permit = self .clone() @@ -275,102 +367,38 @@ impl Infer { }, ); - // Return stream - Ok((permit, response_rx.into_stream())) - } - - /// Tokenizer the input - #[instrument(skip_all)] - pub(crate) async fn tokenize( - &self, - request: GenerateRequest, - ) -> Result, InferError> { - // Tokenize request - let inputs = request.inputs; - let truncate = request.parameters.truncate; - let encoding = self - .validation - .tokenize(inputs, truncate) - .await - .map_err(|err| { - tracing::error!("Error occurred during tokenization. {err}"); - err - })?; - - // Return Encoding - Ok(encoding.map(|(encoding, _)| encoding)) - } - /// Add a new request to the queue and return a InferResponse - #[instrument(skip(self))] - pub(crate) async fn generate( - &self, - request: GenerateRequest, - ) -> Result { - // Create stream and keep semaphore permit as long as generate lives - let (_permit, mut stream) = self.generate_stream(request).await?; - // Return values - let mut result_prefill = Vec::new(); - let mut result_tokens = Vec::new(); - let mut result_prefill_length = 0; - let mut result_generated_text = None; - let mut result_start = None; - let mut result_queued = None; + let mut return_embeddings = None; - // Iterate on stream + let mut stream = response_rx.into_stream(); while let Some(response) = stream.next().await { match response? { // Add prefill tokens - InferStreamResponse::Prefill { - tokens, - tokens_length, - } => { - // Create Token objects - // We do that here instead of in the Python code as Rust for loops are faster - if let Some(tokens_val) = tokens { - result_prefill = tokens_val - .ids - .into_iter() - .zip(tokens_val.logprobs.into_iter()) - .zip(tokens_val.texts.into_iter()) - .map(|((id, logprob), text)| PrefillToken { id, text, logprob }) - .collect(); - } - result_prefill_length = tokens_length; + InferStreamResponse::Prefill { .. } => { + tracing::error!("Received a Prefill message in embed. This is a bug."); } // Push last token - InferStreamResponse::Token(token) => result_tokens.push(token), + InferStreamResponse::Token(..) => { + tracing::error!("Received a Token message in embed. This is a bug."); + } // Final message // Set return values - InferStreamResponse::End { - token, - generated_text, + InferStreamResponse::End { .. } => { + tracing::error!("Received an End message in embed. This is a bug."); + } + InferStreamResponse::Embed { + embedding, start, queued, } => { - result_tokens.push(token); - result_generated_text = Some(generated_text); - result_start = Some(start); - result_queued = Some(queued) - } - InferStreamResponse::Embed { .. } => { - // This should not happen - tracing::error!("Received an Embed message in generate. This is a bug."); + return_embeddings = Some(embedding.values); } } } - // Check that we received a `InferStreamResponse::End` message - if let (Some(generated_text), Some(queued), Some(start)) = - (result_generated_text, result_queued, result_start) - { - Ok(InferResponse { - prefill: result_prefill, - tokens: result_tokens, - prompt_tokens: result_prefill_length, - generated_text, - queued, - start, + if let Some(return_embeddings) = return_embeddings { + Ok(EmbedResponse { + embeddings: return_embeddings, }) } else { let err = InferError::IncompleteGeneration; @@ -379,6 +407,7 @@ impl Infer { Err(err) } } + /// Add best_of new requests to the queue and return a InferResponse of the sequence with /// the highest log probability per token #[instrument(skip(self))] diff --git a/router/src/server.rs b/router/src/server.rs index 542b9b592..c3565d3d2 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1342,10 +1342,153 @@ impl From for Event { )] #[instrument(skip_all)] async fn embed( + infer: Extension, mut client: Extension, Json(req): Json, ) -> Result, (StatusCode, Json)> { + let span = tracing::Span::current(); + let start_time = Instant::now(); + metrics::increment_counter!("lorax_request_count"); + + tracing::debug!("Input: {}", req.inputs); + let input = req.inputs.clone(); + + let stream = async_stream::stream! { + // Inference + let mut end_reached = false; + let mut error = false; + + match infer.embed(req).instrument(info_span!(parent: &span, "async_stream")).await { + // Keep permit as long as generate_stream lives + Ok((_permit, mut response_stream)) => { + // Server-Sent Event stream + while let Some(response) = response_stream.next().await { + match response { + Ok(response) => { + match response { + // Prefill is ignored + InferStreamResponse::Prefill { + .. + } => { + let err = InferError::from(ValidationError::EmbeddingModel); + metrics::increment_counter!("lorax_request_failure", "err" => "bad_request"); + tracing::error!("{err}"); + yield Ok(Event::from(err)); + break; + } + // Yield event for every new token + InferStreamResponse::Token(token) => { + let err = InferError::from(ValidationError::EmbeddingModel); + metrics::increment_counter!("lorax_request_failure", "err" => "bad_request"); + tracing::error!("{err}"); + yield Ok(Event::from(err)); + break; + } + // Yield event for last token and compute timings + InferStreamResponse::End { + .. + } => { + let err = InferError::from(ValidationError::EmbeddingModel); + metrics::increment_counter!("lorax_request_failure", "err" => "bad_request"); + tracing::error!("{err}"); + yield Ok(Event::from(err)); + break; + }, + InferStreamResponse::Embed { + embedding, + start, + queued, + } => { + // Token details + let details = match details { + true => Some(StreamDetails { + finish_reason: FinishReason::from(generated_text.finish_reason), + prompt_tokens: prefill_tokens_length, + generated_tokens: generated_text.generated_tokens, + seed: generated_text.seed, + }), + false => None, + }; + + // Timings + let total_time = start_time.elapsed(); + let validation_time = queued - start_time; + let queue_time = start - queued; + let inference_time = Instant::now() - start; + let time_per_token = inference_time / generated_text.generated_tokens; + + // Tracing metadata + span.record("total_time", format!("{total_time:?}")); + span.record("validation_time", format!("{validation_time:?}")); + span.record("queue_time", format!("{queue_time:?}")); + span.record("inference_time", format!("{inference_time:?}")); + span.record("time_per_token", format!("{time_per_token:?}")); + span.record("seed", format!("{:?}", generated_text.seed)); + + // Metrics + metrics::increment_counter!("lorax_request_success"); + metrics::histogram!("lorax_request_duration", total_time.as_secs_f64()); + metrics::histogram!("lorax_request_validation_duration", validation_time.as_secs_f64()); + metrics::histogram!("lorax_request_queue_duration", queue_time.as_secs_f64()); + metrics::histogram!("lorax_request_inference_duration", inference_time.as_secs_f64()); + metrics::histogram!("lorax_request_mean_time_per_token_duration", time_per_token.as_secs_f64()); + metrics::histogram!("lorax_request_generated_tokens", generated_text.generated_tokens as f64); + + + + // StreamResponse + end_reached = true; + + let mut output_text = generated_text.text; + if let Some(prompt) = add_prompt { + output_text = prompt + &output_text; + } + + tracing::debug!(parent: &span, "Output: {}", output_text); + tracing::info!(parent: &span, "Success"); + + let total_tokens = generated_text.generated_tokens + prefill_tokens_length; + if info.request_logger_url.is_some() { + let _ = request_logger_sender.send((total_tokens as i64, api_token.unwrap_or("".to_string()), info.model_id.clone())).await; + } + + let stream_token = StreamResponse { + token, + generated_text: Some(output_text), + details + }; + + yield Ok(callback(stream_token)); + break; + } + } + } + // yield error + Err(err) => { + error = true; + yield Ok(Event::from(err)); + break; + } + } + } + }, + // yield error + Err(err) => { + error = true; + yield Ok(Event::from(err)); + } + } + // Check if generation reached the end + // Skip if we already sent an error + if !end_reached && !error { + let err = InferError::IncompleteGeneration; + metrics::increment_counter!("lorax_request_failure", "err" => "incomplete"); + tracing::error!("{err}"); + yield Ok(Event::from(err)); + } + }; + let embeddings = client.embed(input).await.unwrap(); let embeddings = embeddings.get(0); From f99cfa14cbd36d9c4c33e8f1c205057a2f024386 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 30 May 2024 22:20:48 -0700 Subject: [PATCH 17/31] Fix server endpoint --- router/src/server.rs | 160 +------------------------------------------ 1 file changed, 3 insertions(+), 157 deletions(-) diff --git a/router/src/server.rs b/router/src/server.rs index c3565d3d2..4a4eb9a12 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1352,164 +1352,10 @@ async fn embed( tracing::debug!("Input: {}", req.inputs); - let input = req.inputs.clone(); - - let stream = async_stream::stream! { - // Inference - let mut end_reached = false; - let mut error = false; - - match infer.embed(req).instrument(info_span!(parent: &span, "async_stream")).await { - // Keep permit as long as generate_stream lives - Ok((_permit, mut response_stream)) => { - // Server-Sent Event stream - while let Some(response) = response_stream.next().await { - match response { - Ok(response) => { - match response { - // Prefill is ignored - InferStreamResponse::Prefill { - .. - } => { - let err = InferError::from(ValidationError::EmbeddingModel); - metrics::increment_counter!("lorax_request_failure", "err" => "bad_request"); - tracing::error!("{err}"); - yield Ok(Event::from(err)); - break; - } - // Yield event for every new token - InferStreamResponse::Token(token) => { - let err = InferError::from(ValidationError::EmbeddingModel); - metrics::increment_counter!("lorax_request_failure", "err" => "bad_request"); - tracing::error!("{err}"); - yield Ok(Event::from(err)); - break; - } - // Yield event for last token and compute timings - InferStreamResponse::End { - .. - } => { - let err = InferError::from(ValidationError::EmbeddingModel); - metrics::increment_counter!("lorax_request_failure", "err" => "bad_request"); - tracing::error!("{err}"); - yield Ok(Event::from(err)); - break; - }, - InferStreamResponse::Embed { - embedding, - start, - queued, - } => { - // Token details - let details = match details { - true => Some(StreamDetails { - finish_reason: FinishReason::from(generated_text.finish_reason), - prompt_tokens: prefill_tokens_length, - generated_tokens: generated_text.generated_tokens, - seed: generated_text.seed, - }), - false => None, - }; - - // Timings - let total_time = start_time.elapsed(); - let validation_time = queued - start_time; - let queue_time = start - queued; - let inference_time = Instant::now() - start; - let time_per_token = inference_time / generated_text.generated_tokens; - - // Tracing metadata - span.record("total_time", format!("{total_time:?}")); - span.record("validation_time", format!("{validation_time:?}")); - span.record("queue_time", format!("{queue_time:?}")); - span.record("inference_time", format!("{inference_time:?}")); - span.record("time_per_token", format!("{time_per_token:?}")); - span.record("seed", format!("{:?}", generated_text.seed)); - - // Metrics - metrics::increment_counter!("lorax_request_success"); - metrics::histogram!("lorax_request_duration", total_time.as_secs_f64()); - metrics::histogram!("lorax_request_validation_duration", validation_time.as_secs_f64()); - metrics::histogram!("lorax_request_queue_duration", queue_time.as_secs_f64()); - metrics::histogram!("lorax_request_inference_duration", inference_time.as_secs_f64()); - metrics::histogram!("lorax_request_mean_time_per_token_duration", time_per_token.as_secs_f64()); - metrics::histogram!("lorax_request_generated_tokens", generated_text.generated_tokens as f64); - - - - // StreamResponse - end_reached = true; - - let mut output_text = generated_text.text; - if let Some(prompt) = add_prompt { - output_text = prompt + &output_text; - } - - tracing::debug!(parent: &span, "Output: {}", output_text); - tracing::info!(parent: &span, "Success"); - - let total_tokens = generated_text.generated_tokens + prefill_tokens_length; - if info.request_logger_url.is_some() { - let _ = request_logger_sender.send((total_tokens as i64, api_token.unwrap_or("".to_string()), info.model_id.clone())).await; - } - - let stream_token = StreamResponse { - token, - generated_text: Some(output_text), - details - }; - - yield Ok(callback(stream_token)); - break; - } - } - } - // yield error - Err(err) => { - error = true; - yield Ok(Event::from(err)); - break; - } - } - } - }, - // yield error - Err(err) => { - error = true; - yield Ok(Event::from(err)); - } - } - // Check if generation reached the end - // Skip if we already sent an error - if !end_reached && !error { - let err = InferError::IncompleteGeneration; - metrics::increment_counter!("lorax_request_failure", "err" => "incomplete"); - tracing::error!("{err}"); - yield Ok(Event::from(err)); - } - }; - - let embeddings = client.embed(input).await.unwrap(); - let embeddings = embeddings.get(0); - - // TODO: better error enums - if (!embeddings.unwrap().error_msg.is_empty()) { - return Err(( - StatusCode::BAD_REQUEST, - Json(ErrorResponse { - error: embeddings.unwrap().error_msg.clone(), - error_type: "model doesn't support embeddings".to_string(), - }), - )); - } + // Inference + let response = infer.embed(req).await?; - let values = embeddings - .map(|emb| emb.embeddings.as_ref().map(|emb| emb.values.clone())) - .flatten() - .unwrap_or_default(); - Ok(Json(EmbedResponse { - embeddings: values.clone(), - })) + Ok(Json(response)) } /// Tokenize inputs From a6dec7240453368f0f0c8aa339fbf5959bb4b155 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 31 May 2024 09:22:48 -0700 Subject: [PATCH 18/31] Clone --- router/src/batch.rs | 18 +++++++++--------- router/src/scheduler.rs | 4 ++-- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/router/src/batch.rs b/router/src/batch.rs index 21fb116ab..263712534 100644 --- a/router/src/batch.rs +++ b/router/src/batch.rs @@ -38,7 +38,7 @@ impl ValidRequest for ValidGenerateRequest { } fn adapter(&self) -> Adapter { - self.adapter + self.adapter.clone() } fn to_batch(&self, num_entries: usize, queue_len: usize) -> Box { @@ -67,11 +67,11 @@ impl ValidRequest for ValidEmbedRequest { } fn adapter(&self) -> Adapter { - self.adapter + self.adapter.clone() } fn to_batch(&self, num_entries: usize, queue_len: usize) -> Box { - Box::new(EmbedBatchEntries::new()) + Box::new(EmbedBatchEntries::new(num_entries, queue_len)) } fn as_any(&self) -> &dyn Any { @@ -122,11 +122,11 @@ impl BatchEntriesState { let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); next_batch_span.follows_from(&Span::current()); - let mut batch_requests = Vec::with_capacity(num_entries); - let mut batch_entries = + let batch_requests = Vec::with_capacity(num_entries); + let batch_entries = IntMap::with_capacity_and_hasher(num_entries, BuildNoHashHasher::default()); - let mut index_to_adapter = HashMap::with_capacity(queue_len); + let index_to_adapter = HashMap::with_capacity(queue_len); Self { batch_requests, @@ -356,9 +356,9 @@ pub(crate) struct EmbedBatchEntries { } impl EmbedBatchEntries { - pub(crate) fn new() -> Self { + pub(crate) fn new(num_entries: usize, queue_len: usize) -> Self { Self { - state: BatchEntriesState::new(0, 0), + state: BatchEntriesState::new(num_entries, queue_len), } } } @@ -447,7 +447,7 @@ impl BatchEntries for EmbedBatchEntries { generation_health: &Arc, ) -> Option { embed( - &mut client, + client, batch, &mut self.state.batch_entries, &generation_health, diff --git a/router/src/scheduler.rs b/router/src/scheduler.rs index f59720052..2e1094288 100644 --- a/router/src/scheduler.rs +++ b/router/src/scheduler.rs @@ -265,7 +265,7 @@ impl AdapterSchedulerState { // Pop entries starting from the front of the queue let mut batch_entries: Option> = None; - while let Some((id, mut entry, adapter)) = queues_state.next_entry() { + while let Some((id, entry, adapter)) = queues_state.next_entry() { // Filter entries where the response receiver was dropped (== entries where the request // was dropped by the client) if entry.response_tx.is_disconnected() { @@ -323,7 +323,7 @@ impl AdapterSchedulerState { ); } - if !batch_entries.as_ref().unwrap().add(id, entry, adapter) { + if !batch_entries.unwrap().add(id, entry, adapter) { // Incompatible entry for this batch. Reinsert and break queues_state.push_front(&adapter, id, entry); break; From 511721858f315f42fe14ecf10b1aa7ba045ef2a1 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 31 May 2024 10:02:14 -0700 Subject: [PATCH 19/31] Fix scheduler --- router/src/batch.rs | 37 +++++++++++++++++++++++++------------ router/src/scheduler.rs | 12 +++++++----- 2 files changed, 32 insertions(+), 17 deletions(-) diff --git a/router/src/batch.rs b/router/src/batch.rs index 263712534..546f8098e 100644 --- a/router/src/batch.rs +++ b/router/src/batch.rs @@ -201,7 +201,8 @@ impl BatchEntriesState { #[async_trait] pub(crate) trait BatchEntries: Sync + Send + Debug { - fn add(&mut self, id: u64, entry: Entry, adapter: Adapter) -> bool; + fn can_add(&self, entry: &Entry) -> bool; + fn add(&mut self, id: u64, entry: Entry, adapter: Adapter); fn extend(&mut self, entries: Box); fn drain(&mut self) -> Vec<(Adapter, u64, Entry)>; fn create_batch_data(&self, batch_id: u64, max_tokens: u32) -> Batch; @@ -242,7 +243,7 @@ impl GenerateBatchEntries { #[async_trait] impl BatchEntries for GenerateBatchEntries { - fn add(&mut self, id: u64, entry: Entry, adapter: Adapter) -> bool { + fn can_add(&self, entry: &Entry) -> bool { // return false if the entry.request is not of type ValidGenerateRequest let valid_request = entry .request @@ -250,9 +251,16 @@ impl BatchEntries for GenerateBatchEntries { .as_any() .downcast_ref::(); - if valid_request.is_none() { - return false; - } + let result = valid_request.is_some(); + result + } + + fn add(&mut self, id: u64, entry: Entry, adapter: Adapter) { + let valid_request = entry + .request + .as_ref() + .as_any() + .downcast_ref::(); let request = valid_request.unwrap(); let request_proto = Request { @@ -267,7 +275,6 @@ impl BatchEntries for GenerateBatchEntries { }; self.state.add(id, entry, adapter, request_proto); - return true; } fn extend(&mut self, entries: Box) { @@ -365,17 +372,24 @@ impl EmbedBatchEntries { #[async_trait] impl BatchEntries for EmbedBatchEntries { - fn add(&mut self, id: u64, entry: Entry, adapter: Adapter) -> bool { - // return false if the entry.request is not of type ValidGenerateRequest + fn can_add(&self, entry: &Entry) -> bool { + // return false if the entry.request is not of type ValidEmbedRequest let valid_request = entry .request .as_ref() .as_any() .downcast_ref::(); - if valid_request.is_none() { - return false; - } + let result = valid_request.is_some(); + result + } + + fn add(&mut self, id: u64, entry: Entry, adapter: Adapter) { + let valid_request = entry + .request + .as_ref() + .as_any() + .downcast_ref::(); let request = valid_request.unwrap(); let request_proto = Request { @@ -390,7 +404,6 @@ impl BatchEntries for EmbedBatchEntries { }; self.state.add(id, entry, adapter, request_proto); - return true; } fn extend(&mut self, entries: Box) { diff --git a/router/src/scheduler.rs b/router/src/scheduler.rs index 2e1094288..579f14d96 100644 --- a/router/src/scheduler.rs +++ b/router/src/scheduler.rs @@ -275,8 +275,8 @@ impl AdapterSchedulerState { if self.requires_padding { let mut batch_requests_len = 0; - if batch_entries.is_some() { - batch_requests_len = batch_entries.as_ref().unwrap().len(); + if let Some(batch_entries) = batch_entries.as_ref() { + batch_requests_len = batch_entries.len(); } // We pad to max input length in the Python shards @@ -323,18 +323,20 @@ impl AdapterSchedulerState { ); } - if !batch_entries.unwrap().add(id, entry, adapter) { + if !batch_entries.as_ref().unwrap().can_add(&entry) { // Incompatible entry for this batch. Reinsert and break queues_state.push_front(&adapter, id, entry); break; } + + batch_entries.as_mut().unwrap().add(id, entry, adapter) } if batch_entries.is_none() { return None; } - let batch_entries = batch_entries.unwrap(); + let mut batch_entries = batch_entries.unwrap(); // Empty batch if batch_entries.is_empty() { @@ -346,7 +348,7 @@ impl AdapterSchedulerState { // Batch is too small if batch_entries.len() < min_size { // Add back entries to the queue in the correct order - for (adapter, id, entry) in batch_entries.as_ref().drain() { + for (adapter, id, entry) in batch_entries.drain() { queues_state.push_front(&adapter, id, entry); } From abe763c8c5c472b9a0f1c41a312cc1030ec847b8 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 31 May 2024 10:16:47 -0700 Subject: [PATCH 20/31] Cleanup --- router/src/batch.rs | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/router/src/batch.rs b/router/src/batch.rs index 546f8098e..fb0dee5b1 100644 --- a/router/src/batch.rs +++ b/router/src/batch.rs @@ -159,7 +159,8 @@ impl BatchEntriesState { fn drain(&mut self) -> Vec<(Adapter, u64, Entry)> { let mut entries = Vec::with_capacity(self.batch_requests.len()); - for r in self.batch_requests.into_iter().rev() { + // TODO(travis): clone is not ideal, find a way to do this cleanly in place + for r in self.batch_requests.clone().into_iter().rev() { let id = r.id; let entry = self.batch_entries.remove(&id).unwrap(); let adapter_index = r.adapter_index; @@ -175,9 +176,10 @@ impl BatchEntriesState { let size = self.len() as u32; self.next_batch_span.record("batch_size", size); + // TODO(travis): clone is not ideal, find a way to do this cleanly in place Batch { id: batch_id, - requests: self.batch_requests, + requests: self.batch_requests.clone(), size, max_tokens, } @@ -312,13 +314,13 @@ impl BatchEntries for GenerateBatchEntries { } fn update_entries_span(&mut self, create_span_fn: Box Span>) { - let state = self.state(); + let next_batch_span = &self.state().next_batch_span; for (_, entry) in self.state.batch_entries.iter_mut() { // Create a new span to link the batch back to this entry let entry_batch_span = create_span_fn(&entry.span); // Add relationships - state.next_batch_span.follows_from(&entry_batch_span); - entry_batch_span.follows_from(&state.next_batch_span); + next_batch_span.follows_from(&entry_batch_span); + entry_batch_span.follows_from(next_batch_span); // Update entry entry.temp_span = Some(entry_batch_span); } @@ -331,7 +333,7 @@ impl BatchEntries for GenerateBatchEntries { generation_health: &Arc, ) -> Option { prefill( - &mut client, + client, batch, &mut self.state.batch_entries, &generation_health, @@ -347,7 +349,7 @@ impl BatchEntries for GenerateBatchEntries { generation_health: &Arc, ) -> Option { decode( - &mut client, + client, batches, &mut self.state.batch_entries, &generation_health, From 06ed8a41d3b7788d0974516b4974d5b3a0227aa5 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 31 May 2024 16:54:14 -0700 Subject: [PATCH 21/31] Fix take --- router/src/batch.rs | 31 +++++++++++++++++++------------ router/src/infer.rs | 2 +- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/router/src/batch.rs b/router/src/batch.rs index fb0dee5b1..d8cedbbb3 100644 --- a/router/src/batch.rs +++ b/router/src/batch.rs @@ -212,6 +212,7 @@ pub(crate) trait BatchEntries: Sync + Send + Debug { fn is_empty(&self) -> bool; fn len(&self) -> usize; fn state(&self) -> &BatchEntriesState; + fn mut_state(&mut self) -> &mut BatchEntriesState; fn set_span(&mut self, span: Span); fn update_entries_span(&mut self, create_span_fn: Box Span>); @@ -279,10 +280,9 @@ impl BatchEntries for GenerateBatchEntries { self.state.add(id, entry, adapter, request_proto); } - fn extend(&mut self, entries: Box) { - self.state() - .batch_entries - .extend(entries.state().batch_entries); + fn extend(&mut self, mut entries: Box) { + let new_batch_entries = std::mem::take(&mut entries.mut_state().batch_entries); + self.state.batch_entries.extend(new_batch_entries); } fn drain(&mut self) -> Vec<(Adapter, u64, Entry)> { @@ -309,12 +309,16 @@ impl BatchEntries for GenerateBatchEntries { &self.state } + fn mut_state(&mut self) -> &mut BatchEntriesState { + &mut self.state + } + fn set_span(&mut self, span: Span) { self.state.next_batch_span = span; } fn update_entries_span(&mut self, create_span_fn: Box Span>) { - let next_batch_span = &self.state().next_batch_span; + let next_batch_span = &self.state.next_batch_span; for (_, entry) in self.state.batch_entries.iter_mut() { // Create a new span to link the batch back to this entry let entry_batch_span = create_span_fn(&entry.span); @@ -408,10 +412,9 @@ impl BatchEntries for EmbedBatchEntries { self.state.add(id, entry, adapter, request_proto); } - fn extend(&mut self, entries: Box) { - self.state() - .batch_entries - .extend(entries.state().batch_entries); + fn extend(&mut self, mut entries: Box) { + let new_batch_entries = std::mem::take(&mut entries.mut_state().batch_entries); + self.state.batch_entries.extend(new_batch_entries); } fn drain(&mut self) -> Vec<(Adapter, u64, Entry)> { @@ -438,18 +441,22 @@ impl BatchEntries for EmbedBatchEntries { &self.state } + fn mut_state(&mut self) -> &mut BatchEntriesState { + &mut self.state + } + fn set_span(&mut self, span: Span) { self.state.next_batch_span = span; } fn update_entries_span(&mut self, create_span_fn: Box Span>) { - let state = self.state(); + let next_batch_span = &self.state.next_batch_span; for (_, entry) in self.state.batch_entries.iter_mut() { // Create a new span to link the batch back to this entry let entry_batch_span = create_span_fn(&entry.span); // Add relationships - state.next_batch_span.follows_from(&entry_batch_span); - entry_batch_span.follows_from(&state.next_batch_span); + next_batch_span.follows_from(&entry_batch_span); + entry_batch_span.follows_from(next_batch_span); // Update entry entry.temp_span = Some(entry_batch_span); } diff --git a/router/src/infer.rs b/router/src/infer.rs index c802c329c..4a24b3f41 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -512,7 +512,7 @@ async fn batching_task( let adapters_in_use = batch_entries.adapters_in_use(); // Try to get a new batch - if let Some((mut new_entries, new_batch)) = adapter_scheduler + if let Some((new_entries, new_batch)) = adapter_scheduler .next_batch( adapters_in_use, min_size, From e46f2dcf220bc0c333ead5f56da23d00d95e6baa Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 31 May 2024 21:54:55 -0700 Subject: [PATCH 22/31] Fixed --- router/src/batch.rs | 65 +++++++---------------------------------- router/src/infer.rs | 49 +++++++++++++++++++++---------- router/src/scheduler.rs | 19 ++++++++++-- 3 files changed, 59 insertions(+), 74 deletions(-) diff --git a/router/src/batch.rs b/router/src/batch.rs index d8cedbbb3..f47e65b78 100644 --- a/router/src/batch.rs +++ b/router/src/batch.rs @@ -13,7 +13,7 @@ use lorax_client::{ }; use nohash_hasher::{BuildNoHashHasher, IntMap}; use tokio::time::Instant; -use tracing::{info_span, Instrument, Span}; +use tracing::{info_span, span, Instrument, Span}; use crate::{ adapter::Adapter, @@ -113,15 +113,10 @@ pub(crate) struct BatchEntriesState { pub(crate) batch_requests: Vec, pub(crate) batch_entries: HashMap>, pub(crate) index_to_adapter: HashMap, - next_batch_span: Span, } impl BatchEntriesState { pub(crate) fn new(num_entries: usize, queue_len: usize) -> Self { - // Create span for this batch to add context to inference calls - let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); - next_batch_span.follows_from(&Span::current()); - let batch_requests = Vec::with_capacity(num_entries); let batch_entries = IntMap::with_capacity_and_hasher(num_entries, BuildNoHashHasher::default()); @@ -132,19 +127,10 @@ impl BatchEntriesState { batch_requests, batch_entries, index_to_adapter, - next_batch_span, } } fn add(&mut self, id: u64, mut entry: Entry, adapter: Adapter, request: Request) { - // Create a new span to link the batch back to this entry - let entry_batch_span = info_span!(parent: &entry.span, "infer"); - // Add relationships - self.next_batch_span.follows_from(&entry_batch_span); - entry_batch_span.follows_from(&self.next_batch_span); - // Update entry - entry.temp_span = Some(entry_batch_span); - self.batch_requests.push(request); // Set batch_time @@ -174,7 +160,6 @@ impl BatchEntriesState { fn create_batch_data(&self, batch_id: u64, max_tokens: u32) -> Batch { // Final batch size let size = self.len() as u32; - self.next_batch_span.record("batch_size", size); // TODO(travis): clone is not ideal, find a way to do this cleanly in place Batch { @@ -213,13 +198,12 @@ pub(crate) trait BatchEntries: Sync + Send + Debug { fn len(&self) -> usize; fn state(&self) -> &BatchEntriesState; fn mut_state(&mut self) -> &mut BatchEntriesState; - fn set_span(&mut self, span: Span); - fn update_entries_span(&mut self, create_span_fn: Box Span>); async fn process_first( &mut self, client: &mut ShardedClient, batch: Batch, + span: Span, generation_health: &Arc, ) -> Option; @@ -227,6 +211,7 @@ pub(crate) trait BatchEntries: Sync + Send + Debug { &mut self, client: &mut ShardedClient, batches: Vec, + span: Span, generation_health: &Arc, ) -> Option; } @@ -313,27 +298,11 @@ impl BatchEntries for GenerateBatchEntries { &mut self.state } - fn set_span(&mut self, span: Span) { - self.state.next_batch_span = span; - } - - fn update_entries_span(&mut self, create_span_fn: Box Span>) { - let next_batch_span = &self.state.next_batch_span; - for (_, entry) in self.state.batch_entries.iter_mut() { - // Create a new span to link the batch back to this entry - let entry_batch_span = create_span_fn(&entry.span); - // Add relationships - next_batch_span.follows_from(&entry_batch_span); - entry_batch_span.follows_from(next_batch_span); - // Update entry - entry.temp_span = Some(entry_batch_span); - } - } - async fn process_first( &mut self, client: &mut ShardedClient, batch: Batch, + span: Span, generation_health: &Arc, ) -> Option { prefill( @@ -342,7 +311,7 @@ impl BatchEntries for GenerateBatchEntries { &mut self.state.batch_entries, &generation_health, ) - .instrument(self.state.next_batch_span) + .instrument(span) .await } @@ -350,6 +319,7 @@ impl BatchEntries for GenerateBatchEntries { &mut self, client: &mut ShardedClient, batches: Vec, + span: Span, generation_health: &Arc, ) -> Option { decode( @@ -358,7 +328,7 @@ impl BatchEntries for GenerateBatchEntries { &mut self.state.batch_entries, &generation_health, ) - .instrument(self.state.next_batch_span) + .instrument(span) .await } } @@ -445,27 +415,11 @@ impl BatchEntries for EmbedBatchEntries { &mut self.state } - fn set_span(&mut self, span: Span) { - self.state.next_batch_span = span; - } - - fn update_entries_span(&mut self, create_span_fn: Box Span>) { - let next_batch_span = &self.state.next_batch_span; - for (_, entry) in self.state.batch_entries.iter_mut() { - // Create a new span to link the batch back to this entry - let entry_batch_span = create_span_fn(&entry.span); - // Add relationships - next_batch_span.follows_from(&entry_batch_span); - entry_batch_span.follows_from(next_batch_span); - // Update entry - entry.temp_span = Some(entry_batch_span); - } - } - async fn process_first( &mut self, client: &mut ShardedClient, batch: Batch, + span: Span, generation_health: &Arc, ) -> Option { embed( @@ -474,7 +428,7 @@ impl BatchEntries for EmbedBatchEntries { &mut self.state.batch_entries, &generation_health, ) - .instrument(self.state.next_batch_span) + .instrument(span) .await } @@ -482,6 +436,7 @@ impl BatchEntries for EmbedBatchEntries { &mut self, client: &mut ShardedClient, batches: Vec, + span: Span, generation_health: &Arc, ) -> Option { // TODO(travis): send error (programming eroor) if we get here diff --git a/router/src/infer.rs b/router/src/infer.rs index 4a24b3f41..49dc0ca81 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -471,7 +471,7 @@ async fn batching_task( // Get the next batch from the queue // This batch might be smaller than the maximum batch size if there are not enough requests // waiting in the queue - while let Some((mut batch_entries, batch)) = adapter_scheduler + while let Some((mut batch_entries, batch, span)) = adapter_scheduler .next_batch( HashSet::new(), None, @@ -481,7 +481,7 @@ async fn batching_task( .await { let mut cached_batch = batch_entries - .process_first(&mut client, batch, &generation_health) + .process_first(&mut client, batch, span, &generation_health) .await; let mut waiting_tokens = 1; @@ -512,7 +512,7 @@ async fn batching_task( let adapters_in_use = batch_entries.adapters_in_use(); // Try to get a new batch - if let Some((new_entries, new_batch)) = adapter_scheduler + if let Some((new_entries, new_batch, span)) = adapter_scheduler .next_batch( adapters_in_use, min_size, @@ -528,11 +528,24 @@ async fn batching_task( metrics::increment_counter!("lorax_batch_concat", "reason" => "wait_exceeded"); } - batch_entries.update_entries_span(Box::new(create_waiting_span)); + batch_entries + .mut_state() + .batch_entries + .iter_mut() + .for_each(|(_, entry)| { + // Create a new span to add the info that this entry is waiting + // because a new batch is being computed + let entry_waiting_span = info_span!(parent: &entry.span, "waiting"); + // Add relationships + span.follows_from(&entry_waiting_span); + entry_waiting_span.follows_from(&span); + // Update entry + entry.temp_span = Some(entry_waiting_span); + }); // Generate one token for this new batch to have the attention past in cache let new_cached_batch = batch_entries - .process_first(&mut client, new_batch, &generation_health) + .process_first(&mut client, new_batch, span, &generation_health) .await; // Reset waiting counter waiting_tokens = 1; @@ -547,11 +560,23 @@ async fn batching_task( let next_batch_size = batch_entries.len(); let next_batch_span = info_span!(parent: None, "batch", batch_size = next_batch_size); - batch_entries.set_span(next_batch_span); - batch_entries.update_entries_span(Box::new(create_infer_span)); + + batch_entries + .mut_state() + .batch_entries + .iter_mut() + .for_each(|(_, entry)| { + // Create a new span to link the batch back to this entry + let entry_batch_span = info_span!(parent: &entry.span, "infer"); + // Add relationships + next_batch_span.follows_from(&entry_batch_span); + entry_batch_span.follows_from(&next_batch_span); + // Update entry + entry.temp_span = Some(entry_batch_span); + }); cached_batch = batch_entries - .process_next(&mut client, batches, &generation_health) + .process_next(&mut client, batches, next_batch_span, &generation_health) .await; waiting_tokens += 1; } @@ -561,14 +586,6 @@ async fn batching_task( } } -fn create_waiting_span(parent: &Span) -> Span { - info_span!(parent: parent, "waiting") -} - -fn create_infer_span(parent: &Span) -> Span { - info_span!(parent:parent, "infer") -} - #[instrument(skip_all)] pub(crate) async fn prefill( client: &mut ShardedClient, diff --git a/router/src/scheduler.rs b/router/src/scheduler.rs index 579f14d96..5f790b880 100644 --- a/router/src/scheduler.rs +++ b/router/src/scheduler.rs @@ -103,7 +103,7 @@ impl AdapterScheduler { } } -type NextBatch = (Box, Batch); +type NextBatch = (Box, Batch, Span); /// Background task that manages the queues of the various adapters /// TODO(geoffrey): add tracing (span object) to the various commands @@ -250,6 +250,10 @@ impl AdapterSchedulerState { } } + // Create span for this batch to add context to inference calls + let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); + next_batch_span.follows_from(&Span::current()); + let mut max_input_length = 0; let mut prefill_tokens: u32 = 0; let mut decode_tokens: u32 = 0; @@ -265,7 +269,7 @@ impl AdapterSchedulerState { // Pop entries starting from the front of the queue let mut batch_entries: Option> = None; - while let Some((id, entry, adapter)) = queues_state.next_entry() { + while let Some((id, mut entry, adapter)) = queues_state.next_entry() { // Filter entries where the response receiver was dropped (== entries where the request // was dropped by the client) if entry.response_tx.is_disconnected() { @@ -329,6 +333,14 @@ impl AdapterSchedulerState { break; } + // Create a new span to link the batch back to this entry + let entry_batch_span = info_span!(parent: &entry.span, "infer"); + // Add relationships + next_batch_span.follows_from(&entry_batch_span); + entry_batch_span.follows_from(&next_batch_span); + // Update entry + entry.temp_span = Some(entry_batch_span); + batch_entries.as_mut().unwrap().add(id, entry, adapter) } @@ -356,6 +368,7 @@ impl AdapterSchedulerState { } } + next_batch_span.record("batch_size", batch_entries.len() as u32); let max_tokens = prefill_tokens + decode_tokens; let batch = batch_entries.create_batch_data(self.next_batch_id, max_tokens); @@ -364,7 +377,7 @@ impl AdapterSchedulerState { metrics::histogram!("lorax_batch_next_size", batch_entries.len() as f64); - Some((batch_entries, batch)) + Some((batch_entries, batch, next_batch_span)) } } From acc8f77e2400d7824a7064821a31fa5cd60dcf95 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 31 May 2024 22:15:18 -0700 Subject: [PATCH 23/31] Fixed batch --- server/lorax_server/models/flash_bert.py | 2 +- server/lorax_server/models/types.py | 54 ++++++++++++++++++++++-- server/lorax_server/server.py | 12 ++++-- 3 files changed, 61 insertions(+), 7 deletions(-) diff --git a/server/lorax_server/models/flash_bert.py b/server/lorax_server/models/flash_bert.py index 86904117d..b4441a76d 100644 --- a/server/lorax_server/models/flash_bert.py +++ b/server/lorax_server/models/flash_bert.py @@ -237,7 +237,7 @@ def embed(self, batch: FlashEmbeddingBatch) -> Embedding: return Embedding(values=cpu_results[: self.hidden_size]) - def tokenize_to_batch(self, inputs) -> FlashEmbeddingBatch: + def batch_from_pb(self, batch: generate_pb2.Batch,) -> FlashEmbeddingBatch: tokens = self.tokenizer(inputs, return_token_type_ids=True) num_tokens = len(tokens["input_ids"]) position_ids = range(num_tokens) diff --git a/server/lorax_server/models/types.py b/server/lorax_server/models/types.py index c77f784e6..76631270b 100644 --- a/server/lorax_server/models/types.py +++ b/server/lorax_server/models/types.py @@ -136,8 +136,56 @@ class FlashEmbeddingBatch(ABC): max_s: int size: int - def __len__(self): + def __len__(self) -> int: return self.size - def from_pb(self, *args, **kwargs): - return None + @classmethod + def from_pb( + self, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + tokenizers: TokenizerManager, + dtype: torch.dtype, + device: torch.device, + ) -> "FlashEmbeddingBatch": + batch_inputs = [] + max_truncation = 0 + for r in pb.requests: + inputs = tokenizers.get_inputs(r, tokenizer) + batch_inputs.append(inputs) + max_truncation = max(max_truncation, r.truncate) + + batch_tokenized_inputs = tokenizer( + batch_inputs, + return_token_type_ids=True, + truncation=True, + max_length=max_truncation, + ) + + max_s = 0 + position_ids = [] + + cumulative_length = 0 + cu_seqlens = [0] + + for i, (r, tokenized_input) in enumerate(zip(pb.requests, batch_tokenized_inputs)): + tokenized_input = tokenized_input[-r.truncate :] + + input_length = len(tokenized_input) + max_s = max(max_s, input_length) + cu_seqlens.append(cumulative_length + input_length) + + # Position ids + request_position_ids = torch.arange(0, input_length, dtype=torch.int32) + position_ids.append(request_position_ids) + + cumulative_length += input_length + + return FlashEmbeddingBatch( + input_ids=torch.tensor(batch_tokenized_inputs["input_ids"], dtype=torch.int32, device=device), + token_type_ids=torch.tensor(batch_tokenized_inputs["token_type_ids"], dtype=torch.int32, device=device), + position_ids=torch.tensor(position_ids, dtype=torch.int32, device=device), + cu_seqlens=torch.tensor(cu_seqlens, dtype=torch.int32, device=device), + max_s=max_s, + size=len(batch_inputs), + ) diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index 5ea9bd6cf..45962f9bf 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -101,9 +101,15 @@ async def Embed(self, request: generate_pb2.EmbedRequest, context): return generate_pb2.EmbedResponse( embeddings=generate_pb2.Embedding(), errorMsg="Model does not support embeddings" ) - batch = request.inputs - tokenised_batch = self.model.tokenize_to_batch(batch) - embeddings = self.model.embed(tokenised_batch) + + batch = self.model.batch_type.from_pb( + request.batch, + self.model.tokenizer, + self.model.tokenizers, + self.model.dtype, + self.model.device, + ) + embeddings = self.model.embed(batch) return generate_pb2.EmbedResponse(embeddings=embeddings) async def Decode(self, request: generate_pb2.DecodeRequest, context): From 5f0cac78a595402cdaec19e2d9daf5083fcf004a Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 31 May 2024 22:18:05 -0700 Subject: [PATCH 24/31] Cleanup --- server/lorax_server/models/flash_bert.py | 13 ------------- server/lorax_server/server.py | 5 +---- 2 files changed, 1 insertion(+), 17 deletions(-) diff --git a/server/lorax_server/models/flash_bert.py b/server/lorax_server/models/flash_bert.py index b4441a76d..07aa96546 100644 --- a/server/lorax_server/models/flash_bert.py +++ b/server/lorax_server/models/flash_bert.py @@ -236,16 +236,3 @@ def embed(self, batch: FlashEmbeddingBatch) -> Embedding: cpu_results = embedding.view(-1).tolist() return Embedding(values=cpu_results[: self.hidden_size]) - - def batch_from_pb(self, batch: generate_pb2.Batch,) -> FlashEmbeddingBatch: - tokens = self.tokenizer(inputs, return_token_type_ids=True) - num_tokens = len(tokens["input_ids"]) - position_ids = range(num_tokens) - return FlashEmbeddingBatch( - input_ids=torch.tensor(tokens["input_ids"], dtype=torch.int32, device=self.device), - token_type_ids=torch.tensor(tokens["token_type_ids"], dtype=torch.int32, device=self.device), - position_ids=torch.tensor(position_ids, dtype=torch.int32, device=self.device), - cu_seqlens=torch.tensor([0, num_tokens], dtype=torch.int32, device=self.device), - max_s=num_tokens, - size=1, - ) diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index 45962f9bf..aafab68a5 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -97,10 +97,7 @@ async def Prefill(self, request: generate_pb2.PrefillRequest, context): async def Embed(self, request: generate_pb2.EmbedRequest, context): if not self.model.supports_embeddings: - logger.error("Model does not support embeddings") - return generate_pb2.EmbedResponse( - embeddings=generate_pb2.Embedding(), errorMsg="Model does not support embeddings" - ) + raise ValueError("Model does not support embeddings") batch = self.model.batch_type.from_pb( request.batch, From 7898dba2892853be2859838c824411639a632365 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Sat, 1 Jun 2024 13:59:53 -0700 Subject: [PATCH 25/31] Fix shape --- server/lorax_server/models/flash_bert.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/server/lorax_server/models/flash_bert.py b/server/lorax_server/models/flash_bert.py index 07aa96546..81699c38f 100644 --- a/server/lorax_server/models/flash_bert.py +++ b/server/lorax_server/models/flash_bert.py @@ -226,13 +226,14 @@ def forward(self, batch: FlashEmbeddingBatch): @tracer.start_as_current_span("embed") def embed(self, batch: FlashEmbeddingBatch) -> Embedding: - embedding = self.model.forward( + embedding: torch.Tensor = self.model.forward( input_ids=batch.input_ids, token_type_ids=batch.token_type_ids, position_ids=batch.position_ids, cu_seqlens=batch.cu_seqlens, max_s=batch.max_s, ) - cpu_results = embedding.view(-1).tolist() + embedding = embedding.reshape(embedding.shape[0], -1)[:, : self.hidden_size] - return Embedding(values=cpu_results[: self.hidden_size]) + cpu_results = embedding.cpu().tolist() + return Embedding(values=cpu_results) From d3a84e934e488c2b8117a830376e732ccc696fb6 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Sat, 1 Jun 2024 14:15:29 -0700 Subject: [PATCH 26/31] Fixed warmup --- server/lorax_server/models/types.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/server/lorax_server/models/types.py b/server/lorax_server/models/types.py index 76631270b..84ca35c6c 100644 --- a/server/lorax_server/models/types.py +++ b/server/lorax_server/models/types.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from typing import List, Optional +import numpy as np import torch from transformers import PreTrainedTokenizerBase @@ -155,21 +156,24 @@ def from_pb( batch_inputs.append(inputs) max_truncation = max(max_truncation, r.truncate) - batch_tokenized_inputs = tokenizer( + batch_inputs = tokenizer( batch_inputs, return_token_type_ids=True, truncation=True, max_length=max_truncation, ) + batch_tokenized_inputs = batch_inputs["input_ids"] - max_s = 0 + all_input_ids = [] position_ids = [] + cu_seqlens = [0] + max_s = 0 cumulative_length = 0 - cu_seqlens = [0] for i, (r, tokenized_input) in enumerate(zip(pb.requests, batch_tokenized_inputs)): tokenized_input = tokenized_input[-r.truncate :] + all_input_ids.append(tokenized_input) input_length = len(tokenized_input) max_s = max(max_s, input_length) @@ -180,11 +184,21 @@ def from_pb( position_ids.append(request_position_ids) cumulative_length += input_length + + if len(pb.requests) > 1: + input_ids = np.concatenate(all_input_ids, dtype=np.int64) + position_ids = torch.cat(position_ids) + else: + input_ids = all_input_ids[0] + position_ids = position_ids[0] + + input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + position_ids = position_ids.to(device) return FlashEmbeddingBatch( - input_ids=torch.tensor(batch_tokenized_inputs["input_ids"], dtype=torch.int32, device=device), - token_type_ids=torch.tensor(batch_tokenized_inputs["token_type_ids"], dtype=torch.int32, device=device), - position_ids=torch.tensor(position_ids, dtype=torch.int32, device=device), + input_ids=input_ids, + token_type_ids=torch.tensor(batch_inputs["token_type_ids"], dtype=torch.int32, device=device), + position_ids=position_ids, cu_seqlens=torch.tensor(cu_seqlens, dtype=torch.int32, device=device), max_s=max_s, size=len(batch_inputs), From ccc6c06e654362712c37311502cb2ae838761f4b Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Sat, 1 Jun 2024 14:20:39 -0700 Subject: [PATCH 27/31] Fixed adapter --- router/src/infer.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/router/src/infer.rs b/router/src/infer.rs index 49dc0ca81..7184b1f77 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -320,10 +320,10 @@ impl Infer { let adapter = Adapter::new( AdapterParameters { - adapter_ids: vec!["".to_string()], + adapter_ids: vec![BASE_MODEL_ADAPTER_ID.to_string()], ..Default::default() }, - "hf".to_string(), + "hub".to_string(), 0, None, ); From bfac3b885ffb6f476b5aa4b19c04b6950740cd5b Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 3 Jun 2024 10:49:31 -0700 Subject: [PATCH 28/31] DEBUG --- README.md | 11 +++++++++++ server/lorax_server/server.py | 1 + 2 files changed, 12 insertions(+) diff --git a/README.md b/README.md index 25ce24621..637ba3a3d 100644 --- a/README.md +++ b/README.md @@ -100,6 +100,17 @@ curl 127.0.0.1:8080/generate \ -H 'Content-Type: application/json' ``` +Embed: + +```shell +curl 127.0.0.1:8080/embed \ +-X POST \ +-d '{ +"inputs": "[INST] Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? [/INST]" +}' \ +-H 'Content-Type: application/json' +``` + Prompt a LoRA adapter: ```shell diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index aafab68a5..f98d7a9f8 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -96,6 +96,7 @@ async def Prefill(self, request: generate_pb2.PrefillRequest, context): ) async def Embed(self, request: generate_pb2.EmbedRequest, context): + print("!!! EMBED") if not self.model.supports_embeddings: raise ValueError("Model does not support embeddings") From 652d4971a03a6332e8f86864e2c16b57b04237a4 Mon Sep 17 00:00:00 2001 From: Magdy Saleh Date: Thu, 6 Jun 2024 17:32:58 +0000 Subject: [PATCH 29/31] add warmup method --- server/lorax_server/models/flash_bert.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/server/lorax_server/models/flash_bert.py b/server/lorax_server/models/flash_bert.py index 81699c38f..d148b3471 100644 --- a/server/lorax_server/models/flash_bert.py +++ b/server/lorax_server/models/flash_bert.py @@ -172,7 +172,8 @@ def __init__( else: raise NotImplementedError("FlashSantacoderSharded is only available on GPU") - self.device = device + # self.device = device + self.device = "cpu" self.dtype = dtype tokenizer = AutoTokenizer.from_pretrained(model_id) @@ -214,7 +215,10 @@ def supports_text_generation(self) -> bool: return False def warmup(self, batch: FlashEmbeddingBatch, max_new_tokens: int) -> int | None: - return 42 # no-op for now + # Note: This is meant to 1) preallocate the memory by doing a forward pass + # and then just returning the max seqlen since for embeddings we are never generating + _ = self.embed(batch) + return batch.max_s def generate_token(self, batch: FlashEmbeddingBatch) -> None: if not self.supports_text_generation: From 9c688c28663b113114c82fdcad1c87043019dd1d Mon Sep 17 00:00:00 2001 From: Magdy Saleh Date: Thu, 6 Jun 2024 21:14:59 +0000 Subject: [PATCH 30/31] fix --- server/lorax_server/models/flash_bert.py | 2 +- server/lorax_server/models/types.py | 17 +++++++++++++---- server/lorax_server/server.py | 6 ++++-- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/server/lorax_server/models/flash_bert.py b/server/lorax_server/models/flash_bert.py index d148b3471..beaa7f904 100644 --- a/server/lorax_server/models/flash_bert.py +++ b/server/lorax_server/models/flash_bert.py @@ -240,4 +240,4 @@ def embed(self, batch: FlashEmbeddingBatch) -> Embedding: embedding = embedding.reshape(embedding.shape[0], -1)[:, : self.hidden_size] cpu_results = embedding.cpu().tolist() - return Embedding(values=cpu_results) + return cpu_results diff --git a/server/lorax_server/models/types.py b/server/lorax_server/models/types.py index 84ca35c6c..18accb3ff 100644 --- a/server/lorax_server/models/types.py +++ b/server/lorax_server/models/types.py @@ -57,7 +57,6 @@ def to_pb(self) -> generate_pb2.GeneratedText: seed=self.seed, ) - @dataclass class PrefillTokens: token_ids: List[int] @@ -129,6 +128,7 @@ def to_pb(self) -> generate_pb2.Generation: @dataclass class FlashEmbeddingBatch(ABC): + request_ids: List[int] input_ids: torch.Tensor token_type_ids: torch.Tensor position_ids: torch.Tensor @@ -162,18 +162,23 @@ def from_pb( truncation=True, max_length=max_truncation, ) + batch_tokenized_inputs = batch_inputs["input_ids"] + batch_token_type_ids = batch_inputs["token_type_ids"] all_input_ids = [] position_ids = [] + all_token_type_ids = [] cu_seqlens = [0] max_s = 0 cumulative_length = 0 - - for i, (r, tokenized_input) in enumerate(zip(pb.requests, batch_tokenized_inputs)): + + for i, (r, tokenized_input, token_type_ids) in enumerate(zip(pb.requests, batch_tokenized_inputs, batch_token_type_ids)): tokenized_input = tokenized_input[-r.truncate :] + token_type_ids = token_type_ids[-r.truncate :] all_input_ids.append(tokenized_input) + all_token_type_ids.append(token_type_ids) input_length = len(tokenized_input) max_s = max(max_s, input_length) @@ -187,17 +192,21 @@ def from_pb( if len(pb.requests) > 1: input_ids = np.concatenate(all_input_ids, dtype=np.int64) + final_token_type_ids = np.concatenate(all_token_type_ids, dtype=np.int64) position_ids = torch.cat(position_ids) else: input_ids = all_input_ids[0] + final_token_type_ids = all_token_type_ids[0] position_ids = position_ids[0] input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + final_token_type_ids = torch.tensor(final_token_type_ids, dtype=torch.int64, device=device) position_ids = position_ids.to(device) return FlashEmbeddingBatch( + request_ids=[r.id for r in pb.requests], input_ids=input_ids, - token_type_ids=torch.tensor(batch_inputs["token_type_ids"], dtype=torch.int32, device=device), + token_type_ids=final_token_type_ids, position_ids=position_ids, cu_seqlens=torch.tensor(cu_seqlens, dtype=torch.int32, device=device), max_s=max_s, diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index f98d7a9f8..63cc39c77 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -96,7 +96,6 @@ async def Prefill(self, request: generate_pb2.PrefillRequest, context): ) async def Embed(self, request: generate_pb2.EmbedRequest, context): - print("!!! EMBED") if not self.model.supports_embeddings: raise ValueError("Model does not support embeddings") @@ -108,7 +107,10 @@ async def Embed(self, request: generate_pb2.EmbedRequest, context): self.model.device, ) embeddings = self.model.embed(batch) - return generate_pb2.EmbedResponse(embeddings=embeddings) + embeddings_proto = [] + for i, embedding in enumerate(embeddings): + embeddings_proto.append(generate_pb2.Embedding(request_id=batch.request_ids[i], values=embedding)) + return generate_pb2.EmbedResponse(embeddings=embeddings_proto) async def Decode(self, request: generate_pb2.DecodeRequest, context): if len(request.batches) == 0: From 8c20c5e6d90650d5627f23fb2bab7bc131b3ca9e Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 7 Jun 2024 22:23:09 -0700 Subject: [PATCH 31/31] Fix readme --- README.md | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/README.md b/README.md index 637ba3a3d..25ce24621 100644 --- a/README.md +++ b/README.md @@ -100,17 +100,6 @@ curl 127.0.0.1:8080/generate \ -H 'Content-Type: application/json' ``` -Embed: - -```shell -curl 127.0.0.1:8080/embed \ --X POST \ --d '{ -"inputs": "[INST] Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? [/INST]" -}' \ --H 'Content-Type: application/json' -``` - Prompt a LoRA adapter: ```shell