diff --git a/docs/reference/openapi.json b/docs/reference/openapi.json index cb0f11a6b..483b507f7 100644 --- a/docs/reference/openapi.json +++ b/docs/reference/openapi.json @@ -705,7 +705,7 @@ "example": 0.5, "nullable": false, "minimum": 0.0, - "maximim": 1.0 + "maximum": 1.0 }, "majority_sign_method": { "type": "string", @@ -727,7 +727,7 @@ "example": 1, "nullable": true, "minimum": 0.0, - "exclusiveMinimum": 0.0 + "exclusiveMinimum": true }, "decoder_input_details": { "type": "boolean", @@ -748,7 +748,7 @@ "default": "null", "nullable": true, "minimum": 0.0, - "exclusiveMinimum": 0.0 + "exclusiveMinimum": true }, "ignore_eos_token": { "type": "boolean", @@ -761,7 +761,8 @@ "default": "null", "example": 1.03, "nullable": true, - "exclusiveMinimum": 0.0 + "minimum": 0.0, + "exclusiveMinimum": true }, "return_full_text": { "type": "boolean", @@ -776,7 +777,7 @@ "example": "null", "nullable": true, "minimum": 0.0, - "exclusiveMinimum": 0.0 + "exclusiveMinimum": true }, "stop": { "type": "array", @@ -794,7 +795,8 @@ "default": "null", "example": 0.5, "nullable": true, - "exclusiveMinimum": 0.0 + "minimum": 0.0, + "exclusiveMinimum": false }, "top_k": { "type": "integer", @@ -802,7 +804,8 @@ "default": "null", "example": 10, "nullable": true, - "exclusiveMinimum": 0.0 + "minimum": 0.0, + "exclusiveMinimum": true }, "top_p": { "type": "number", @@ -811,7 +814,7 @@ "example": 0.95, "nullable": true, "maximum": 1.0, - "exclusiveMinimum": 0.0 + "exclusiveMinimum": true }, "truncate": { "type": "integer", @@ -827,7 +830,7 @@ "example": 0.95, "nullable": true, "maximum": 1.0, - "exclusiveMinimum": 0.0 + "exclusiveMinimum": true }, "watermark": { "type": "boolean", @@ -930,7 +933,8 @@ "default": "null", "example": 0.5, "nullable": true, - "exclusiveMinimum": 0.0 + "minimum": 0.0, + "exclusiveMinimum": false }, "top_p": { "type": "number", @@ -939,22 +943,23 @@ "example": 0.95, "nullable": true, "maximum": 1.0, - "exclusiveMinimum": 0.0 + "minimum": 0.0, + "exclusiveMinimum": true }, "n": { "type": "integer", "default": "null", "example": 3, "nullable": true, - "exclusiveMinimum": 0 + "minimum": 0, + "exclusiveMinimum": true }, "max_tokens": { "type": "integer", "format": "int32", "default": "20", "minimum": 0.0, - "exclusiveMaximum": 512.0, - "exclusiveMinimum": 0.0 + "exclusiveMinimum": true }, "stop": { "type": "array", @@ -997,7 +1002,8 @@ "default": "null", "example": 0.5, "nullable": true, - "exclusiveMinimum": 0.0 + "minimum": 0.0, + "exclusiveMinimum": false }, "top_p": { "type": "number", @@ -1006,22 +1012,23 @@ "example": 0.95, "nullable": true, "maximum": 1.0, - "exclusiveMinimum": 0.0 + "minimum": 0.0, + "exclusiveMinimum": true }, "n": { "type": "integer", "default": "null", "example": 3, "nullable": true, - "exclusiveMinimum": 0 + "minimum": 0, + "exclusiveMinimum": true }, "max_tokens": { "type": "integer", "format": "int32", "default": "20", "minimum": 0.0, - "exclusiveMaximum": 512.0, - "exclusiveMinimum": 0.0 + "exclusiveMinimum": true }, "stop": { "type": "array", diff --git a/router/src/batch.rs b/router/src/batch.rs index 50b91550b..2be4bcdbe 100644 --- a/router/src/batch.rs +++ b/router/src/batch.rs @@ -259,6 +259,7 @@ pub(crate) trait BatchEntries: Sync + Send + Debug { blocks: Vec, slots: Vec, prefix_len: u32, + chunk_len: Option, ); fn extend(&mut self, entries: Box); fn drain(&mut self) -> Vec<(Adapter, u64, Entry)>; @@ -323,6 +324,7 @@ impl BatchEntries for GenerateBatchEntries { blocks: Vec, slots: Vec, prefix_len: u32, + chunk_len: Option, ) { let valid_request = entry .request @@ -343,7 +345,7 @@ impl BatchEntries for GenerateBatchEntries { blocks, slots, cache_len: prefix_len, - chunk_len: None, + chunk_len: chunk_len, }; self.state.add(id, entry, adapter, request_proto); @@ -455,6 +457,7 @@ impl BatchEntries for EmbedBatchEntries { blocks: Vec, slots: Vec, prefix_len: u32, + chunk_len: Option, ) { let valid_request = entry .request @@ -475,7 +478,7 @@ impl BatchEntries for EmbedBatchEntries { blocks, slots, cache_len: prefix_len, - chunk_len: None, + chunk_len: chunk_len, }; self.state.add(id, entry, adapter, request_proto); @@ -580,6 +583,7 @@ impl BatchEntries for ClassifyBatchEntries { blocks: Vec, slots: Vec, prefix_len: u32, + chunk_len: Option, ) { let valid_request = entry .request @@ -600,7 +604,7 @@ impl BatchEntries for ClassifyBatchEntries { blocks, slots, cache_len: prefix_len, - chunk_len: None, + chunk_len: chunk_len, }; self.state.add(id, entry, adapter, request_proto); diff --git a/router/src/block_allocator.rs b/router/src/block_allocator.rs index 05c2bd30d..3b1374536 100644 --- a/router/src/block_allocator.rs +++ b/router/src/block_allocator.rs @@ -1,4 +1,4 @@ -use std::{cmp::min, sync::Arc}; +use std::sync::Arc; use tokio::sync::{mpsc, oneshot}; use crate::radix::RadixAllocator; @@ -56,12 +56,14 @@ impl BlockAllocator { pub(crate) async fn allocate( &self, + adapter_index: u32, tokens: u32, prefill_tokens: Option>>, ) -> Option { let (response_sender, response_receiver) = oneshot::channel(); self.block_allocator .send(BlockAllocatorCommand::Allocate { + adapter_index, tokens, prefill_tokens, response_sender, @@ -103,12 +105,13 @@ async fn block_allocator_task( allocation_id, } => allocator.free(blocks, allocation_id), BlockAllocatorCommand::Allocate { + adapter_index, tokens, prefill_tokens, response_sender, } => { response_sender - .send(allocator.allocate(tokens, prefill_tokens)) + .send(allocator.allocate(adapter_index, tokens, prefill_tokens)) .unwrap(); } } @@ -122,6 +125,7 @@ enum BlockAllocatorCommand { allocation_id: u64, }, Allocate { + adapter_index: u32, tokens: u32, prefill_tokens: Option>>, response_sender: oneshot::Sender>, @@ -131,6 +135,7 @@ enum BlockAllocatorCommand { pub(crate) trait Allocator { fn allocate( &mut self, + adapter_index: u32, tokens: u32, prefill_tokens: Option>>, ) -> Option; @@ -158,6 +163,7 @@ impl SimpleAllocator { impl Allocator for SimpleAllocator { fn allocate( &mut self, + _adapter_index: u32, tokens: u32, _prefill_tokens: Option>>, ) -> Option { @@ -167,7 +173,7 @@ impl Allocator for SimpleAllocator { None => (tokens, 1), Some(window_size) => { let repeats = (tokens + window_size - 1) / window_size; - let tokens = min(tokens, window_size); + let tokens = core::cmp::min(tokens, window_size); (tokens, repeats as usize) } }; diff --git a/router/src/infer.rs b/router/src/infer.rs index ba57d206d..74788bda0 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -200,6 +200,7 @@ impl Infer { speculate, max_batch_total_tokens, prefix_caching, + chunked_prefill, is_causal_lm, ); diff --git a/router/src/radix.rs b/router/src/radix.rs index 8543b236a..243df370a 100644 --- a/router/src/radix.rs +++ b/router/src/radix.rs @@ -1,11 +1,22 @@ +use crate::block_allocator::{Allocator, BlockAllocation}; +use slotmap::{DefaultKey, SlotMap}; +use std::hash::{Hash, Hasher}; use std::{ collections::{BTreeSet, HashMap}, sync::Arc, }; -use slotmap::{DefaultKey, SlotMap}; - -use crate::block_allocator::{Allocator, BlockAllocation}; +fn hash(adapter_index: u32, slice: &[u32]) -> u64 { + assert!(!slice.is_empty()); + if slice.len() == 1 && adapter_index == 0 { + slice[0] as u64 + } else { + let mut s = std::hash::DefaultHasher::new(); + adapter_index.hash(&mut s); + slice.hash(&mut s); + s.finish() + } +} pub struct RadixAllocator { allocation_id: u64, @@ -16,39 +27,43 @@ pub struct RadixAllocator { /// Blocks that are immediately available for allocation. free_blocks: Vec, + + #[allow(dead_code)] + // This isn't used because the prefix need to match without the windowing + // mecanism. This at worst is overallocating, not necessarily being wrong. + window_size: Option, + + block_size: u32, } impl RadixAllocator { pub fn new(block_size: u32, n_blocks: u32, window_size: Option) -> Self { - assert_eq!( - block_size, 1, - "Radix tree allocator only works with block_size=1, was: {}", - block_size - ); - if window_size.is_some() { - unimplemented!("Window size not supported in the prefix-caching block allocator yet"); - } - RadixAllocator { allocation_id: 0, allocations: HashMap::new(), - cache_blocks: RadixTrie::new(), + cache_blocks: RadixTrie::new(block_size as usize), // Block 0 is reserved for health checks. free_blocks: (1..n_blocks).collect(), + window_size, + block_size, } } - fn alloc_or_reclaim(&mut self, n_blocks_needed: usize) -> Option> { + fn alloc_or_reclaim(&mut self, adapter_index: u32, n_blocks_needed: usize) -> Option> { if self.free_blocks.len() < n_blocks_needed { // This is a bit annoying, we first extend the free list and then // split it off again below. This is because we need to put it on // the free list if we cannot allocate enough blocks. This is only // temporary, the trie needs to be able to report whether it can // allocate the requested amount. Just not implemented yet. + tracing::debug!( + "Free blocks {} need {n_blocks_needed}", + self.free_blocks.len() + ); self.free_blocks.extend( self.cache_blocks - .evict(n_blocks_needed - self.free_blocks.len()), + .evict(adapter_index, n_blocks_needed - self.free_blocks.len()), ); } @@ -63,35 +78,43 @@ impl RadixAllocator { } } +// Allocator trait impl Allocator for RadixAllocator { fn allocate( &mut self, + adapter_index: u32, tokens: u32, prefill_tokens: Option>>, ) -> Option { let mut blocks = vec![]; let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() { - let node_id = self - .cache_blocks - .find(prefill_tokens.as_slice(), &mut blocks); - // Even if this allocation fails below, we need to increase he - // refcount to ensure that the prefix that was found is not evicted. - + let node_id = + self.cache_blocks + .find(adapter_index, prefill_tokens.as_slice(), &mut blocks); node_id } else { self.cache_blocks.root_id() }; + // Even if this allocation fails below, we need to increase he + // refcount to ensure that the prefix that was found is not evicted. self.cache_blocks .incref(prefix_node) .expect("Failed to increment refcount"); - let prefix_len = blocks.len(); + let prefix_len = blocks.len() * self.block_size as usize; let suffix_len = tokens - prefix_len as u32; - match self.alloc_or_reclaim(suffix_len as usize) { + let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size; + + tracing::debug!("Prefix {prefix_len} - Suffix {suffix_len}"); + + match self.alloc_or_reclaim(adapter_index, suffix_blocks as usize) { Some(suffix_blocks) => blocks.extend(suffix_blocks), None => { + tracing::debug!("Cannot allocate {:?}", self.cache_blocks); + tracing::debug!("Found {prefix_len} prefix tokens need {suffix_blocks} suffix blocks for {tokens} tokens"); + tracing::debug!("Block size {}", self.block_size); self.cache_blocks .decref(prefix_node) .expect("Failed to decrement refcount"); @@ -100,9 +123,23 @@ impl Allocator for RadixAllocator { } // 1:1 mapping of blocks and slots. - let slots = blocks.clone(); + let slots = if self.block_size == 1 { + blocks.clone() + } else { + let mut slots = Vec::with_capacity(blocks.len() * self.block_size as usize); + 'slots: for block_id in &blocks { + for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) { + slots.push(s); + if slots.len() as u32 == tokens { + break 'slots; + } + } + } + slots + }; let allocation = RadixAllocation { + adapter_index, prefix_node, cached_prefix_len: prefix_len, prefill_tokens: prefill_tokens.clone(), @@ -136,29 +173,39 @@ impl Allocator for RadixAllocator { // If there are prefill tokens that did not come from the cache, // add them to the cache. if prefill_tokens.len() > allocation.cached_prefix_len { - let prefix_len = self - .cache_blocks - .insert(prefill_tokens, &blocks[..prefill_tokens.len()]) - // Unwrap, failing is a programming error. - .expect("Failed to store prefill tokens"); - - // We can have a prefill with the following structure: - // - // |---| From the prefix cache. - // A B C D E F G - //|--------| Found in the trie during insertion. - // - // This means that while processing this request there was a - // partially overlapping request that had A..=E in its - // prefill. In this case we need to free the blocks D E. - if prefix_len > allocation.cached_prefix_len { - self.free_blocks - .extend(&blocks[allocation.cached_prefix_len..prefix_len]); + let aligned = + (prefill_tokens.len() / self.block_size as usize) * self.block_size as usize; + if aligned > 0 { + let prefix_len = self + .cache_blocks + .insert( + allocation.adapter_index, + &prefill_tokens[..aligned], + &blocks[..aligned / self.block_size as usize], + ) + // Unwrap, failing is a programming error. + .expect("Failed to store prefill tokens"); + // We can have a prefill with the following structure: + // + // |---| From the prefix cache. + // A B C D E F G + //|--------| Found in the trie during insertion. + // + // This means that while processing this request there was a + // partially overlapping request that had A..=E in its + // prefill. In this case we need to free the blocks D E. + if prefix_len > allocation.cached_prefix_len { + self.free_blocks.extend( + &blocks[allocation.cached_prefix_len / self.block_size as usize + ..prefix_len / self.block_size as usize], + ); + } } } // Free non-prefill blocks. - self.free_blocks.extend(&blocks[prefill_tokens.len()..]); + self.free_blocks + .extend(&blocks[prefill_tokens.len() / self.block_size as usize..]); } else { self.free_blocks.extend(blocks); } @@ -166,6 +213,7 @@ impl Allocator for RadixAllocator { } struct RadixAllocation { + adapter_index: u32, prefix_node: NodeId, cached_prefix_len: usize, prefill_tokens: Option>>, @@ -187,7 +235,6 @@ struct RadixAllocation { pub enum TrieError { InvalidNodeId, RefCountUnderflow, - BlockTokenCountMismatch, } pub type NodeId = DefaultKey; @@ -206,11 +253,14 @@ pub struct RadixTrie { /// Time as a monotonically increating counter to avoid the system /// call that a real time lookup would require. time: u64, + + /// All blocks need to be aligned with this + block_size: usize, } impl RadixTrie { /// Construct a new radix trie. - pub fn new() -> Self { + pub fn new(block_size: usize) -> Self { let root = TrieNode::new(vec![], vec![], 0, None); let mut nodes = SlotMap::new(); let root = nodes.insert(root); @@ -219,36 +269,47 @@ impl RadixTrie { nodes, root, time: 0, + block_size, } } /// Find the prefix of the given tokens. /// /// The blocks corresponding to the part of the prefix that could be found - /// are writteng to `blocks`. The number of blocks is in `0..=tokens.len()`. + /// are written to `blocks`. The number of blocks is in `0..=tokens.len()`. /// Returns the identifier of the trie node that contains the longest /// prefix. The node identifier can be used by callers to e.g. increase its /// reference count. /// /// Using this method will update the access time of the traversed nodes. - pub fn find(&mut self, key: &[u32], blocks: &mut Vec) -> NodeId { + pub fn find(&mut self, adapter_index: u32, key: &[u32], blocks: &mut Vec) -> NodeId { self.time += 1; - self.find_(self.root, key, blocks) + self.find_(adapter_index, self.root, key, blocks) } /// Find worker. - fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec) -> NodeId { + fn find_( + &mut self, + adapter_index: u32, + mut node_id: NodeId, + key: &[u32], + blocks: &mut Vec, + ) -> NodeId { let node = &self.nodes[node_id]; - if let Some(&child_id) = node.children.get(&key[0]) { - self.update_access_time(child_id); - let child = self.nodes.get(child_id).expect("Invalid child identifier"); - let shared_prefix_len = child.key.shared_prefix_len(key); - blocks.extend(&child.blocks[..shared_prefix_len]); - - let key = &key[shared_prefix_len..]; - if !key.is_empty() { - node_id = self.find_(child_id, key, blocks); + if key.len() >= self.block_size { + let node_key = hash(adapter_index, &key[..self.block_size]); + if let Some(&child_id) = node.children.get(&node_key) { + self.update_access_time(child_id); + let child = self.nodes.get(child_id).expect("Invalid child identifier"); + let shared_prefix_len = shared_prefix(&child.key, key, self.block_size); + assert_eq!(shared_prefix_len % self.block_size, 0); + blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]); + + let key = &key[shared_prefix_len..]; + if !key.is_empty() { + node_id = self.find_(adapter_index, child_id, key, blocks); + } } } @@ -273,6 +334,11 @@ impl RadixTrie { node.ref_count -= 1; if node.ref_count == 0 { + assert!( + node.children.is_empty(), + "Nodes with children must have refcount > 0" + ); + self.leaves.insert((node.last_accessed, node_id)); } @@ -300,8 +366,8 @@ impl RadixTrie { /// Evict `n_blocks` from the trie. /// /// Returns the evicted blocks. When the length is less than `n_blocks`, - /// not enough blocks could beevicted. - pub fn evict(&mut self, n_blocks: usize) -> Vec { + /// not enough blocks could be evicted. + pub fn evict(&mut self, adapter_index: u32, n_blocks: usize) -> Vec { // NOTE: we don't return Result here. If any of the unwrapping fails, // it's a programming error in the trie implementation, not a user // error caused by e.g. an invalid argument. @@ -310,14 +376,22 @@ impl RadixTrie { // evict n_blocks and return `None` if we can't. We are now needlessly // evicting prefixes from the cache in such a case. let mut evicted = Vec::new(); + tracing::debug!("Evicting in search of {n_blocks}"); while let Some((last_access, node_id)) = self.leaves.pop_first() { - let blocks_needed = n_blocks - evicted.len(); + let blocks_needed = n_blocks.saturating_sub(evicted.len()); + tracing::debug!("Evicting node {node_id:?} "); let node = self.nodes.get(node_id).expect("Leave does not exist"); + assert_eq!( + node.ref_count, 0, + "Leaf must have refcount of 0, got {}", + node.ref_count + ); + if blocks_needed >= node.blocks.len() { // We need to evict the whole node if we need more blocks than it has. - let node = self.remove_node(node_id); + let node = self.remove_node(adapter_index, node_id); evicted.extend(node.blocks); if evicted.len() >= n_blocks { @@ -328,8 +402,11 @@ impl RadixTrie { // the required number of blocks and leave the remaining blocks // untouched. let node = self.nodes.get_mut(node_id).expect("Leave does not exist"); - node.key.truncate(node.blocks.len() - blocks_needed); - evicted.extend(node.blocks.split_off(node.blocks.len() - blocks_needed)); + + let truncate_blocks = node.blocks.len() - blocks_needed; + let truncate_tokens = truncate_blocks * self.block_size; + node.key.truncate(truncate_tokens); + evicted.extend(node.blocks.split_off(truncate_blocks)); self.leaves.insert((last_access, node_id)); break; } @@ -343,14 +420,21 @@ impl RadixTrie { /// This method returns the length of the prefix that was already /// in the trie. E.g. if the length is 10, this means that for /// the first 10 elements of the tree **the blocks are not updated**. - pub fn insert(&mut self, tokens: &[u32], blocks: &[u32]) -> Result { + pub fn insert( + &mut self, + adapter_index: u32, + tokens: &[u32], + blocks: &[u32], + ) -> Result { self.time += 1; - self.insert_(self.root, tokens, blocks) + let common = self.insert_(adapter_index, self.root, tokens, blocks)?; + Ok(common) } /// Insertion worker. fn insert_( &mut self, + adapter_index: u32, node_id: NodeId, tokens: &[u32], blocks: &[u32], @@ -359,21 +443,20 @@ impl RadixTrie { // the part of the prefix that is already in the trie to detect // mismatches. - if tokens.len() != blocks.len() { - return Err(TrieError::BlockTokenCountMismatch); - } + assert_eq!(tokens.len(), blocks.len() * self.block_size); - if let Some(&child_id) = self.nodes[node_id].children.get(&tokens[0]) { + let node_key = hash(adapter_index, &tokens[..self.block_size]); + if let Some(&child_id) = self.nodes[node_id].children.get(&node_key) { self.update_access_time(child_id); let child = self .nodes .get_mut(child_id) // Unwrap here, since failure is a bug. .expect("Child node does not exist"); - let shared_prefix_len = child.key.shared_prefix_len(tokens); + let shared_prefix_len = shared_prefix(&child.key, tokens, self.block_size); // We are done, the prefix is already in the trie. - if shared_prefix_len == tokens.len() { + if shared_prefix_len == tokens.len() || shared_prefix_len == 0 { return Ok(shared_prefix_len); } @@ -381,26 +464,27 @@ impl RadixTrie { if shared_prefix_len == child.key.len() { return Ok(shared_prefix_len + self.insert_( + adapter_index, child_id, &tokens[shared_prefix_len..], - &blocks[shared_prefix_len..], + &blocks[shared_prefix_len / self.block_size..], )?); } // The node's prefix and the insertion prefix only match partially, // split the node to just contain the matching part. Then insert the // remainder of the prefix into the node again - let child_id = self.split_node(child_id, shared_prefix_len); + let child_id = self.split_node(adapter_index, child_id, shared_prefix_len); let key = &tokens[shared_prefix_len..]; - let blocks = &blocks[shared_prefix_len..]; - Ok(shared_prefix_len + self.insert_(child_id, key, blocks)?) + let blocks = &blocks[shared_prefix_len / self.block_size..]; + Ok(shared_prefix_len + self.insert_(adapter_index, child_id, key, blocks)?) } else { - self.add_node(node_id, tokens, blocks); + self.add_node(adapter_index, node_id, tokens, blocks); Ok(0) } } - fn split_node(&mut self, node_id: NodeId, prefix_len: usize) -> NodeId { + fn split_node(&mut self, adapter_index: u32, node_id: NodeId, prefix_len: usize) -> NodeId { // We have to make the current node a child to ensure that its // properties and node id stay the same. @@ -411,17 +495,18 @@ impl RadixTrie { .get_mut(node_id) .expect("Node to-be split does not exist"); let mut parent_key = node.key.split_off(prefix_len); - let mut parent_blocks = node.blocks.split_off(prefix_len); + let prefix_blocks = prefix_len / self.block_size; + let mut parent_blocks = node.blocks.split_off(prefix_blocks); // Move first part of the prefix to the parent. We swap to avoid // an allocation + copy for both splits of the key/blocks. std::mem::swap(&mut node.key, &mut parent_key); std::mem::swap(&mut node.blocks, &mut parent_blocks); - let node_key = node.key[0]; + let node_key = hash(adapter_index, &node.key[..self.block_size]); let grandparent_id = node.parent.expect("Node does not have a parent"); - let parent_id = self.add_node(grandparent_id, parent_key, parent_blocks); + let parent_id = self.add_node(adapter_index, grandparent_id, parent_key, parent_blocks); self.add_node_to_parent(parent_id, node_key, node_id); // Reborrow to make the borrow checker happy. @@ -437,13 +522,14 @@ impl RadixTrie { /// Create a node and add it to the parent. fn add_node( &mut self, + adapter_index: u32, parent_id: NodeId, key: impl Into>, blocks: impl Into>, ) -> NodeId { let key = key.into(); let blocks = blocks.into(); - let first = key[0]; + let first = hash(adapter_index, &key[..self.block_size]); let child = TrieNode::new(key, blocks, self.time, Some(parent_id)); let child_id = self.nodes.insert(child); @@ -455,10 +541,10 @@ impl RadixTrie { } /// Add a node to the parent. - fn add_node_to_parent(&mut self, parent_id: NodeId, first: u32, child_id: NodeId) { + fn add_node_to_parent(&mut self, parent_id: NodeId, hash: u64, child_id: NodeId) { // Unwrap here, passing in an unknown id is a programming error. let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node"); - if parent.children.insert(first, child_id).is_none() { + if parent.children.insert(hash, child_id).is_none() { // Only increase reference count if child does not replace another child. self.incref(parent_id) .expect("Failed to increase parent refcount"); @@ -466,15 +552,21 @@ impl RadixTrie { } /// Remove a node from the trie. - fn remove_node(&mut self, node_id: NodeId) -> TrieNode { + fn remove_node(&mut self, adapter_index: u32, node_id: NodeId) -> TrieNode { // Unwrap here, passing in an unknown id is a programming error. let node = self.nodes.remove(node_id).expect("Unknown node"); + assert!( + node.children.is_empty(), + "Tried to remove a node with {} children", + node.children.len() + ); let parent_id = node.parent.expect("Attempted to remove root node"); let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node"); - parent.children.remove(&node.key[0]); + + let node_key = hash(adapter_index, &node.key[..self.block_size]); + parent.children.remove(&node_key); self.decref(parent_id) .expect("Failed to decrease parent refcount"); - self.nodes.remove(node_id); node } @@ -526,7 +618,7 @@ impl RadixTrie { #[derive(Debug)] struct TrieNode { blocks: Vec, - children: HashMap, + children: HashMap, key: Vec, last_accessed: u64, parent: Option, @@ -546,38 +638,68 @@ impl TrieNode { } } -/// Helper trait to get the length of the shared prefix of two sequences. -trait SharedPrefixLen { - fn shared_prefix_len(&self, other: &Self) -> usize; -} - -impl SharedPrefixLen for [T] -where - T: PartialEq, -{ - fn shared_prefix_len(&self, other: &Self) -> usize { - self.iter().zip(other).take_while(|(a, b)| a == b).count() - } +fn shared_prefix(left: &[u32], right: &[u32], block_size: usize) -> usize { + let full = left.iter().zip(right).take_while(|(a, b)| a == b).count(); + // NOTE: this is the case because the child node was chosen based on + // matching the first character of the key/prefix. + assert!(full > 0, "Prefixes must at least share 1 token"); + (full / block_size) * block_size } #[cfg(test)] mod tests { use std::sync::Arc; - use crate::block_allocator::Allocator; + use super::*; - use super::RadixAllocator; + #[test] + fn allocator_block_size() { + let mut cache = RadixAllocator::new(2, 12, None); + let allocation = cache + .allocate(0, 8, Some(Arc::new(vec![0, 1, 2, 3]))) + .unwrap(); + assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); + assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]); + assert_eq!(allocation.prefix_len, 0); + cache.free(allocation.blocks.clone(), allocation.allocation_id); + + let allocation = cache + .allocate(0, 8, Some(Arc::new(vec![0, 1, 2, 3]))) + .unwrap(); + assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); + assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]); + assert_eq!(allocation.prefix_len, 4); + } + + #[test] + fn allocator_block_size_non_aligned() { + let mut cache = RadixAllocator::new(2, 12, None); + let allocation = cache.allocate(0, 7, Some(Arc::new(vec![0, 1, 2]))).unwrap(); + assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); + assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]); + assert_eq!(allocation.prefix_len, 0); + cache.free(allocation.blocks.clone(), allocation.allocation_id); + + let allocation = cache.allocate(0, 7, Some(Arc::new(vec![0, 1, 2]))).unwrap(); + assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); + assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]); + assert_eq!(allocation.prefix_len, 2); + } #[test] fn allocator_reuses_prefixes() { let mut cache = RadixAllocator::new(1, 12, None); - let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + let allocation = cache + .allocate(0, 8, Some(Arc::new(vec![0, 1, 2, 3]))) + .unwrap(); assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]); - assert_eq!(allocation.slots, allocation.slots); + assert_eq!(allocation.blocks, allocation.slots); assert_eq!(allocation.prefix_len, 0); cache.free(allocation.blocks.clone(), allocation.allocation_id); - let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + let allocation = cache + .allocate(0, 8, Some(Arc::new(vec![0, 1, 2, 3]))) + .unwrap(); assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]); assert_eq!(allocation.prefix_len, 4); } @@ -585,11 +707,13 @@ mod tests { #[test] fn allocator_collects_older_prefixes_first() { let mut cache = RadixAllocator::new(1, 7, None); - let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + let allocation1 = cache + .allocate(0, 4, Some(Arc::new(vec![0, 1, 2, 3]))) + .unwrap(); assert_eq!(allocation1.blocks, vec![3, 4, 5, 6]); assert_eq!(allocation1.prefix_len, 0); - let allocation2 = cache.allocate(2, Some(Arc::new(vec![4, 5]))).unwrap(); + let allocation2 = cache.allocate(0, 2, Some(Arc::new(vec![4, 5]))).unwrap(); assert_eq!(allocation2.blocks, vec![1, 2]); assert_eq!(allocation2.prefix_len, 0); @@ -597,7 +721,9 @@ mod tests { cache.free(allocation2.blocks.clone(), allocation2.allocation_id); // We should get the blocks of the first allocation, since they are more recent. - let allocation3 = cache.allocate(4, Some(Arc::new(vec![6, 7, 8, 9]))).unwrap(); + let allocation3 = cache + .allocate(0, 4, Some(Arc::new(vec![6, 7, 8, 9]))) + .unwrap(); assert_eq!(allocation3.blocks, vec![3, 4, 5, 6]); assert_eq!(allocation3.prefix_len, 0); } @@ -605,13 +731,19 @@ mod tests { #[test] fn allocator_frees_fully_overlapping_prefills() { let mut cache = RadixAllocator::new(1, 10, None); - let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); - let allocation2 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + let allocation1 = cache + .allocate(0, 4, Some(Arc::new(vec![0, 1, 2, 3]))) + .unwrap(); + let allocation2 = cache + .allocate(0, 4, Some(Arc::new(vec![0, 1, 2, 3]))) + .unwrap(); cache.free(allocation2.blocks.clone(), allocation2.allocation_id); cache.free(allocation1.blocks.clone(), allocation1.allocation_id); - let allocation3 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + let allocation3 = cache + .allocate(0, 4, Some(Arc::new(vec![0, 1, 2, 3]))) + .unwrap(); assert_eq!(allocation3.prefix_len, 4); // 10 blocks, of which 1 reserved for health checks, 4 for the cached blocks. @@ -621,20 +753,20 @@ mod tests { #[test] fn allocator_frees_partially_overlapping_prefills() { let mut cache = RadixAllocator::new(1, 20, None); - let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1]))).unwrap(); + let allocation1 = cache.allocate(0, 4, Some(Arc::new(vec![0, 1]))).unwrap(); assert_eq!(allocation1.blocks, vec![16, 17, 18, 19]); assert_eq!(allocation1.prefix_len, 0); cache.free(allocation1.blocks.clone(), allocation1.allocation_id); let allocation2 = cache - .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 4, 5]))) + .allocate(0, 8, Some(Arc::new(vec![0, 1, 2, 3, 4, 5]))) .unwrap(); assert_eq!(allocation2.blocks, vec![16, 17, 12, 13, 14, 15, 18, 19]); assert_eq!(allocation2.prefix_len, 2); let allocation3 = cache - .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 6, 7]))) + .allocate(0, 8, Some(Arc::new(vec![0, 1, 2, 3, 6, 7]))) .unwrap(); assert_eq!(allocation3.blocks, vec![16, 17, 6, 7, 8, 9, 10, 11]); assert_eq!(allocation3.prefix_len, 2); @@ -646,14 +778,14 @@ mod tests { assert_eq!(cache.free_blocks.len(), 11); let allocation4 = cache - .allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 4, 5]))) + .allocate(0, 6, Some(Arc::new(vec![0, 1, 2, 3, 4, 5]))) .unwrap(); assert_eq!(allocation4.blocks, vec![16, 17, 6, 7, 14, 15]); assert_eq!(allocation4.prefix_len, 6); assert_eq!(cache.free_blocks.len(), 11); let allocation5 = cache - .allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 6, 7]))) + .allocate(0, 6, Some(Arc::new(vec![0, 1, 2, 3, 6, 7]))) .unwrap(); assert_eq!(allocation5.blocks, vec![16, 17, 6, 7, 8, 9]); assert_eq!(allocation5.prefix_len, 6); @@ -662,96 +794,123 @@ mod tests { #[test] fn trie_insertions_have_correct_prefix_len() { - let mut trie = super::RadixTrie::new(); + let mut trie = RadixTrie::new(1); - assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 0); + assert_eq!(trie.insert(0, &[0, 1, 2], &[0, 1, 2]).unwrap(), 0); // Already exists. - assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 3); + assert_eq!(trie.insert(0, &[0, 1, 2], &[0, 1, 2]).unwrap(), 3); // Completely new at root-level - assert_eq!(trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(), 0); + assert_eq!(trie.insert(0, &[1, 2, 3], &[1, 2, 3]).unwrap(), 0); // Contains full prefix, but longer. - assert_eq!(trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(), 3); + assert_eq!( + trie.insert(0, &[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(), + 3 + ); // Shares partial prefix, we need a split. assert_eq!( - trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) + trie.insert(0, &[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) .unwrap(), 4 ); } + #[test] + fn trie_insertions_block_size() { + let mut trie = RadixTrie::new(2); + + assert_eq!(trie.insert(0, &[0, 1, 2, 3], &[0, 1]).unwrap(), 0); + + // Already exists. + // But needs to be block_size aligned + assert_eq!(trie.insert(0, &[0, 1, 2, 3], &[0, 1]).unwrap(), 4); + + // Completely new at root-level + assert_eq!(trie.insert(0, &[1, 2, 3, 4], &[1, 2]).unwrap(), 0); + + // Contains full prefix, but longer. + assert_eq!(trie.insert(0, &[0, 1, 2, 3, 4, 5], &[0, 1, 2]).unwrap(), 4); + + // Shares partial prefix, we need a split. + assert_eq!( + trie.insert(0, &[0, 1, 3, 4, 5, 6, 7, 8], &[0, 1, 2, 3]) + .unwrap(), + 2 + ); + } + #[test] fn trie_get_returns_correct_blocks() { - let mut trie = super::RadixTrie::new(); - trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(); - trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(); - trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(); - trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) + let mut trie = RadixTrie::new(1); + trie.insert(0, &[0, 1, 2], &[0, 1, 2]).unwrap(); + trie.insert(0, &[1, 2, 3], &[1, 2, 3]).unwrap(); + trie.insert(0, &[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(); + trie.insert(0, &[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) .unwrap(); let mut blocks = Vec::new(); - trie.find(&[0], &mut blocks); + trie.find(0, &[0], &mut blocks); assert_eq!(blocks, vec![0]); blocks.clear(); - trie.find(&[0, 1, 2], &mut blocks); + trie.find(0, &[0, 1, 2], &mut blocks); assert_eq!(blocks, vec![0, 1, 2]); blocks.clear(); - trie.find(&[1, 2, 3], &mut blocks); + trie.find(0, &[1, 2, 3], &mut blocks); assert_eq!(blocks, vec![1, 2, 3]); blocks.clear(); - trie.find(&[0, 1, 2, 3], &mut blocks); + trie.find(0, &[0, 1, 2, 3], &mut blocks); assert_eq!(blocks, vec![0, 1, 2, 3]); blocks.clear(); - trie.find(&[0, 1, 2, 3, 4], &mut blocks); + trie.find(0, &[0, 1, 2, 3, 4], &mut blocks); assert_eq!(blocks, vec![0, 1, 2, 3, 4]); blocks.clear(); - trie.find(&[0, 1, 2, 3, 5], &mut blocks); + trie.find(0, &[0, 1, 2, 3, 5], &mut blocks); assert_eq!(blocks, vec![0, 1, 2, 3, 5]); } #[test] fn trie_evict_removes_correct_blocks() { - let mut trie = super::RadixTrie::new(); - trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(); - trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) + let mut trie = RadixTrie::new(1); + trie.insert(0, &[0, 1, 2], &[0, 1, 2]).unwrap(); + trie.insert(0, &[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) .unwrap(); - trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(); - trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(); + trie.insert(0, &[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(); + trie.insert(0, &[1, 2, 3], &[1, 2, 3]).unwrap(); let mut blocks = Vec::new(); // Remove less than the leave blocks. - assert_eq!(trie.evict(1), vec![7]); - trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks); + assert_eq!(trie.evict(0, 1), vec![7]); + trie.find(0, &[0, 1, 2, 3, 5, 6, 7], &mut blocks); assert_eq!(blocks, vec![0, 1, 2, 3, 5, 6]); // Refresh other leaf. - trie.find(&[0, 1, 2, 3, 4], &mut blocks); - trie.find(&[1, 2, 3], &mut blocks); + trie.find(0, &[0, 1, 2, 3, 4], &mut blocks); + trie.find(0, &[1, 2, 3], &mut blocks); // Remove the leave blocks exactly. - assert_eq!(trie.evict(2), vec![5, 6]); + assert_eq!(trie.evict(0, 2), vec![5, 6]); blocks.clear(); - trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks); + trie.find(0, &[0, 1, 2, 3, 5, 6, 7], &mut blocks); assert_eq!(blocks, vec![0, 1, 2, 3]); - trie.find(&[1, 2, 3], &mut blocks); + trie.find(0, &[1, 2, 3], &mut blocks); // Remove more than the leave blocks. - assert_eq!(trie.evict(3), vec![4, 3, 2]); + assert_eq!(trie.evict(0, 3), vec![4, 3, 2]); blocks.clear(); - trie.find(&[0, 1, 2, 3, 4], &mut blocks); + trie.find(0, &[0, 1, 2, 3, 4], &mut blocks); assert_eq!(blocks, vec![0, 1]); // Clear out the whole trie. - assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]); + assert_eq!(trie.evict(0, 10), vec![1, 2, 3, 0, 1]); } } diff --git a/router/src/scheduler.rs b/router/src/scheduler.rs index 7777938be..4231b3894 100644 --- a/router/src/scheduler.rs +++ b/router/src/scheduler.rs @@ -6,11 +6,7 @@ use crate::{ AdapterLoader, }; use lorax_client::{Batch, ShardedClient}; -use std::{ - cmp::{max, min}, - collections::HashSet, - sync::Arc, -}; +use std::{cmp::max, collections::HashSet, sync::Arc}; use tokio::sync::{oneshot, Mutex}; use tracing::{info_span, instrument, Instrument, Span}; @@ -44,6 +40,7 @@ impl AdapterScheduler { speculate: u32, max_batch_total_tokens: u32, prefix_caching: bool, + chunked_prefill: bool, is_causal_lm: bool, ) -> Self { let (sender, receiver) = flume::unbounded(); @@ -61,6 +58,7 @@ impl AdapterScheduler { speculate, max_batch_total_tokens, prefix_caching, + chunked_prefill, is_causal_lm, )); @@ -126,6 +124,7 @@ async fn adapter_scheduler_task( speculate: u32, max_batch_total_tokens: u32, prefix_caching: bool, + chunked_prefill: bool, is_causal_lm: bool, ) { let mut state = AdapterSchedulerState::new( @@ -138,6 +137,7 @@ async fn adapter_scheduler_task( speculate, max_batch_total_tokens, prefix_caching, + chunked_prefill, is_causal_lm, ); @@ -193,11 +193,14 @@ struct AdapterSchedulerState { block_size: u32, /// Sliding window - window_size: Option, + // window_size: Option, /// Speculation amount speculate: u32, + /// Chunked prefill + chunked_prefill: bool, + /// Paged Attention Block Allocation block_allocator: Option, } @@ -213,6 +216,7 @@ impl AdapterSchedulerState { speculate: u32, max_batch_total_tokens: u32, prefix_caching: bool, + chunked_prefill: bool, is_causal_lm: bool, ) -> Self { let queues_state = Arc::new(Mutex::new(AdapterQueuesState::new( @@ -237,8 +241,9 @@ impl AdapterSchedulerState { next_batch_id: 0, requires_padding, block_size, - window_size, + // window_size, speculate, + chunked_prefill, block_allocator, } } @@ -282,6 +287,10 @@ impl AdapterSchedulerState { prefill_token_budget: u32, token_budget: u32, ) -> Option { + if prefill_token_budget == 0 || token_budget == 0 { + return None; + }; + let num_entries = self.queues_state.lock().await.len(); if num_entries == 0 { return None; @@ -331,6 +340,8 @@ impl AdapterSchedulerState { batch_requests_len = batch_entries.len(); } + let mut should_break = false; + let mut chunk_len = None; let block_allocation = match &self.block_allocator { None => { // We pad to max input length in the Python shards @@ -354,31 +365,6 @@ impl AdapterSchedulerState { None } Some(block_allocator) => { - prefill_tokens += entry.request.input_length(); - let max_new_tokens = match self.window_size { - None => entry.request.max_new_tokens(), - Some(window_size) => min( - window_size.saturating_sub(entry.request.input_length()), - entry.request.max_new_tokens(), - ), - }; - decode_tokens += max_new_tokens; - - // If we're prefix caching, this check could be under-estimating the number of available blocks - // due to shared prefixes, so we'll let the block allocator determine whether we have enough space. - if prefill_tokens > prefill_token_budget - || (prefill_tokens + decode_tokens + self.speculate) > token_budget - { - // Entry is over budget - // Add it back to the front - tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); - self.queues_state - .lock() - .await - .push_front(&adapter, id, entry); - break; - } - let tokens = entry.request.input_length() + entry.request.max_new_tokens() + self.speculate @@ -392,8 +378,8 @@ impl AdapterSchedulerState { self.speculate ); - match block_allocator - .allocate(tokens, entry.request.input_ids()) + let block_allocation = match block_allocator + .allocate(adapter.index(), tokens, entry.request.input_ids()) .await { None => { @@ -406,12 +392,62 @@ impl AdapterSchedulerState { .push_front(&adapter, id, entry); break 'entry_loop; } - Some(block_allocation) => { + Some(mut block_allocation) => { tracing::debug!("Allocation: {block_allocation:?}"); max_blocks = max(max_blocks, block_allocation.blocks.len() as u32); - Some(block_allocation) + + if block_allocation.prefix_len == entry.request.input_length() { + // The whole request was found in the radix trie + // However, for the transformer forward to work, we need to + // have at least one token of postfix. + block_allocation.prefix_len -= 1; + } + + block_allocation + } + }; + + let postfix_len = entry.request.input_length() - block_allocation.prefix_len; + if prefill_tokens + postfix_len > prefill_token_budget { + // Entry is over budget + if self.chunked_prefill { + // We support chunking, just set postfix_len to exactly match prefill_token_budget + let entry_chunk_len = + prefill_token_budget.saturating_sub(prefill_tokens); + if entry_chunk_len > 0 { + chunk_len = Some(entry_chunk_len); + } else { + // We cannot prefill even one token for this entry + // Add it back to the queue + self.queues_state + .lock() + .await + .push_front(&adapter, id, entry); + break 'entry_loop; + } + tracing::debug!( + "Matched budget: prefill_tokens={} == {prefill_token_budget}", + prefill_tokens + postfix_len + ); + should_break = true; + } else { + // We don't support chunking, this entry needs to go back to the buffer + // Add it back to the front + tracing::debug!( + "Over budget: prefill_tokens={} > {prefill_token_budget}", + prefill_tokens + postfix_len + ); + self.queues_state + .lock() + .await + .push_front(&adapter, id, entry); + break 'entry_loop; } } + + prefill_tokens += postfix_len; + + Some(block_allocation) } }; @@ -454,7 +490,11 @@ impl AdapterSchedulerState { batch_entries .as_mut() .unwrap() - .add(id, entry, adapter, blocks, slots, prefix_len); + .add(id, entry, adapter, blocks, slots, prefix_len, chunk_len); + + if should_break { + break 'entry_loop; + } } if batch_entries.is_none() { diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 3952f01d1..2e667ef84 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -219,15 +219,6 @@ def from_pb( if cache_length == prompt_length: assert False, "unreachable" - # TODO(travis): double-check prefix caching - # if PREFIX_CACHING: - # prefix_len = r.prefix_len - # if prefix_len == orig_input_length: - # assert prefix_len > 0 - # prefix_len -= 1 - # else: - # prefix_len = 0 - # `chunk_len` is an optional field in the protobuf # It is only set if the model support chunking if r.HasField("chunk_len"):