diff --git a/router/src/batch.rs b/router/src/batch.rs index 50b91550b..2be4bcdbe 100644 --- a/router/src/batch.rs +++ b/router/src/batch.rs @@ -259,6 +259,7 @@ pub(crate) trait BatchEntries: Sync + Send + Debug { blocks: Vec, slots: Vec, prefix_len: u32, + chunk_len: Option, ); fn extend(&mut self, entries: Box); fn drain(&mut self) -> Vec<(Adapter, u64, Entry)>; @@ -323,6 +324,7 @@ impl BatchEntries for GenerateBatchEntries { blocks: Vec, slots: Vec, prefix_len: u32, + chunk_len: Option, ) { let valid_request = entry .request @@ -343,7 +345,7 @@ impl BatchEntries for GenerateBatchEntries { blocks, slots, cache_len: prefix_len, - chunk_len: None, + chunk_len: chunk_len, }; self.state.add(id, entry, adapter, request_proto); @@ -455,6 +457,7 @@ impl BatchEntries for EmbedBatchEntries { blocks: Vec, slots: Vec, prefix_len: u32, + chunk_len: Option, ) { let valid_request = entry .request @@ -475,7 +478,7 @@ impl BatchEntries for EmbedBatchEntries { blocks, slots, cache_len: prefix_len, - chunk_len: None, + chunk_len: chunk_len, }; self.state.add(id, entry, adapter, request_proto); @@ -580,6 +583,7 @@ impl BatchEntries for ClassifyBatchEntries { blocks: Vec, slots: Vec, prefix_len: u32, + chunk_len: Option, ) { let valid_request = entry .request @@ -600,7 +604,7 @@ impl BatchEntries for ClassifyBatchEntries { blocks, slots, cache_len: prefix_len, - chunk_len: None, + chunk_len: chunk_len, }; self.state.add(id, entry, adapter, request_proto); diff --git a/router/src/infer.rs b/router/src/infer.rs index ba57d206d..74788bda0 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -200,6 +200,7 @@ impl Infer { speculate, max_batch_total_tokens, prefix_caching, + chunked_prefill, is_causal_lm, ); diff --git a/router/src/scheduler.rs b/router/src/scheduler.rs index ff5f4be73..f7154e891 100644 --- a/router/src/scheduler.rs +++ b/router/src/scheduler.rs @@ -44,6 +44,7 @@ impl AdapterScheduler { speculate: u32, max_batch_total_tokens: u32, prefix_caching: bool, + chunked_prefill: bool, is_causal_lm: bool, ) -> Self { let (sender, receiver) = flume::unbounded(); @@ -61,6 +62,7 @@ impl AdapterScheduler { speculate, max_batch_total_tokens, prefix_caching, + chunked_prefill, is_causal_lm, )); @@ -126,6 +128,7 @@ async fn adapter_scheduler_task( speculate: u32, max_batch_total_tokens: u32, prefix_caching: bool, + chunked_prefill: bool, is_causal_lm: bool, ) { let mut state = AdapterSchedulerState::new( @@ -138,6 +141,7 @@ async fn adapter_scheduler_task( speculate, max_batch_total_tokens, prefix_caching, + chunked_prefill, is_causal_lm, ); @@ -198,6 +202,9 @@ struct AdapterSchedulerState { /// Speculation amount speculate: u32, + /// Chunked prefill + chunked_prefill: bool, + /// Paged Attention Block Allocation block_allocator: Option, } @@ -213,6 +220,7 @@ impl AdapterSchedulerState { speculate: u32, max_batch_total_tokens: u32, prefix_caching: bool, + chunked_prefill: bool, is_causal_lm: bool, ) -> Self { let queues_state = Arc::new(Mutex::new(AdapterQueuesState::new( @@ -239,6 +247,7 @@ impl AdapterSchedulerState { block_size, window_size, speculate, + chunked_prefill, block_allocator, } } @@ -282,6 +291,10 @@ impl AdapterSchedulerState { prefill_token_budget: u32, token_budget: u32, ) -> Option { + if prefill_token_budget == 0 || token_budget == 0 { + return None; + }; + let num_entries = self.queues_state.lock().await.len(); if num_entries == 0 { return None; @@ -331,6 +344,8 @@ impl AdapterSchedulerState { batch_requests_len = batch_entries.len(); } + let mut should_break = false; + let mut chunk_len = None; let block_allocation = match &self.block_allocator { None => { // We pad to max input length in the Python shards @@ -354,31 +369,6 @@ impl AdapterSchedulerState { None } Some(block_allocator) => { - prefill_tokens += entry.request.input_length(); - let max_new_tokens = match self.window_size { - None => entry.request.max_new_tokens(), - Some(window_size) => min( - window_size.saturating_sub(entry.request.input_length()), - entry.request.max_new_tokens(), - ), - }; - decode_tokens += max_new_tokens; - - // If we're prefix caching, this check could be under-estimating the number of available blocks - // due to shared prefixes, so we'll let the block allocator determine whether we have enough space. - if prefill_tokens > prefill_token_budget - || (prefill_tokens + decode_tokens + self.speculate) > token_budget - { - // Entry is over budget - // Add it back to the front - tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); - self.queues_state - .lock() - .await - .push_front(&adapter, id, entry); - break; - } - let tokens = entry.request.input_length() + entry.request.max_new_tokens() + self.speculate @@ -392,7 +382,7 @@ impl AdapterSchedulerState { self.speculate ); - match block_allocator + let block_allocation = match block_allocator .allocate(adapter.index(), tokens, entry.request.input_ids()) .await { @@ -406,12 +396,63 @@ impl AdapterSchedulerState { .push_front(&adapter, id, entry); break 'entry_loop; } - Some(block_allocation) => { + Some(mut block_allocation) => { tracing::debug!("Allocation: {block_allocation:?}"); max_blocks = max(max_blocks, block_allocation.blocks.len() as u32); - Some(block_allocation) + + if block_allocation.prefix_len == entry.request.input_length() { + // The whole request was found in the radix trie + // However, for the transformer forward to work, we need to + // have at least one token of postfix. + block_allocation.prefix_len -= 1; + } + + block_allocation + } + }; + + let postfix_len = entry.request.input_length() - block_allocation.prefix_len; + + if prefill_tokens + postfix_len > prefill_token_budget { + // Entry is over budget + if self.chunked_prefill { + // We support chunking, just set postfix_len to exactly match prefill_token_budget + let entry_chunk_len = + prefill_token_budget.saturating_sub(prefill_tokens); + if entry_chunk_len > 0 { + chunk_len = Some(entry_chunk_len); + } else { + // We cannot prefill even one token for this entry + // Add it back to the queue + self.queues_state + .lock() + .await + .push_front(&adapter, id, entry); + break 'entry_loop; + } + tracing::debug!( + "Matched budget: prefill_tokens={} == {prefill_token_budget}", + prefill_tokens + postfix_len + ); + should_break = true; + } else { + // We don't support chunking, this entry needs to go back to the buffer + // Add it back to the front + tracing::debug!( + "Over budget: prefill_tokens={} > {prefill_token_budget}", + prefill_tokens + postfix_len + ); + self.queues_state + .lock() + .await + .push_front(&adapter, id, entry); + break 'entry_loop; } } + + prefill_tokens += postfix_len; + + Some(block_allocation) } }; @@ -454,7 +495,11 @@ impl AdapterSchedulerState { batch_entries .as_mut() .unwrap() - .add(id, entry, adapter, blocks, slots, prefix_len); + .add(id, entry, adapter, blocks, slots, prefix_len, chunk_len); + + if should_break { + break 'entry_loop; + } } if batch_entries.is_none() {