diff --git a/Cargo.lock b/Cargo.lock index 281d0be98..5f986f527 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1298,6 +1298,7 @@ dependencies = [ "reqwest-retry", "serde", "serde_json", + "slotmap", "thiserror", "tokenizers", "tokio", @@ -2705,9 +2706,9 @@ dependencies = [ [[package]] name = "slotmap" -version = "1.0.6" +version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1e08e261d0e8f5c43123b7adf3e4ca1690d655377ac93a03b2c9d3e98de1342" +checksum = "dbff4acf519f630b3a3ddcfaea6c06b42174d9a44bc70c620e9ed1649d58b82a" dependencies = [ "version_check", ] diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 533fded4e..56263d0a5 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -298,6 +298,11 @@ struct Args { #[clap(long, env)] eager_prefill: Option, + /// Whether to use the prefix caching mechanism. + /// TODO(travis): better comment here + #[clap(long, env)] + prefix_caching: Option, + /// Maximum number of adapters that can be placed on the GPU and accept requests at a time. #[clap(default_value = "1024", long, env)] max_active_adapters: usize, @@ -440,6 +445,7 @@ fn shard_manager( watermark_delta: Option, cuda_memory_fraction: f32, adapter_memory_fraction: f32, + prefix_caching: Option, otlp_endpoint: Option, status_sender: mpsc::Sender, shutdown: Arc, @@ -548,6 +554,11 @@ fn shard_manager( adapter_memory_fraction.to_string().into(), )); + // Prefix caching + if let Some(prefix_caching) = prefix_caching { + envs.push(("PREFIX_CACHING".into(), prefix_caching.to_string().into())); + } + // Safetensors load fast envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into())); @@ -984,6 +995,7 @@ fn spawn_shards( let watermark_delta = args.watermark_delta; let cuda_memory_fraction = args.cuda_memory_fraction; let adapter_memory_fraction = args.adapter_memory_fraction; + let prefix_caching = args.prefix_caching; thread::spawn(move || { shard_manager( model_id, @@ -1009,6 +1021,7 @@ fn spawn_shards( watermark_delta, cuda_memory_fraction, adapter_memory_fraction, + prefix_caching, otlp_endpoint, status_sender, shutdown, @@ -1164,6 +1177,10 @@ fn spawn_webserver( router_args.push("--eager-prefill".to_string()); } + if args.prefix_caching.unwrap_or(false) { + router_args.push("--prefix-caching".to_string()); + } + // Ngrok if args.ngrok { router_args.push("--ngrok".to_string()); diff --git a/proto/generate.proto b/proto/generate.proto index 06e0f9787..96f64a586 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -135,6 +135,8 @@ message Request { repeated uint32 blocks = 9; /// Paged attention slots repeated uint32 slots = 10; + /// Prefix length that can be retrieved from the KV cache + uint32 prefix_len = 11; } message Batch { diff --git a/router/Cargo.toml b/router/Cargo.toml index a8d94d2fb..865dbf09c 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -35,6 +35,7 @@ reqwest-retry = "0.4.0" regex = "1.5.4" serde = "1.0.152" serde_json = { version = "1.0.93", features = ["preserve_order"] } +slotmap = "1.0.7" thiserror = "1.0.38" tokenizers = { version = "0.19.1", features = ["http"] } tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 94f0d73da..160a73905 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -122,6 +122,7 @@ impl Client { // Blocks and slots will be set on the server side if we use paged attention blocks: vec![], slots: vec![], + prefix_len: 0, // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { temperature: 0.9, diff --git a/router/src/batch.rs b/router/src/batch.rs index 5156736cd..31047c2a9 100644 --- a/router/src/batch.rs +++ b/router/src/batch.rs @@ -24,6 +24,7 @@ use crate::{ pub(crate) trait ValidRequest: Sync + Send + Debug + Any { fn input_length(&self) -> u32; + fn input_ids(&self) -> Option>>; fn max_new_tokens(&self) -> u32; fn adapter(&self) -> Adapter; fn to_batch(&self, num_entries: usize, queue_len: usize) -> Box; @@ -35,6 +36,14 @@ impl ValidRequest for ValidGenerateRequest { self.input_length } + fn input_ids(&self) -> Option>> { + if let Some(tokenized_inputs) = &self.tokenized_inputs { + Some(Arc::new(tokenized_inputs.ids.clone())) + } else { + None + } + } + fn max_new_tokens(&self) -> u32 { self.stopping_parameters.max_new_tokens } @@ -65,6 +74,14 @@ impl ValidRequest for ValidEmbedRequest { self.input_length } + fn input_ids(&self) -> Option>> { + if let Some(tokenized_inputs) = &self.tokenized_inputs { + Some(Arc::new(tokenized_inputs.ids.clone())) + } else { + None + } + } + fn max_new_tokens(&self) -> u32 { 1 } @@ -95,6 +112,14 @@ impl ValidRequest for ValidClassifyRequest { self.input_length } + fn input_ids(&self) -> Option>> { + if let Some(tokenized_inputs) = &self.tokenized_inputs { + Some(Arc::new(tokenized_inputs.ids.clone())) + } else { + None + } + } + fn max_new_tokens(&self) -> u32 { 1 } @@ -227,7 +252,15 @@ impl BatchEntriesState { #[async_trait] pub(crate) trait BatchEntries: Sync + Send + Debug { fn can_add(&self, entry: &Entry) -> bool; - fn add(&mut self, id: u64, entry: Entry, adapter: Adapter, blocks: Vec, slots: Vec); + fn add( + &mut self, + id: u64, + entry: Entry, + adapter: Adapter, + blocks: Vec, + slots: Vec, + prefix_len: u32, + ); fn extend(&mut self, entries: Box); fn drain(&mut self) -> Vec<(Adapter, u64, Entry)>; fn create_batch_data(&self, batch_id: u64, max_tokens: u32, max_blocks: u32) -> Batch; @@ -281,7 +314,15 @@ impl BatchEntries for GenerateBatchEntries { result } - fn add(&mut self, id: u64, entry: Entry, adapter: Adapter, blocks: Vec, slots: Vec) { + fn add( + &mut self, + id: u64, + entry: Entry, + adapter: Adapter, + blocks: Vec, + slots: Vec, + prefix_len: u32, + ) { let valid_request = entry .request .as_ref() @@ -300,6 +341,7 @@ impl BatchEntries for GenerateBatchEntries { adapter_index: adapter.index(), blocks, slots, + prefix_len, }; self.state.add(id, entry, adapter, request_proto); @@ -401,7 +443,15 @@ impl BatchEntries for EmbedBatchEntries { result } - fn add(&mut self, id: u64, entry: Entry, adapter: Adapter, blocks: Vec, slots: Vec) { + fn add( + &mut self, + id: u64, + entry: Entry, + adapter: Adapter, + blocks: Vec, + slots: Vec, + prefix_len: u32, + ) { let valid_request = entry .request .as_ref() @@ -420,6 +470,7 @@ impl BatchEntries for EmbedBatchEntries { adapter_index: adapter.index(), blocks, slots, + prefix_len, }; self.state.add(id, entry, adapter, request_proto); @@ -515,7 +566,15 @@ impl BatchEntries for ClassifyBatchEntries { result } - fn add(&mut self, id: u64, entry: Entry, adapter: Adapter, blocks: Vec, slots: Vec) { + fn add( + &mut self, + id: u64, + entry: Entry, + adapter: Adapter, + blocks: Vec, + slots: Vec, + prefix_len: u32, + ) { let valid_request = entry .request .as_ref() @@ -534,6 +593,7 @@ impl BatchEntries for ClassifyBatchEntries { adapter_index: adapter.index(), blocks, slots, + prefix_len, }; self.state.add(id, entry, adapter, request_proto); diff --git a/router/src/block_allocator.rs b/router/src/block_allocator.rs index 9845c24a1..05c2bd30d 100644 --- a/router/src/block_allocator.rs +++ b/router/src/block_allocator.rs @@ -1,16 +1,26 @@ -use std::cmp::min; +use std::{cmp::min, sync::Arc}; use tokio::sync::{mpsc, oneshot}; +use crate::radix::RadixAllocator; + #[derive(Debug, Clone)] pub(crate) struct BlockAllocation { + pub allocation_id: u64, pub blocks: Vec, pub slots: Vec, - block_allocator: BlockAllocator, + + /// Prefix that was cached and for which the KV does not have to + /// be recomputed. + pub prefix_len: u32, + + pub(crate) block_allocator: Option, } impl Drop for BlockAllocation { fn drop(&mut self) { - self.block_allocator.free(self.blocks.clone()) + if let Some(block_allocator) = self.block_allocator.as_mut() { + block_allocator.free(self.blocks.clone(), self.allocation_id) + } } } @@ -24,6 +34,7 @@ impl BlockAllocator { pub(crate) fn new( max_batch_total_tokens: u32, block_size: u32, + prefix_caching: bool, window_size: Option, ) -> Self { // Create channel @@ -33,6 +44,7 @@ impl BlockAllocator { tokio::spawn(block_allocator_task( max_batch_total_tokens / block_size, block_size, + prefix_caching, window_size, receiver, )); @@ -42,28 +54,32 @@ impl BlockAllocator { } } - pub(crate) async fn allocate(&self, tokens: u32) -> Option { + pub(crate) async fn allocate( + &self, + tokens: u32, + prefill_tokens: Option>>, + ) -> Option { let (response_sender, response_receiver) = oneshot::channel(); self.block_allocator .send(BlockAllocatorCommand::Allocate { tokens, + prefill_tokens, response_sender, }) .unwrap(); - response_receiver - .await - .unwrap() - .map(|(blocks, slots)| BlockAllocation { - blocks, - slots, - block_allocator: self.clone(), - }) + response_receiver.await.unwrap().map(|mut allocation| { + allocation.block_allocator = Some(self.clone()); + allocation + }) } - pub(crate) fn free(&self, blocks: Vec) { + pub(crate) fn free(&self, blocks: Vec, allocation_id: u64) { self.block_allocator - .send(BlockAllocatorCommand::Free { blocks }) + .send(BlockAllocatorCommand::Free { + allocation_id, + blocks, + }) .unwrap(); } } @@ -71,60 +87,29 @@ impl BlockAllocator { async fn block_allocator_task( blocks: u32, block_size: u32, + prefix_caching: bool, window_size: Option, mut receiver: mpsc::UnboundedReceiver, ) { - // Block 0 is reserved for health checks - let mut free_blocks: Vec = (1..blocks).collect(); + let mut allocator: Box = if prefix_caching { + Box::new(RadixAllocator::new(block_size, blocks, window_size)) + } else { + Box::new(SimpleAllocator::new(blocks, block_size, window_size)) + }; while let Some(cmd) = receiver.recv().await { match cmd { - BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks), + BlockAllocatorCommand::Free { + blocks, + allocation_id, + } => allocator.free(blocks, allocation_id), BlockAllocatorCommand::Allocate { tokens, + prefill_tokens, response_sender, } => { - // Apply window size - let (required_blocks, repeats) = { - let (tokens, repeats) = match window_size { - None => (tokens, 1), - Some(window_size) => { - let repeats = (tokens + window_size - 1) / window_size; - let tokens = min(tokens, window_size); - (tokens, repeats as usize) - } - }; - // Pad to a multiple of block size - let required_blocks = (tokens + block_size - 1) / block_size; - (required_blocks, repeats) - }; - - let tokens = tokens as usize; - tracing::trace!( - "Allocating {} tokens ({} blocks, {} repeats)", - tokens, - required_blocks, - repeats - ); - let allocation = if required_blocks > free_blocks.len() as u32 { - None - } else { - let blocks = - free_blocks.split_off(free_blocks.len() - required_blocks as usize); - let mut slots = Vec::with_capacity( - (required_blocks * block_size * repeats as u32) as usize, - ); - - 'slots: for block_id in blocks.repeat(repeats).iter() { - for s in (block_id * block_size)..((block_id + 1) * block_size) { - slots.push(s); - if slots.len() == tokens { - break 'slots; - } - } - } - Some((blocks, slots)) - }; - response_sender.send(allocation).unwrap(); + response_sender + .send(allocator.allocate(tokens, prefill_tokens)) + .unwrap(); } } } @@ -134,9 +119,92 @@ async fn block_allocator_task( enum BlockAllocatorCommand { Free { blocks: Vec, + allocation_id: u64, }, Allocate { tokens: u32, - response_sender: oneshot::Sender, Vec)>>, + prefill_tokens: Option>>, + response_sender: oneshot::Sender>, }, } + +pub(crate) trait Allocator { + fn allocate( + &mut self, + tokens: u32, + prefill_tokens: Option>>, + ) -> Option; + + fn free(&mut self, blocks: Vec, allocation_id: u64); +} + +pub struct SimpleAllocator { + free_blocks: Vec, + block_size: u32, + window_size: Option, +} + +impl SimpleAllocator { + fn new(blocks: u32, block_size: u32, window_size: Option) -> Self { + SimpleAllocator { + block_size, + // Block 0 is reserved for health checks + free_blocks: (1..blocks).collect(), + window_size, + } + } +} + +impl Allocator for SimpleAllocator { + fn allocate( + &mut self, + tokens: u32, + _prefill_tokens: Option>>, + ) -> Option { + // Apply window size + let (required_blocks, repeats) = { + let (tokens, repeats) = match self.window_size { + None => (tokens, 1), + Some(window_size) => { + let repeats = (tokens + window_size - 1) / window_size; + let tokens = min(tokens, window_size); + (tokens, repeats as usize) + } + }; + // Pad to a multiple of block size + let required_blocks = (tokens + self.block_size - 1) / self.block_size; + (required_blocks, repeats) + }; + + let tokens = tokens as usize; + if required_blocks > self.free_blocks.len() as u32 { + None + } else { + let blocks = self + .free_blocks + .split_off(self.free_blocks.len() - required_blocks as usize); + let mut slots = + Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize); + + 'slots: for block_id in blocks.repeat(repeats).iter() { + for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) { + slots.push(s); + if slots.len() == tokens { + break 'slots; + } + } + } + Some(BlockAllocation { + allocation_id: 0, + blocks, + slots, + prefix_len: 0, + block_allocator: None, + }) + } + } + + fn free(&mut self, blocks: Vec, _allocation_id: u64) { + self.free_blocks.extend(blocks) + } +} diff --git a/router/src/health.rs b/router/src/health.rs index d0216040f..0dd000fa5 100644 --- a/router/src/health.rs +++ b/router/src/health.rs @@ -58,6 +58,7 @@ impl Health { // Block 0 is reserved for health checks blocks: vec![0], slots: (0..16).collect(), + prefix_len: 0, }; let batch = Batch { id: BATCH_ID, diff --git a/router/src/infer.rs b/router/src/infer.rs index 4b8f08dd4..4d86d18fa 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -159,6 +159,7 @@ impl Infer { block_size: u32, speculate: u32, preloaded_adapters: Vec, + prefix_caching: bool, ) -> Self { let adapter_event = Arc::new(AdapterEvent { batching_task: Notify::new(), @@ -175,6 +176,7 @@ impl Infer { adapter_cycle_time_s, speculate, max_batch_total_tokens, + prefix_caching, ); // Initialize with base model adapter (empty) mapping to index 0 diff --git a/router/src/lib.rs b/router/src/lib.rs index 8ae5ac4a1..5e9bb670d 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -6,6 +6,7 @@ mod health; mod infer; mod loader; mod queue; +mod radix; mod scheduler; pub mod server; diff --git a/router/src/main.rs b/router/src/main.rs index 9282998e8..8568349e4 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -90,6 +90,8 @@ struct Args { adapter_source: String, #[clap(long, env)] eager_prefill: bool, + #[clap(long, env)] + prefix_caching: bool, } #[tokio::main] @@ -129,6 +131,7 @@ async fn main() -> Result<(), RouterError> { ngrok_edge, adapter_source, eager_prefill, + prefix_caching, } = args; init_logging(otlp_endpoint, json_output); @@ -463,6 +466,7 @@ async fn main() -> Result<(), RouterError> { adapter_source, embedding_model, eager_prefill, + prefix_caching, ) .await?; Ok(()) diff --git a/router/src/radix.rs b/router/src/radix.rs new file mode 100644 index 000000000..0464b9f8c --- /dev/null +++ b/router/src/radix.rs @@ -0,0 +1,755 @@ +use std::{ + collections::{BTreeSet, HashMap}, + sync::Arc, +}; + +use slotmap::{DefaultKey, SlotMap}; + +use crate::block_allocator::{Allocator, BlockAllocation}; + +pub struct RadixAllocator { + allocation_id: u64, + + allocations: HashMap, + + cache_blocks: RadixTrie, + + /// Blocks that are immediately available for allocation. + free_blocks: Vec, +} + +impl RadixAllocator { + pub fn new(block_size: u32, n_blocks: u32, window_size: Option) -> Self { + assert_eq!( + block_size, 1, + "Radix tree allocator only works with block_size=1, was: {}", + block_size + ); + if window_size.is_some() { + unimplemented!("Window size not supported in the prefix-caching block allocator yet"); + } + + RadixAllocator { + allocation_id: 0, + allocations: HashMap::new(), + cache_blocks: RadixTrie::new(), + + // Block 0 is reserved for health checks. + free_blocks: (1..n_blocks).collect(), + } + } + + fn alloc_or_reclaim(&mut self, n_blocks_needed: usize) -> Option> { + if self.free_blocks.len() < n_blocks_needed { + // This is a bit annoying, we first extend the free list and then + // split it off again below. This is because we need to put it on + // the free list if we cannot allocate enough blocks. This is only + // temporary, the trie needs to be able to report whether it can + // allocate the requested amount. Just not implemented yet. + self.free_blocks.extend( + self.cache_blocks + .evict(n_blocks_needed - self.free_blocks.len()), + ); + } + + if self.free_blocks.len() >= n_blocks_needed { + Some( + self.free_blocks + .split_off(self.free_blocks.len() - n_blocks_needed), + ) + } else { + None + } + } +} + +impl Allocator for RadixAllocator { + fn allocate( + &mut self, + tokens: u32, + prefill_tokens: Option>>, + ) -> Option { + let mut blocks = vec![]; + let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() { + let node_id = self + .cache_blocks + .find(prefill_tokens.as_slice(), &mut blocks); + // Even if this allocation fails below, we need to increase he + // refcount to ensure that the prefix that was found is not evicted. + + node_id + } else { + self.cache_blocks.root_id() + }; + + self.cache_blocks + .incref(prefix_node) + .expect("Failed to increment refcount"); + + let prefix_len = blocks.len(); + let suffix_len = tokens - prefix_len as u32; + + match self.alloc_or_reclaim(suffix_len as usize) { + Some(suffix_blocks) => blocks.extend(suffix_blocks), + None => { + self.cache_blocks + .decref(prefix_node) + .expect("Failed to decrement refcount"); + return None; + } + } + + // 1:1 mapping of blocks and slots. + let slots = blocks.clone(); + + let allocation = RadixAllocation { + prefix_node, + cached_prefix_len: prefix_len, + prefill_tokens: prefill_tokens.clone(), + }; + + self.allocation_id += 1; + self.allocations.insert(self.allocation_id, allocation); + + Some(BlockAllocation { + allocation_id: self.allocation_id, + block_allocator: None, + blocks, + slots, + prefix_len: prefix_len as u32, + }) + } + + fn free(&mut self, blocks: Vec, allocation_id: u64) { + let allocation = match self.allocations.remove(&allocation_id) { + Some(allocation) => allocation, + None => unreachable!("Tried to free an unknown allocation."), + }; + + self.cache_blocks + .decref(allocation.prefix_node) + .expect("Failed to decrement refcount"); + + if let Some(prefill_tokens) = allocation.prefill_tokens { + let prefill_tokens = prefill_tokens.as_slice(); + + // If there are prefill tokens that did not come from the cache, + // add them to the cache. + if prefill_tokens.len() > allocation.cached_prefix_len { + let prefix_len = self + .cache_blocks + .insert(prefill_tokens, &blocks[..prefill_tokens.len()]) + // Unwrap, failing is a programming error. + .expect("Failed to store prefill tokens"); + + // We can have a prefill with the following structure: + // + // |---| From the prefix cache. + // A B C D E F G + //|--------| Found in the trie during insertion. + // + // This means that while processing this request there was a + // partially overlapping request that had A..=E in its + // prefill. In this case we need to free the blocks D E. + self.free_blocks + .extend(&blocks[allocation.cached_prefix_len..prefix_len]); + } + + // Free non-prefill blocks. + self.free_blocks.extend(&blocks[prefill_tokens.len()..]); + } else { + self.free_blocks.extend(blocks); + } + } +} + +struct RadixAllocation { + prefix_node: NodeId, + cached_prefix_len: usize, + prefill_tokens: Option>>, +} + +// Radix trie that is heavily inspired by radix attention from sglang. +// +// The trie is optimized for prefix caching: +// +// - A normal radix trie stores discrete values. In this radix trie, +// inserting *abc* with value *xyz* will also enable lookup for +// *a* (*x*) and *ab* (*xy*). +// - As a result, every value is required to have the same length as +// the key. +// - We store additional information in each node, such as last access +// time and a reference count. + +#[derive(Debug)] +pub enum TrieError { + InvalidNodeId, + RefCountUnderflow, + BlockTokenCountMismatch, +} + +pub type NodeId = DefaultKey; + +#[derive(Debug)] +pub struct RadixTrie { + /// Identifier of the root nod. + root: DefaultKey, + + /// Leave node identifiers ordered by increasing recency. + leaves: BTreeSet<(u64, NodeId)>, + + /// All trie nodes. + nodes: SlotMap, + + /// Time as a monotonically increating counter to avoid the system + /// call that a real time lookup would require. + time: u64, +} + +impl RadixTrie { + /// Construct a new radix trie. + pub fn new() -> Self { + let root = TrieNode::new(vec![], vec![], 0, None); + let mut nodes = SlotMap::new(); + let root = nodes.insert(root); + RadixTrie { + leaves: BTreeSet::new(), + nodes, + root, + time: 0, + } + } + + /// Find the prefix of the given tokens. + /// + /// The blocks corresponding to the part of the prefix that could be found + /// are writteng to `blocks`. The number of blocks is in `0..=tokens.len()`. + /// Returns the identifier of the trie node that contains the longest + /// prefix. The node identifier can be used by callers to e.g. increase its + /// reference count. + /// + /// Using this method will update the access time of the traversed nodes. + pub fn find(&mut self, key: &[u32], blocks: &mut Vec) -> NodeId { + self.time += 1; + self.find_(self.root, key, blocks) + } + + /// Find worker. + fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec) -> NodeId { + let node = &self.nodes[node_id]; + + if let Some(&child_id) = node.children.get(&key[0]) { + self.update_access_time(child_id); + let child = self.nodes.get(child_id).expect("Invalid child identifier"); + let shared_prefix_len = child.key.shared_prefix_len(key); + blocks.extend(&child.blocks[..shared_prefix_len]); + + let key = &key[shared_prefix_len..]; + if !key.is_empty() { + node_id = self.find_(child_id, key, blocks); + } + } + + node_id + } + + /// Decrease the reference count of a node. + pub fn decref(&mut self, node_id: NodeId) -> Result<(), TrieError> { + // We don't care about refcounting for root, since it will never + // be evicted. + if node_id == self.root { + return Ok(()); + } + + let node = self + .nodes + .get_mut(node_id) + .ok_or(TrieError::InvalidNodeId)?; + if node.ref_count == 0 { + return Err(TrieError::RefCountUnderflow); + } + + node.ref_count -= 1; + if node.ref_count == 0 { + self.leaves.insert((node.last_accessed, node_id)); + } + + Ok(()) + } + + /// Increase the reference count of a node. + pub fn incref(&mut self, node_id: NodeId) -> Result<(), TrieError> { + if node_id == self.root { + return Ok(()); + } + + let node = self + .nodes + .get_mut(node_id) + .ok_or(TrieError::InvalidNodeId)?; + if node.ref_count == 0 { + self.leaves.remove(&(node.last_accessed, node_id)); + } + node.ref_count += 1; + + Ok(()) + } + + /// Evict `n_blocks` from the trie. + /// + /// Returns the evicted blocks. When the length is less than `n_blocks`, + /// not enough blocks could beevicted. + pub fn evict(&mut self, n_blocks: usize) -> Vec { + // NOTE: we don't return Result here. If any of the unwrapping fails, + // it's a programming error in the trie implementation, not a user + // error caused by e.g. an invalid argument. + + // TODO: add some bookkeeping in the future to check whether we can + // evict n_blocks and return `None` if we can't. We are now needlessly + // evicting prefixes from the cache in such a case. + let mut evicted = Vec::new(); + + while let Some((last_access, node_id)) = self.leaves.pop_first() { + let blocks_needed = n_blocks - evicted.len(); + + let node = self.nodes.get(node_id).expect("Leave does not exist"); + if blocks_needed >= node.blocks.len() { + // We need to evict the whole node if we need more blocks than it has. + let node = self.remove_node(node_id); + evicted.extend(node.blocks); + + if evicted.len() >= n_blocks { + break; + } + } else { + // The node has more blocks than needed, so we'll just remove + // the required number of blocks and leave the remaining blocks + // untouched. + let node = self.nodes.get_mut(node_id).expect("Leave does not exist"); + node.key.truncate(node.blocks.len() - blocks_needed); + evicted.extend(node.blocks.split_off(node.blocks.len() - blocks_needed)); + self.leaves.insert((last_access, node_id)); + break; + } + } + + evicted + } + + /// Insert a prefill along with its blocks. + /// + /// This method returns the length of the prefix that was already + /// in the trie. E.g. if the length is 10, this means that for + /// the first 10 elements of the tree **the blocks are not updated**. + pub fn insert(&mut self, tokens: &[u32], blocks: &[u32]) -> Result { + self.time += 1; + self.insert_(self.root, tokens, blocks) + } + + /// Insertion worker. + fn insert_( + &mut self, + node_id: NodeId, + tokens: &[u32], + blocks: &[u32], + ) -> Result { + // TODO: in the future we may want to check that the blocks match for + // the part of the prefix that is already in the trie to detect + // mismatches. + + if tokens.len() != blocks.len() { + return Err(TrieError::BlockTokenCountMismatch); + } + + if let Some(&child_id) = self.nodes[node_id].children.get(&tokens[0]) { + self.update_access_time(child_id); + let child = self + .nodes + .get_mut(child_id) + // Unwrap here, since failure is a bug. + .expect("Child node does not exist"); + let shared_prefix_len = child.key.shared_prefix_len(tokens); + + // We are done, the prefix is already in the trie. + if shared_prefix_len == tokens.len() { + return Ok(shared_prefix_len); + } + + // The node's prefix is a prefix of the insertion prefix. + if shared_prefix_len == child.key.len() { + return Ok(shared_prefix_len + + self.insert_( + child_id, + &tokens[shared_prefix_len..], + &blocks[shared_prefix_len..], + )?); + } + + // The node's prefix and the insertion prefix only match partially, + // split the node to just contain the matching part. Then insert the + // remainder of the prefix into the node again + let child_id = self.split_node(child_id, shared_prefix_len); + let key = &tokens[shared_prefix_len..]; + let blocks = &blocks[shared_prefix_len..]; + Ok(shared_prefix_len + self.insert_(child_id, key, blocks)?) + } else { + self.add_node(node_id, tokens, blocks); + Ok(0) + } + } + + fn split_node(&mut self, node_id: NodeId, prefix_len: usize) -> NodeId { + // We have to make the current node a child to ensure that its + // properties and node id stay the same. + + // This funcion unwraps, an invalid node_id is a programming error. + + let node = self + .nodes + .get_mut(node_id) + .expect("Node to-be split does not exist"); + let mut parent_key = node.key.split_off(prefix_len); + let mut parent_blocks = node.blocks.split_off(prefix_len); + + // Move first part of the prefix to the parent. We swap to avoid + // an allocation + copy for both splits of the key/blocks. + std::mem::swap(&mut node.key, &mut parent_key); + std::mem::swap(&mut node.blocks, &mut parent_blocks); + + let node_key = node.key[0]; + + let grandparent_id = node.parent.expect("Node does not have a parent"); + let parent_id = self.add_node(grandparent_id, parent_key, parent_blocks); + self.add_node_to_parent(parent_id, node_key, node_id); + + // Reborrow to make the borrow checker happy. + let node = self + .nodes + .get_mut(node_id) + .expect("Node to-be split does not exist"); + node.parent = Some(parent_id); + + parent_id + } + + /// Create a node and add it to the parent. + fn add_node( + &mut self, + parent_id: NodeId, + key: impl Into>, + blocks: impl Into>, + ) -> NodeId { + let key = key.into(); + let blocks = blocks.into(); + let first = key[0]; + + let child = TrieNode::new(key, blocks, self.time, Some(parent_id)); + let child_id = self.nodes.insert(child); + + self.add_node_to_parent(parent_id, first, child_id); + self.leaves.insert((self.time, child_id)); + + child_id + } + + /// Add a node to the parent. + fn add_node_to_parent(&mut self, parent_id: NodeId, first: u32, child_id: NodeId) { + // Unwrap here, passing in an unknown id is a programming error. + let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node"); + if parent.children.insert(first, child_id).is_none() { + // Only increase reference count if child does not replace another child. + self.incref(parent_id) + .expect("Failed to increase parent refcount"); + } + } + + /// Remove a node from the trie. + fn remove_node(&mut self, node_id: NodeId) -> TrieNode { + // Unwrap here, passing in an unknown id is a programming error. + let node = self.nodes.remove(node_id).expect("Unknown node"); + let parent_id = node.parent.expect("Attempted to remove root node"); + let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node"); + parent.children.remove(&node.key[0]); + self.decref(parent_id) + .expect("Failed to decrease parent refcount"); + self.nodes.remove(node_id); + node + } + + fn update_access_time(&mut self, node_id: NodeId) { + // Unwrap here, passing in an unknown id is a programming error. + let node = self.nodes.get_mut(node_id).expect("Unknown node"); + + // Update the ordered leaves set if the node is a leave. + if self.leaves.remove(&(node.last_accessed, node_id)) { + self.leaves.insert((self.time, node_id)); + } + + node.last_accessed = self.time; + } + + #[allow(dead_code)] + #[doc(hidden)] + /// Print debugging output for the trie. + /// + /// In contrast to `Debug` nicely formatted. + pub fn print_debug(&self) { + self.print_debug_(self.root, 0); + } + + fn print_debug_(&self, node_id: NodeId, indent: usize) { + let node = &self.nodes[node_id]; + eprintln!( + "{}{:?}, key: {:?}, blocks: {:?}, ref_count: {}, last_accessed: {}, parent: {:?}, children: {:?}", + " ".repeat(indent), + node_id, + node.key, + node.blocks, + node.ref_count, + node.last_accessed, + node.parent, + node.children + ); + for child_id in self.nodes[node_id].children.values() { + self.print_debug_(*child_id, indent + 2); + } + } + + pub(crate) fn root_id(&self) -> DefaultKey { + self.root + } +} + +/// Trie node. +#[derive(Debug)] +struct TrieNode { + blocks: Vec, + children: HashMap, + key: Vec, + last_accessed: u64, + parent: Option, + ref_count: usize, +} + +impl TrieNode { + fn new(key: Vec, blocks: Vec, last_accessed: u64, parent: Option) -> Self { + TrieNode { + children: HashMap::new(), + key, + blocks, + last_accessed, + parent, + ref_count: 0, + } + } +} + +/// Helper trait to get the length of the shared prefix of two sequences. +trait SharedPrefixLen { + fn shared_prefix_len(&self, other: &Self) -> usize; +} + +impl SharedPrefixLen for [T] +where + T: PartialEq, +{ + fn shared_prefix_len(&self, other: &Self) -> usize { + self.iter().zip(other).take_while(|(a, b)| a == b).count() + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::block_allocator::Allocator; + + use super::RadixAllocator; + + #[test] + fn allocator_reuses_prefixes() { + let mut cache = RadixAllocator::new(1, 12, None); + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]); + assert_eq!(allocation.slots, allocation.slots); + assert_eq!(allocation.prefix_len, 0); + cache.free(allocation.blocks.clone(), allocation.allocation_id); + + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]); + assert_eq!(allocation.prefix_len, 4); + } + + #[test] + fn allocator_collects_older_prefixes_first() { + let mut cache = RadixAllocator::new(1, 7, None); + let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation1.blocks, vec![3, 4, 5, 6]); + assert_eq!(allocation1.prefix_len, 0); + + let allocation2 = cache.allocate(2, Some(Arc::new(vec![4, 5]))).unwrap(); + assert_eq!(allocation2.blocks, vec![1, 2]); + assert_eq!(allocation2.prefix_len, 0); + + cache.free(allocation1.blocks.clone(), allocation1.allocation_id); + cache.free(allocation2.blocks.clone(), allocation2.allocation_id); + + // We should get the blocks of the first allocation, since they are more recent. + let allocation3 = cache.allocate(4, Some(Arc::new(vec![6, 7, 8, 9]))).unwrap(); + assert_eq!(allocation3.blocks, vec![3, 4, 5, 6]); + assert_eq!(allocation3.prefix_len, 0); + } + + #[test] + fn allocator_frees_fully_overlapping_prefills() { + let mut cache = RadixAllocator::new(1, 10, None); + let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + let allocation2 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + + cache.free(allocation2.blocks.clone(), allocation2.allocation_id); + cache.free(allocation1.blocks.clone(), allocation1.allocation_id); + + let allocation3 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation3.prefix_len, 4); + + // 10 blocks, of which 1 reserved for health checks, 4 for the cached blocks. + assert_eq!(cache.free_blocks.len(), 5); + } + + #[test] + fn allocator_frees_partially_overlapping_prefills() { + let mut cache = RadixAllocator::new(1, 20, None); + let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1]))).unwrap(); + assert_eq!(allocation1.blocks, vec![16, 17, 18, 19]); + assert_eq!(allocation1.prefix_len, 0); + + cache.free(allocation1.blocks.clone(), allocation1.allocation_id); + + let allocation2 = cache + .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 4, 5]))) + .unwrap(); + assert_eq!(allocation2.blocks, vec![16, 17, 12, 13, 14, 15, 18, 19]); + assert_eq!(allocation2.prefix_len, 2); + + let allocation3 = cache + .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 6, 7]))) + .unwrap(); + assert_eq!(allocation3.blocks, vec![16, 17, 6, 7, 8, 9, 10, 11]); + assert_eq!(allocation3.prefix_len, 2); + + cache.free(allocation3.blocks.clone(), allocation3.allocation_id); + cache.free(allocation2.blocks.clone(), allocation2.allocation_id); + + // 20 blocks, of which 1 reserved for health checks, 6 for allocation3, 2 for allocation2. + assert_eq!(cache.free_blocks.len(), 11); + + let allocation4 = cache + .allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 4, 5]))) + .unwrap(); + assert_eq!(allocation4.blocks, vec![16, 17, 6, 7, 14, 15]); + assert_eq!(allocation4.prefix_len, 6); + assert_eq!(cache.free_blocks.len(), 11); + + let allocation5 = cache + .allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 6, 7]))) + .unwrap(); + assert_eq!(allocation5.blocks, vec![16, 17, 6, 7, 8, 9]); + assert_eq!(allocation5.prefix_len, 6); + assert_eq!(cache.free_blocks.len(), 11); + } + + #[test] + fn trie_insertions_have_correct_prefix_len() { + let mut trie = super::RadixTrie::new(); + + assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 0); + + // Already exists. + assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 3); + + // Completely new at root-level + assert_eq!(trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(), 0); + + // Contains full prefix, but longer. + assert_eq!(trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(), 3); + + // Shares partial prefix, we need a split. + assert_eq!( + trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) + .unwrap(), + 4 + ); + } + + #[test] + fn trie_get_returns_correct_blocks() { + let mut trie = super::RadixTrie::new(); + trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(); + trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(); + trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(); + trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) + .unwrap(); + + let mut blocks = Vec::new(); + trie.find(&[0], &mut blocks); + assert_eq!(blocks, vec![0]); + + blocks.clear(); + trie.find(&[0, 1, 2], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2]); + + blocks.clear(); + trie.find(&[1, 2, 3], &mut blocks); + assert_eq!(blocks, vec![1, 2, 3]); + + blocks.clear(); + trie.find(&[0, 1, 2, 3], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3]); + + blocks.clear(); + trie.find(&[0, 1, 2, 3, 4], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3, 4]); + + blocks.clear(); + trie.find(&[0, 1, 2, 3, 5], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3, 5]); + } + + #[test] + fn trie_evict_removes_correct_blocks() { + let mut trie = super::RadixTrie::new(); + trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(); + trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) + .unwrap(); + trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(); + trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(); + + let mut blocks = Vec::new(); + + // Remove less than the leave blocks. + assert_eq!(trie.evict(1), vec![7]); + trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3, 5, 6]); + + // Refresh other leaf. + trie.find(&[0, 1, 2, 3, 4], &mut blocks); + trie.find(&[1, 2, 3], &mut blocks); + + // Remove the leave blocks exactly. + assert_eq!(trie.evict(2), vec![5, 6]); + blocks.clear(); + trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3]); + + trie.find(&[1, 2, 3], &mut blocks); + + // Remove more than the leave blocks. + assert_eq!(trie.evict(3), vec![4, 3, 2]); + blocks.clear(); + trie.find(&[0, 1, 2, 3, 4], &mut blocks); + assert_eq!(blocks, vec![0, 1]); + + // Clear out the whole trie. + assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]); + } +} diff --git a/router/src/scheduler.rs b/router/src/scheduler.rs index 4569e99f3..815fdda8e 100644 --- a/router/src/scheduler.rs +++ b/router/src/scheduler.rs @@ -45,6 +45,7 @@ impl AdapterScheduler { adapter_cycle_time_s: u64, speculate: u32, max_batch_total_tokens: u32, + prefix_caching: bool, ) -> Self { let (sender, receiver) = flume::unbounded(); @@ -60,6 +61,7 @@ impl AdapterScheduler { adapter_cycle_time_s, speculate, max_batch_total_tokens, + prefix_caching, )); Self { sender } @@ -123,6 +125,7 @@ async fn adapter_scheduler_task( adapter_cycle_time_s: u64, speculate: u32, max_batch_total_tokens: u32, + prefix_caching: bool, ) { let mut state = AdapterSchedulerState::new( client, @@ -133,6 +136,7 @@ async fn adapter_scheduler_task( adapter_cycle_time_s, speculate, max_batch_total_tokens, + prefix_caching, ); while let Ok(cmd) = receiver.recv_async().await { @@ -204,6 +208,7 @@ impl AdapterSchedulerState { adapter_cycle_time_s: u64, speculate: u32, max_batch_total_tokens: u32, + prefix_caching: bool, ) -> Self { let queues_state = Arc::new(Mutex::new(AdapterQueuesState::new( max_active_adapters, @@ -211,8 +216,14 @@ impl AdapterSchedulerState { ))); let loader = AdapterLoader::new(client.clone()); - let block_allocator = (!requires_padding) - .then(|| BlockAllocator::new(max_batch_total_tokens, block_size, window_size)); + let block_allocator = (!requires_padding).then(|| { + BlockAllocator::new( + max_batch_total_tokens, + block_size, + prefix_caching, + window_size, + ) + }); Self { queues_state, @@ -373,7 +384,10 @@ impl AdapterSchedulerState { self.speculate ); - match block_allocator.allocate(tokens).await { + match block_allocator + .allocate(tokens, entry.request.input_ids()) + .await + { None => { // Entry is over budget // Add it back to the front @@ -418,11 +432,12 @@ impl AdapterSchedulerState { // Update entry entry.temp_span = Some(entry_batch_span); - let (blocks, slots) = match &block_allocation { - None => (Vec::new(), Vec::new()), + let (blocks, slots, prefix_len) = match &block_allocation { + None => (Vec::new(), Vec::new(), 0), Some(block_allocation) => ( block_allocation.blocks.clone(), block_allocation.slots.clone(), + block_allocation.prefix_len, ), }; @@ -431,7 +446,7 @@ impl AdapterSchedulerState { batch_entries .as_mut() .unwrap() - .add(id, entry, adapter, blocks, slots); + .add(id, entry, adapter, blocks, slots, prefix_len); } if batch_entries.is_none() { diff --git a/router/src/server.rs b/router/src/server.rs index a9a00af80..49a43ca84 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1014,6 +1014,7 @@ pub async fn run( adapter_source: String, embedding_model: bool, eager_prefill: bool, + prefix_caching: bool, ) -> Result<(), axum::BoxError> { // OpenAPI documentation #[derive(OpenApi)] @@ -1115,6 +1116,7 @@ pub async fn run( shard_info.block_size, shard_info.speculate, shard_info.preloaded_adapters, + prefix_caching, ); // Duration buckets diff --git a/server/lorax_server/models/custom_modeling/flash_cohere_modeling.py b/server/lorax_server/models/custom_modeling/flash_cohere_modeling.py index 6f39c42b0..36c2e2aa3 100644 --- a/server/lorax_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_cohere_modeling.py @@ -286,6 +286,8 @@ def forward( query, key, value, + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/lorax_server/models/custom_modeling/flash_dbrx_modeling.py b/server/lorax_server/models/custom_modeling/flash_dbrx_modeling.py index dd2e371f9..8880a3d5a 100644 --- a/server/lorax_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_dbrx_modeling.py @@ -423,6 +423,8 @@ def forward( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py b/server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py index c67ca89f2..d0d8285c4 100644 --- a/server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py @@ -265,6 +265,8 @@ def forward( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py b/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py index 15cd944fa..239c00bb4 100644 --- a/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py @@ -271,6 +271,8 @@ def forward( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py b/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py index 760086f1d..f76ed28d7 100644 --- a/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py @@ -164,6 +164,8 @@ def forward( qkv[:, 0], qkv[:, 1], qkv[:, 2], + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py index b84b9a0a5..dc1f9face 100644 --- a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py @@ -327,6 +327,8 @@ def forward( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py index a14e64938..b262a3abf 100644 --- a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py @@ -341,6 +341,8 @@ def forward( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py index bfcc2e71c..7fbeff767 100644 --- a/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py @@ -408,6 +408,8 @@ def forward( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/lorax_server/models/custom_modeling/flash_neox_modeling.py b/server/lorax_server/models/custom_modeling/flash_neox_modeling.py index 337ba90c6..395025849 100644 --- a/server/lorax_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_neox_modeling.py @@ -140,6 +140,8 @@ def forward( qkv[:, 0], qkv[:, 1], qkv[:, 2], + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/lorax_server/models/custom_modeling/flash_phi3_modeling.py b/server/lorax_server/models/custom_modeling/flash_phi3_modeling.py index e6a01765a..3a8bc2f73 100644 --- a/server/lorax_server/models/custom_modeling/flash_phi3_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_phi3_modeling.py @@ -280,6 +280,8 @@ def forward( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/lorax_server/models/custom_modeling/flash_phi_modeling.py b/server/lorax_server/models/custom_modeling/flash_phi_modeling.py index 8e5f0abe1..aeaf389f7 100644 --- a/server/lorax_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_phi_modeling.py @@ -166,6 +166,8 @@ def forward( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py index 1e97759d2..97651434c 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py @@ -239,6 +239,8 @@ def forward( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py index 232ca99c3..9d564c287 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py @@ -253,6 +253,8 @@ def forward( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/lorax_server/models/custom_modeling/flash_rw_modeling.py b/server/lorax_server/models/custom_modeling/flash_rw_modeling.py index a44967562..5558c2f4d 100644 --- a/server/lorax_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_rw_modeling.py @@ -180,6 +180,8 @@ def forward( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, @@ -290,6 +292,8 @@ def forward( query, torch.select(kv, dim=2, index=0), torch.select(kv, dim=2, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/lorax_server/models/custom_modeling/flash_santacoder_modeling.py b/server/lorax_server/models/custom_modeling/flash_santacoder_modeling.py index 277b4df10..42f6e65e5 100644 --- a/server/lorax_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_santacoder_modeling.py @@ -240,6 +240,8 @@ def forward( query, torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 9456605a3..68c2a1fc1 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -25,12 +25,12 @@ from lorax_server.pb import generate_pb2 from lorax_server.utils import HeterogeneousNextTokenChooser, StoppingCriteria from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID -from lorax_server.utils.constants import BLOCK_SIZE +from lorax_server.utils.attention.utils import block_tables_to_ragged from lorax_server.utils.dist import MEMORY_FRACTION from lorax_server.utils.graph import GraphCache from lorax_server.utils.segments import SegmentConcatBuilder, find_segments from lorax_server.utils.sources import HUB -from lorax_server.utils.state import FLASH_INFER, get_speculative_tokens, warmup_mode +from lorax_server.utils.state import BLOCK_SIZE, FLASH_INFER, PREFIX_CACHING, get_speculative_tokens, warmup_mode from lorax_server.utils.tokenizer import TokenizerManager ADAPTER_MEMORY_FRACTION = float(os.getenv("ADAPTER_MEMORY_FRACTION", "0.1")) @@ -77,6 +77,10 @@ class FlashCausalLMBatch(Batch): # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences slots: torch.Tensor + # size [b], containing the number of blocks that can be retrieved from the cache + prefix_lens: List[int] + prefix_lens_tensor: torch.Tensor + max_seqlen: int # Prefill metadata tensors to efficiently compute logprobs @@ -84,6 +88,9 @@ class FlashCausalLMBatch(Batch): prefill_next_token_indices: Optional[torch.tensor] prefill_cu_outlens: Optional[List[int]] + # Prefixes + prefix_ids: List[List[int]] + # All tokens all_input_ids: List[List[int]] all_input_ids_tensor: torch.Tensor @@ -152,6 +159,7 @@ def from_pb( prefix_offsets = [] read_offsets = [] all_input_ids = [] + prefix_ids = [] requests_idx_mapping = {} all_prefill_logprobs = True @@ -168,7 +176,7 @@ def from_pb( # Cumulative length cumulative_length = 0 - cumulative_max_length = 0 + cumulative_slot_tokens = 0 prefill_out_cumulative_length = 0 num_blocks = 0 @@ -178,6 +186,7 @@ def from_pb( block_tables = [] slots = [] + prefix_lens = [] # Parse batch for i, (r, tokenized_input) in enumerate(zip(pb.requests, batch_tokenized_inputs)): @@ -186,6 +195,18 @@ def from_pb( tokenized_input = tokenized_input[-r.truncate :] + orig_input_length = len(tokenized_input) + if PREFIX_CACHING: + prefix_len = r.prefix_len + if prefix_len == orig_input_length: + assert prefix_len > 0 + prefix_len -= 1 + else: + prefix_len = 0 + + prefix_ids.append(tokenized_input[:prefix_len]) + tokenized_input = tokenized_input[prefix_len:] + input_length = len(tokenized_input) input_lengths.append(input_length) @@ -195,7 +216,7 @@ def from_pb( all_input_ids.append(tokenized_input) # Position ids - request_position_ids = torch.arange(0, input_length, dtype=torch.int32) + request_position_ids = torch.arange(prefix_len, orig_input_length, dtype=torch.int32) position_ids.append(request_position_ids) # Add cumulative lengths of all previous inputs @@ -210,29 +231,36 @@ def from_pb( adapter_indices_list.append(torch.full((input_length,), r.adapter_index)) adapter_set.add(r.adapter_index) - # Paged attention - # Remove one as the first token des not have a past speculative_tokens = get_speculative_tokens() - total_tokens = input_length + max_new_tokens + speculative_tokens - 1 + + # Tokens that need to be mapped to blocks. + # Remove one as the first token des not have a past + block_tokens = orig_input_length + max_new_tokens - 1 + speculative_tokens + + # Tokens that need to be mapped to slots. We don't need slots for the + # cached prefix (if present). + slot_tokens = input_length + max_new_tokens - 1 + speculative_tokens # blocks and slots can be empty (for example in warmup) if not r.blocks: - needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) + needed_blocks = math.ceil(block_tokens / BLOCK_SIZE) request_blocks = [b for b in range(num_blocks, num_blocks + needed_blocks)] request_slots = [s for b in request_blocks for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)] else: request_blocks = r.blocks - request_slots = r.slots + request_slots = r.slots[ + prefix_len: #: orig_input_length + max_new_tokens + speculative_length + ] block_tables.append(request_blocks) - slots.extend(request_slots[:total_tokens]) + slots.extend(request_slots) + prefix_lens.append(prefix_len) num_blocks += len(request_blocks) - - start_slots.append(cumulative_max_length) + start_slots.append(cumulative_slot_tokens) request_slot_indices = torch.arange( - cumulative_max_length, - cumulative_max_length + input_length, + cumulative_slot_tokens, + cumulative_slot_tokens + input_length, dtype=torch.int64, ) slot_indices.append(request_slot_indices) @@ -262,7 +290,7 @@ def from_pb( # Update cumulative_length += input_length - cumulative_max_length += total_tokens + cumulative_slot_tokens += slot_tokens max_seqlen = max(max_seqlen, input_length) max_blocks = max(max_blocks, len(request_blocks)) max_length = max(max_length, input_length + max_new_tokens + speculative_tokens) @@ -325,6 +353,7 @@ def from_pb( for i, request_blocks in enumerate(block_tables): block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) block_tables_tensor = block_tables_tensor.to(device) + prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device) return cls( batch_id=pb.id, @@ -339,6 +368,8 @@ def from_pb( block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, + prefix_lens=prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, max_seqlen=max_seqlen, prefill_head_indices=prefill_head_indices, prefill_next_token_indices=prefill_next_token_indices, @@ -349,6 +380,7 @@ def from_pb( read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, + prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, num_blocks=num_blocks, @@ -389,8 +421,10 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": start_slots = [] block_tables = [] all_input_ids = [] + prefix_ids = [] input_lengths = [] + prefix_lens = [] prefix_offsets = [] read_offsets = [] @@ -411,11 +445,14 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": # Get length request_input_length = self.input_lengths[idx] + prefix_len = self.prefix_lens[idx] max_seqlen = max(max_seqlen, request_input_length) all_input_ids.append(self.all_input_ids[idx]) + prefix_ids.append(self.prefix_ids[idx]) input_lengths.append(request_input_length) + prefix_lens.append(prefix_len) prefix_offsets.append(self.prefix_offsets[idx]) read_offsets.append(self.read_offsets[idx]) @@ -451,6 +488,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": block_tables_tensor = self.block_tables_tensor[indices] input_lengths_tensor = self.input_lengths_tensor[indices] slots = self.slots[slot_filtering_indices] + prefix_lens_tensor = self.prefix_lens_tensor[indices] next_token_chooser = self.next_token_chooser.filter(indices) speculative_ids = self.speculative_ids[indices] if self.speculative_ids is not None else None @@ -481,10 +519,13 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": prefill_cu_outlens=None, input_lengths=input_lengths, input_lengths_tensor=input_lengths_tensor, + prefix_lens=prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, + prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, num_blocks=num_blocks, @@ -535,6 +576,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch slot_indices = batches[0].slot_indices.new_empty(total_batch_size) input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(total_batch_size) block_tables_tensor = batches[0].block_tables_tensor.new_zeros((total_batch_size, max_blocks)) + prefix_lens_tensor = batches[0].prefix_lens_tensor.new_empty(total_batch_size) all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros((total_batch_size, max_length)) total_indices_size = sum(b.adapter_meta.adapter_indices.shape[0] for b in batches) @@ -545,7 +587,9 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch start_slots = [] block_tables = [] + prefix_lens = [] all_input_ids = [] + prefix_ids = [] input_lengths = [] prefix_offsets = [] @@ -600,10 +644,14 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch batch.block_tables_tensor[:, :max_blocks] ) + prefix_lens_tensor[start_index:end_index] = batch.prefix_lens_tensor + start_slots.append(batch.start_slots + cumulative_slots) block_tables.extend(batch.block_tables) + prefix_lens.extend(batch.prefix_lens) all_input_ids.extend(batch.all_input_ids) + prefix_ids.extend(batch.prefix_ids) input_lengths.extend(batch.input_lengths) prefix_offsets.extend(batch.prefix_offsets) @@ -649,6 +697,8 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch slot_indices=slot_indices, block_tables=block_tables, block_tables_tensor=block_tables_tensor, + prefix_lens=prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, slots=slots, max_seqlen=max_seqlen, prefill_head_indices=None, @@ -660,6 +710,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, + prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, num_blocks=num_blocks, @@ -730,14 +781,17 @@ def __init__( self.kv_cache = [] self.prefill_state = None + self.prefill_with_paged_kv_state = None self.decode_state = None if FLASH_INFER: from lorax_server.utils.flashinfer_attention import ( create_decode_state, create_prefill_state, + create_prefill_with_paged_kv_state, ) self.prefill_state = create_prefill_state(device=device) + self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state(device=device) self.decode_state = create_decode_state( device=device, num_heads=self.num_heads, @@ -926,7 +980,10 @@ def _forward_context( *, block_tables: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], - input_lengths: torch.Tensor, + input_lengths: List[int], + input_lengths_tensor: torch.Tensor, + prefix_lens: List[int], + prefix_lens_tensor: torch.Tensor, state: Optional[Any] = None, ) -> ContextManager: if not FLASH_INFER: @@ -934,23 +991,32 @@ def _forward_context( from lorax_server.utils.flashinfer_attention import ( use_decode_state, - use_prefill_state, + use_prefill_with_paged_kv_state, ) + # has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens) if cu_seqlen_prefill is not None: - return use_prefill_state( - state=state if state is not None else self.prefill_state, + return use_prefill_with_paged_kv_state( + state=(state if state is not None else self.prefill_with_paged_kv_state), + # block_tables=block_tables_to_ragged( + # block_tables=block_tables, + # input_lengths=input_lengths, + # prefix_lens=prefix_lens, + # ), + block_tables=block_tables, cu_seqlens=cu_seqlen_prefill, + input_lengths=input_lengths_tensor, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_size=self.head_size, + page_size=BLOCK_SIZE, ) else: - assert input_lengths is not None + assert input_lengths_tensor is not None return use_decode_state( state=state if state is not None else self.decode_state, - input_lengths=input_lengths, - block_tables=block_tables.view(-1), + input_lengths=input_lengths_tensor, + block_tables=block_tables, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_size=self.head_size, @@ -974,6 +1040,7 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor + prefix_lens_tensor = batch.prefix_lens_tensor max_s = batch.max_seqlen if batch.speculative_ids is not None: @@ -987,6 +1054,7 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1) slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + prefix_lens_tensor = (batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)).reshape(-1) block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B * new_length, -1).contiguous() max_s = max_s + speculative_length @@ -995,16 +1063,40 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> position_ids = new_position_ids # Model Forward - with ( - self._forward_context( + if not use_graph: + # eager mode + input_lengths = input_lengths + prefix_lens_tensor + if PREFIX_CACHING: + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=batch.input_lengths, + prefix_lens=batch.prefix_lens, + ) + + with self._forward_context( block_tables=block_tables, cu_seqlen_prefill=batch.cu_seqlen_prefill, - input_lengths=input_lengths, - ) - if not use_graph - else nullcontext() - ): - logits = model.forward( + input_lengths=batch.input_lengths, + input_lengths_tensor=input_lengths, + prefix_lens=batch.prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, + ): + out = model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=batch.cu_seqlen_prefill, + kv_cache=self.kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + adapter_data=adapter_data, + prefill_cache_indices=batch.prefill_cache_indices, + lm_head_indices=batch.prefill_head_indices, + ) + else: + # CUDA graph mode + out = model.forward( input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=batch.cu_seqlen_prefill, @@ -1012,14 +1104,18 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> block_tables=block_tables, slots=slots, input_lengths=input_lengths, + prefix_lens=batch.prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, max_s=max_s, adapter_data=adapter_data, prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=batch.prefill_head_indices, ) - if batch.prefill_cache_indices is not None: - batch.prefill_cache_indices = None - return logits + + if batch.prefill_cache_indices is not None: + batch.prefill_cache_indices = None + + return out @tracer.start_as_current_span("generate_token") def generate_token( @@ -1200,6 +1296,7 @@ def generate_token( batch.read_offsets, batch.stopping_criterias, batch.all_input_ids, + batch.prefix_ids, batch.next_token_chooser.do_sample, batch.next_token_chooser.seeds, accepted_ids, @@ -1214,6 +1311,7 @@ def generate_token( read_offset, stopping_criteria, all_input_ids, + prefix_ids, do_sample, seed, num_accepted_ids, @@ -1302,10 +1400,12 @@ def generate_token( out_end_index = batch.prefill_cu_outlens[i + 1] # Remove generated token to only have prefill and add nan for first prompt token - request_prefill_logprobs = [float("nan")] + prefill_logprobs[out_start_index : out_end_index - 1] + request_prefill_logprobs = ([float("nan")] * (len(prefix_ids) + 1)) + prefill_logprobs[ + out_start_index : out_end_index - 1 + ] prefill_token_ids = all_input_ids[:-1] prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, + prefix_ids + prefill_token_ids, clean_up_tokenization_spaces=False, skip_special_tokens=False, ) diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index c5c669292..f906602d7 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -16,9 +16,8 @@ BASE_MODEL_ADAPTER_ID, load_and_merge_adapters, ) -from lorax_server.utils.constants import BLOCK_SIZE from lorax_server.utils.sources import HUB -from lorax_server.utils.state import get_speculative_tokens +from lorax_server.utils.state import BLOCK_SIZE, get_speculative_tokens from lorax_server.utils.tokenizer import TokenizerManager from lorax_server.utils.weights import shard_on_dim diff --git a/server/lorax_server/utils/attention/__init__.py b/server/lorax_server/utils/attention/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/lorax_server/utils/attention/utils.py b/server/lorax_server/utils/attention/utils.py new file mode 100644 index 000000000..d2c767246 --- /dev/null +++ b/server/lorax_server/utils/attention/utils.py @@ -0,0 +1,21 @@ +from typing import List + +import torch + + +def block_tables_to_ragged( + *, block_tables: torch.Tensor, input_lengths: List[int], prefix_lens: List[int] +) -> torch.Tensor: + """Convert block table to ragged format compatible with FlashInfer.""" + assert len(input_lengths) == len(prefix_lens) + + total_len = sum(input_lengths) + sum(prefix_lens) + block_tables_ragged = torch.empty(total_len, dtype=torch.int32, device=block_tables.device) + + offset = 0 + for i, (input_length, prefix_len) in enumerate(zip(input_lengths, prefix_lens)): + seq_len = prefix_len + input_length + block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len] + offset += seq_len + + return block_tables_ragged diff --git a/server/lorax_server/utils/constants.py b/server/lorax_server/utils/constants.py deleted file mode 100644 index 582bd1b66..000000000 --- a/server/lorax_server/utils/constants.py +++ /dev/null @@ -1 +0,0 @@ -BLOCK_SIZE: int = 16 diff --git a/server/lorax_server/utils/flash_attn.py b/server/lorax_server/utils/flash_attn.py index c51f2f2e1..7a7849591 100644 --- a/server/lorax_server/utils/flash_attn.py +++ b/server/lorax_server/utils/flash_attn.py @@ -115,9 +115,11 @@ def attention( if FLASH_INFER: def attention( - q, - k, - v, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, cu_seqlens, max_s, softmax_scale, @@ -125,14 +127,15 @@ def attention( causal=True, softcap=0.0, ): - from lorax_server.utils.flashinfer_attention import prefill_state + assert window_size_left == -1, "Windowing is not supported with flash infer" + from lorax_server.utils.flashinfer_attention import ( + prefill_with_paged_kv_state, + ) - return prefill_state.get().forward( - q, - k, - v, + return prefill_with_paged_kv_state.get().forward( + q.contiguous(), causal=causal, - window_left=window_size_left, + paged_kv_cache=(key_cache, value_cache), logits_soft_cap=softcap, sm_scale=softmax_scale, ) @@ -140,9 +143,11 @@ def attention( elif HAS_FLASH_ATTN_V2_CUDA: def attention( - q, - k, - v, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, cu_seqlens, max_s, softmax_scale, @@ -179,9 +184,11 @@ def attention( elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK: def attention( - q, - k, - v, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, cu_seqlens, max_s, softmax_scale, @@ -218,9 +225,11 @@ def attention( elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_TRITON: def attention( - q, - k, - v, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, cu_seqlens, max_s, softmax_scale, @@ -246,9 +255,11 @@ def attention( elif HAS_FLASH_ATTN: def attention( - q, - k, - v, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, cu_seqlens, max_s, softmax_scale, diff --git a/server/lorax_server/utils/flashinfer_attention.py b/server/lorax_server/utils/flashinfer_attention.py index 0dd0abcda..1ffea21a7 100644 --- a/server/lorax_server/utils/flashinfer_attention.py +++ b/server/lorax_server/utils/flashinfer_attention.py @@ -7,6 +7,10 @@ prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = ContextVar("prefill_state") +prefill_with_paged_kv_state: ContextVar[flashinfer.BatchPrefillWithPagedKVCacheWrapper] = ContextVar( + "prefill_with_paged_kv_state" +) + decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar("decode_state") workspace: Optional[torch.Tensor] = None @@ -20,6 +24,70 @@ def get_workspace(device): return workspace +def create_prefill_with_paged_kv_state( + *, + device: torch.device, +): + """Create a prefill state that uses the KV cache.""" + workspace_buffer = get_workspace(device) + return flashinfer.BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, kv_layout="NHD", use_cuda_graph=False) + + +@contextmanager +def use_prefill_with_paged_kv_state( + *, + state: flashinfer.BatchPrefillWithPagedKVCacheWrapper, + block_tables: torch.Tensor, + cu_seqlens: torch.Tensor, + input_lengths: torch.Tensor, + num_heads: int, + num_kv_heads: int, + head_size: int, + page_size: int, + query_dtype: str = "float16", +): + """ + Context manager to set the active flashinfer prefill state to the given + `state` and parameters. This state will be used by all calls to the + `attention` function while the context manager is active. + """ + + indptr = torch.zeros(input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32) + # Round up to page size and then calculate the cumulative sum to get + # the indices into the block table. + torch.add(input_lengths, page_size - 1, out=indptr[1:]) + indptr[1:].div_(page_size, rounding_mode="floor") + indptr[1:].cumsum_(-1) + + # Get the lengths of the last page in a block. + if page_size == 1: + last_page_len = torch.ones(input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device) + else: + last_page_len = torch.empty(input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device) + torch.sub(input_lengths, 1, out=last_page_len) + last_page_len.remainder_(page_size) + last_page_len += 1 + + token = prefill_with_paged_kv_state.set(state) + try: + state.begin_forward( + qo_indptr=cu_seqlens, + paged_kv_indptr=indptr, + paged_kv_indices=block_tables, + paged_kv_last_page_len=last_page_len, + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_size, + q_data_type=query_dtype, + page_size=page_size, + ) + yield + finally: + state.end_forward() + if token is not None: + prefill_with_paged_kv_state.reset(token) + + def create_prefill_state( *, device: torch.device, diff --git a/server/lorax_server/utils/graph.py b/server/lorax_server/utils/graph.py index af835a26e..3a6b42c63 100644 --- a/server/lorax_server/utils/graph.py +++ b/server/lorax_server/utils/graph.py @@ -15,10 +15,10 @@ from lorax_server.adapters import AdapterBatchData, AdapterBatchMetadata from lorax_server.adapters.lora import BatchLoraWeights, RankSegments from lorax_server.adapters.types import LORA -from lorax_server.utils.constants import BLOCK_SIZE +from lorax_server.utils.attention.utils import block_tables_to_ragged from lorax_server.utils.lora import LM_HEAD from lorax_server.utils.sgmv import BGMV_MAX_RANK -from lorax_server.utils.state import FLASH_INFER +from lorax_server.utils.state import BLOCK_SIZE, FLASH_INFER if TYPE_CHECKING: from lorax_server.models.flash_causal_lm import FlashCausalLMBatch @@ -78,6 +78,8 @@ class GraphState: block_tables: torch.Tensor slots: torch.Tensor input_lengths: torch.Tensor + prefix_lens: List[int] + prefix_lens_tensor: torch.Tensor adapter_data: AdapterBatchData traced_adapter_layer_names: Set[str] state: Any = None @@ -223,17 +225,27 @@ def trace( } block_tables = max_input_state.block_tables[:batch_size] + input_lengths = max_input_state.input_lengths[:batch_size] + prefix_lengths = [0] * batch_size + prefix_lengths_tensor = torch.zeros(batch_size, dtype=torch.int32, device=device) state = None + if FLASH_INFER: from lorax_server.utils.flashinfer_attention import ( create_decode_state_cuda_graphs, ) + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=input_lengths.tolist(), + prefix_lens=prefix_lengths, + ) + block_tables_ptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) last_page_len = torch.ones(batch_size, dtype=torch.int32, device=device) state = create_decode_state_cuda_graphs( device=max_input_state.input_ids.device, - block_tables=block_tables.view(-1), + block_tables=block_tables, block_tables_ptr=block_tables_ptr, last_page_len=last_page_len, num_heads=num_heads, @@ -245,7 +257,9 @@ def trace( position_ids=max_input_state.position_ids[:batch_size], block_tables=block_tables, slots=max_input_state.slots[:batch_size], - input_lengths=max_input_state.input_lengths[:batch_size], + input_lengths=input_lengths, + prefix_lens=prefix_lengths, + prefix_lens_tensor=prefix_lengths_tensor, adapter_data=AdapterBatchData( meta=AdapterBatchMetadata( adapter_indices=max_input_state.adapter_data.meta.adapter_indices[:batch_size], @@ -265,7 +279,10 @@ def trace( with forward_context( block_tables=input_state.block_tables, cu_seqlen_prefill=None, - input_lengths=input_state.input_lengths, + input_lengths=input_lengths, + input_lengths_tensor=input_state.input_lengths, + prefix_lens=prefix_lengths, + prefix_lens_tensor=prefix_lengths_tensor, state=input_state.state, ): graph = torch.cuda.CUDAGraph() @@ -296,6 +313,8 @@ def forward( block_tables: torch.Tensor, slots: torch.Tensor, input_lengths: torch.Tensor, + prefix_lens: List[int], + prefix_lens_tensor: torch.Tensor, max_s: int, adapter_data: AdapterBatchData, lm_head_indices: Optional[torch.Tensor] = None, @@ -303,11 +322,21 @@ def forward( pad_and_fill(self.input_state.input_ids, input_ids, 0) pad_and_fill(self.input_state.position_ids, position_ids, 0) pad_and_fill(self.input_state.slots, slots, SLOT_PAD_VALUE) - pad_and_fill(self.input_state.input_lengths, input_lengths, 0) + pad_and_fill(self.input_state.input_lengths, input_lengths + prefix_lens_tensor, 0) self.input_state.block_tables.zero_() self.input_state.block_tables[: block_tables.shape[0], : block_tables.shape[1]] = block_tables + if FLASH_INFER: + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=input_lengths, + prefix_lens=prefix_lens, + ) + self.input_state.block_tables[: block_tables.shape[0]] = block_tables + else: + self.input_state.block_tables[: block_tables.shape[0], : block_tables.shape[1]] = block_tables + for layer_name, weight_data in self.input_state.adapter_data.data.items(): # TODO(travis): generalize this to support other adapter types lora_data = weight_data[LORA] @@ -328,7 +357,10 @@ def forward( with self.forward_context( block_tables=self.input_state.block_tables, cu_seqlen_prefill=None, - input_lengths=self.input_state.input_lengths, + input_lengths=input_lengths, + input_lengths_tensor=self.input_state.input_lengths, + prefix_lens=prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, state=self.input_state.state, ): self.graph.replay() diff --git a/server/lorax_server/utils/state.py b/server/lorax_server/utils/state.py index 1adfee1c6..63e73f55f 100644 --- a/server/lorax_server/utils/state.py +++ b/server/lorax_server/utils/state.py @@ -6,11 +6,24 @@ WARMUP = False SPECULATIVE_TOKENS = 0 -FLASH_INFER = bool(os.environ.get("FLASH_INFER", "")) + +PREFIX_CACHING = bool(os.environ.get("PREFIX_CACHING", "")) +logger.info(f"Prefix caching = {PREFIX_CACHING}") + + +# Always use flashinfer when prefix caching is enabled +FLASH_INFER = bool(os.environ.get("FLASH_INFER", "")) or PREFIX_CACHING if FLASH_INFER: logger.info("Using flashinfer") +BLOCK_SIZE: int +if FLASH_INFER: + BLOCK_SIZE = 1 +else: + BLOCK_SIZE = 16 + + def set_warmup(value: bool): global WARMUP WARMUP = value